device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp Source File

device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp Source File#

Composable Kernel: device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp Source File
device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <tuple>
9
24
25namespace ck {
26namespace tensor_operation {
27namespace device {
28
42template <typename GridwiseGemm,
43 typename GemmDesc,
44 GemmSpecialization GemmSpec,
45 typename ADataType,
46 typename BDataType,
47 typename DsDataType,
48 typename EDataType,
49 typename ALayout,
50 typename BLayout,
51 typename DsLayout,
52 typename ELayout,
53 index_t KPerBlock,
54 typename OffsettedBlockToCTileMap,
55 typename LocalBlock2ETileMap,
56 typename AElementwiseOperation,
57 typename BElementwiseOperation,
58 typename CDEElementwiseOperation,
59 BlockGemmPipelineScheduler BlkGemmPipeSched,
60 BlockGemmPipelineVersion BlkGemmPipelineVer>
61__global__ void
62#if CK_USE_LAUNCH_BOUNDS
64#endif
66 const index_t group_count,
67 const AElementwiseOperation a_element_op,
68 const BElementwiseOperation b_element_op,
69 const CDEElementwiseOperation cde_element_op)
70{
71#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
73 {
74 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
75 __shared__ uint8_t p_shared[shared_size];
76 __shared__ uint8_t p_shared1[shared_size];
77
78 const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
80
81 constexpr auto NumDTensor = DsDataType::Size();
82 index_t tile_id = get_block_1d_id();
83 index_t tile_offset = 0;
84 index_t group_id = -1;
85 index_t group_offset = 0;
86 index_t grid_size_grp = 0;
87
88 index_t gemm_tile_id_start = 0;
89 index_t gemm_tile_id_end = 0;
90
91 index_t M = 0, N = 0, K = 0;
92
93 auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
94
95 do
96 {
97 // Find corresponding GEMM group for our tile
98 while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
99 group_id < group_count)
100 {
101 group_offset += grid_size_grp;
102 group_id++;
103
104 if(group_id >= group_count)
105 return;
106
107 M = gemm_desc_ptr[group_id].M;
108 N = gemm_desc_ptr[group_id].N;
109 K = gemm_desc_ptr[group_id].K;
110
111 if(M == 0 || N == 0 || K == 0)
112 {
113 grid_size_grp = 0;
114 continue;
115 }
116
117 b2c_tile_map = OffsettedBlockToCTileMap(
118 LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
119 grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
120
121 gemm_tile_id_start = group_offset;
122 gemm_tile_id_end = group_offset + grid_size_grp;
123 }
124
125 using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
126 DsGridPointer p_ds_grid;
127
128 static_for<0, NumDTensor, 1>{}([&](auto i) {
129 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
130 p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
131 });
132
133 static constexpr index_t kbatch = 1;
134 static constexpr index_t k_grain = kbatch * KPerBlock;
135 index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock;
136
137 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
138
139 // Update tile offset if we have moved within group
140 b2c_tile_map.UpdateTileOffset(tile_offset);
141
142 using Problem = typename GridwiseGemm::Problem;
143 auto problem = Problem(gemm_desc_ptr[group_id].M,
144 gemm_desc_ptr[group_id].N,
145 gemm_desc_ptr[group_id].K,
146 gemm_desc_ptr[group_id].StrideA,
147 gemm_desc_ptr[group_id].StrideB,
148 gemm_desc_ptr[group_id].StrideDs,
149 gemm_desc_ptr[group_id].StrideE,
150 kbatch);
151
152 if(has_main_k_block_loop)
153 {
154 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
155 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
156 {
157 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
158 true,
161 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
162 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
163 p_ds_grid,
164 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
165 static_cast<void*>(p_shared),
166 problem,
167 a_element_op,
168 b_element_op,
169 cde_element_op,
170 b2c_tile_map);
171 }
172 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
173 {
174 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
175 {
176 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
177 true,
180 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
181 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
182 p_ds_grid,
183 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
184 static_cast<void*>(p_shared),
185 problem,
186 a_element_op,
187 b_element_op,
188 cde_element_op,
189 b2c_tile_map);
190 }
191 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
192 {
193 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
194 true,
197 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
198 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
199 p_ds_grid,
200 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
201 static_cast<void*>(p_shared),
202 problem,
203 a_element_op,
204 b_element_op,
205 cde_element_op,
206 b2c_tile_map);
207 }
208
209 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
210 {
211 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
212 {
213 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
214 true,
217 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
218 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
219 p_ds_grid,
220 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
221 static_cast<void*>(p_shared),
222 problem,
223 a_element_op,
224 b_element_op,
225 cde_element_op,
226 b2c_tile_map);
227 }
228 }
229
230 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
231 {
232 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
233 {
234 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
235 true,
238 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
239 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
240 p_ds_grid,
241 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
242 static_cast<void*>(p_shared),
243 problem,
244 a_element_op,
245 b_element_op,
246 cde_element_op,
247 b2c_tile_map);
248 }
249 }
250
251 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
252 {
253 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
254 {
255 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
256 true,
259 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
260 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
261 p_ds_grid,
262 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
263 static_cast<void*>(p_shared),
264 problem,
265 a_element_op,
266 b_element_op,
267 cde_element_op,
268 b2c_tile_map);
269 }
270 }
271
272 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
273 {
274 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
275 {
276 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
277 true,
280 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
281 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
282 p_ds_grid,
283 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
284 static_cast<void*>(p_shared),
285 problem,
286 a_element_op,
287 b_element_op,
288 cde_element_op,
289 b2c_tile_map);
290 }
291 }
292
293 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
294 {
295 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
296 {
297 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
298 true,
301 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
302 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
303 p_ds_grid,
304 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
305 static_cast<void*>(p_shared),
306 problem,
307 a_element_op,
308 b_element_op,
309 cde_element_op,
310 b2c_tile_map);
311 }
312 }
313
314 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
315 {
316 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
317 {
318 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
319 true,
322 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
323 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
324 p_ds_grid,
325 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
326 static_cast<void*>(p_shared),
327 problem,
328 a_element_op,
329 b_element_op,
330 cde_element_op,
331 b2c_tile_map);
332 }
333 }
334 }
335 // Tail number could be Odd or Even
336 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
337 {
338 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
339 {
340 GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
341 true,
344 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
345 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
346 p_ds_grid,
347 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
348 static_cast<void*>(p_shared),
349 static_cast<void*>(p_shared1),
350 problem,
351 a_element_op,
352 b_element_op,
353 cde_element_op,
354 b2c_tile_map);
355 }
356 else
357 {
358 GridwiseGemm::template Run_2Lds<OffsettedBlockToCTileMap,
359 true,
362 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
363 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
364 p_ds_grid,
365 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
366 static_cast<void*>(p_shared),
367 static_cast<void*>(p_shared1),
368 problem,
369 a_element_op,
370 b_element_op,
371 cde_element_op,
372 b2c_tile_map);
373 }
374 }
375 }
376 else
377 {
378 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
379 {
380 GridwiseGemm::template Run<OffsettedBlockToCTileMap,
381 false,
384 static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid),
385 static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid),
386 p_ds_grid,
387 static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
388 static_cast<void*>(p_shared),
389 problem,
390 a_element_op,
391 b_element_op,
392 cde_element_op,
393 b2c_tile_map);
394 }
395 }
396
397 tile_id += get_grid_size();
398 tile_offset += get_grid_size();
399
400 } while(group_id < group_count);
401 }
402#else
403 ignore = gemm_descs_const;
404 ignore = group_count;
405 ignore = a_element_op;
406 ignore = b_element_op;
407 ignore = cde_element_op;
408#endif // end of if (defined(__gfx9__))
409}
410
411template <typename ALayout,
412 typename BLayout,
413 typename DsLayout,
414 typename ELayout,
415 typename ADataType,
416 typename BDataType,
417 typename AccDataType,
418 typename CShuffleDataType,
419 typename DsDataType,
420 typename EDataType,
421 typename AElementwiseOperation,
422 typename BElementwiseOperation,
423 typename CDEElementwiseOperation,
424 GemmSpecialization GemmSpec,
425 ck::index_t NumGemmKPrefetchStage,
426 ck::index_t BlockSize,
427 ck::index_t MPerBlock,
428 ck::index_t NPerBlock,
429 ck::index_t KPerBlock,
430 ck::index_t AK1,
431 ck::index_t BK1,
432 ck::index_t MPerXDL,
433 ck::index_t NPerXDL,
434 ck::index_t MXdlPerWave,
435 ck::index_t NXdlPerWave,
436 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
437 typename ABlockTransferThreadClusterArrangeOrder,
438 typename ABlockTransferSrcAccessOrder,
439 index_t ABlockTransferSrcVectorDim,
440 index_t ABlockTransferSrcScalarPerVector,
441 index_t ABlockTransferDstScalarPerVector_AK1,
442 index_t ABlockLdsExtraM,
443 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
444 typename BBlockTransferThreadClusterArrangeOrder,
445 typename BBlockTransferSrcAccessOrder,
446 index_t BBlockTransferSrcVectorDim,
447 index_t BBlockTransferSrcScalarPerVector,
448 index_t BBlockTransferDstScalarPerVector_BK1,
449 index_t BBlockLdsExtraN,
450 index_t CShuffleMXdlPerWavePerShuffle,
451 index_t CShuffleNXdlPerWavePerShuffle,
452 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
453 typename CDEShuffleBlockTransferScalarPerVectors,
456 typename ComputeTypeA = EDataType,
457 typename ComputeTypeB = ComputeTypeA>
458
460 : public DeviceGroupedGemmTileLoop<ALayout,
461 BLayout,
462 DsLayout,
463 ELayout,
464 ADataType,
465 BDataType,
466 DsDataType,
467 EDataType,
468 AElementwiseOperation,
469 BElementwiseOperation,
470 CDEElementwiseOperation>
471{
474 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
475 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
476 static constexpr index_t NumDTensor = DsDataType::Size();
477
478 template <index_t NXdlPerWave_>
480 ALayout,
481 BLayout,
482 DsLayout,
483 ELayout,
484 ADataType,
485 BDataType,
486 AccDataType,
487 CShuffleDataType,
488 DsDataType,
489 EDataType,
490 AElementwiseOperation,
491 BElementwiseOperation,
492 CDEElementwiseOperation,
493 GemmSpec,
494 BlockSize,
495 MPerBlock,
496 NPerBlock,
497 KPerBlock,
498 AK1,
499 BK1,
500 MPerXDL,
501 NPerXDL,
502 MXdlPerWave,
503 NXdlPerWave_,
504 ABlockTransferThreadClusterLengths_AK0_M_AK1,
505 ABlockTransferThreadClusterArrangeOrder,
506 ABlockTransferSrcAccessOrder,
507 ABlockTransferSrcVectorDim,
508 ABlockTransferSrcScalarPerVector,
509 ABlockTransferDstScalarPerVector_AK1,
510 false, // AThreadTransferSrcResetCoordinateAfterRun,
511 ABlockLdsExtraM,
512 BBlockTransferThreadClusterLengths_BK0_N_BK1,
513 BBlockTransferThreadClusterArrangeOrder,
514 BBlockTransferSrcAccessOrder,
515 BBlockTransferSrcVectorDim,
516 BBlockTransferSrcScalarPerVector,
517 BBlockTransferDstScalarPerVector_BK1,
518 false, // BThreadTransferSrcResetCoordinateAfterRun,
519 BBlockLdsExtraN,
520 CShuffleMXdlPerWavePerShuffle,
521 CShuffleNXdlPerWavePerShuffle,
522 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
523 CDEShuffleBlockTransferScalarPerVectors,
524 BlkGemmPipeSched,
525 BlkGemmPipelineVer,
526 ComputeTypeA,
527 ComputeTypeB>;
530
534
535 // Argument
536 struct Argument : public BaseArgument
537 {
538 Argument(std::vector<const void*>& /* p_As */,
539 std::vector<const void*>& /* p_Bs */,
540 std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
541 std::vector<void*>& /* p_Es */,
542 const std::vector<GemmDesc>& gemm_descs,
543 AElementwiseOperation a_element_op,
544 BElementwiseOperation b_element_op,
545 CDEElementwiseOperation cde_element_op,
546 int occupancy_num_blocks,
547 int gpu_cu_count)
548 : group_count_{static_cast<index_t>(gemm_descs.size())},
549 occupancy_num_blocks_{occupancy_num_blocks},
550 gpu_cu_count_{gpu_cu_count},
551 gemm_descs_{gemm_descs},
552 a_element_op_{a_element_op},
553 b_element_op_{b_element_op},
554 cde_element_op_{cde_element_op},
555 tile_count_{0}
556 {
557 for(const auto& desc : gemm_descs)
558 {
559 const auto M = desc.M_;
560 const auto N = desc.N_;
561 const auto b2c_tile_map = Block2ETileMap(M, N);
562 tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
563 }
564 }
565
567 const void* p_dev_gemm_args_;
570 const std::vector<GemmDesc>& gemm_descs_;
571 AElementwiseOperation a_element_op_;
572 BElementwiseOperation b_element_op_;
573 CDEElementwiseOperation cde_element_op_;
575 };
576
578 {
579 // The oversubscription factor for the number of blocks that can simultaneously reside on
580 // GPU.
581 static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
582 // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
583 static constexpr int CU_SIMDS = 4;
584 // Assume we want to have at most 2 waves per SIMD
585 // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
586 static int GetCuBlocks()
587 {
588 int BLOCK_WAVES = BlockSize / get_warp_size();
589 return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
590 }
591 };
592
593 // Invoker
594 struct Invoker : public BaseInvoker
595 {
609 template <typename GridwiseGemm>
610 float Run(const Argument& arg,
611 const void* dev_gemm_args,
612 const StreamConfig& stream_config = StreamConfig{})
613 {
614 if(dev_gemm_args == nullptr)
615 {
616 std::ostringstream err;
617 err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
618 << ":" << __LINE__ << ", in function: " << __func__;
619 throw std::runtime_error(err.str());
620 }
621
622 float ave_time = 0;
623 ave_time = DispatchKernel<GridwiseGemm>(arg, dev_gemm_args, stream_config);
624
625 return ave_time;
626 }
627
641 template <typename GridwiseGemm>
642 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
643 {
644 if(arg.p_dev_gemm_args_ == nullptr)
645 {
646 std::ostringstream err;
647 err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
648 << ":" << __LINE__ << ", in function: " << __func__;
649 throw std::runtime_error(err.str());
650 }
651
652 return Run<GridwiseGemm>(arg, arg.p_dev_gemm_args_, stream_config);
653 }
654
656
657 float Run(const BaseArgument* p_arg,
658 const StreamConfig& stream_config = StreamConfig{}) override
659 {
660 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
661 }
662
663 private:
664 template <typename GridwiseGemm>
665 float DispatchKernel(const Argument& arg,
666 const void* dev_gemm_args,
667 const StreamConfig& stream_config) const
668 {
669 const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
671 GemmSpec,
672 ADataType,
673 BDataType,
674 DsDataType,
675 EDataType,
676 ALayout,
677 BLayout,
678 DsLayout,
679 ELayout,
680 KPerBlock,
683 AElementwiseOperation,
684 BElementwiseOperation,
685 CDEElementwiseOperation,
686 BlkGemmPipeSched,
687 BlkGemmPipelineVer>;
688 return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
689 }
690
691 template <typename KernelFunction>
692 int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
693 const StreamConfig& stream_config) const
694 {
695 // Calculate max number of workgroups that can simultaneously reside on the CU.
696 int occ_num_blocks = 0;
697 size_t dyn_shared_mem_per_blk = 0;
698 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
699 &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
700
701 int cu_count = getAvailableComputeUnitCount(stream_config);
702
703 if(stream_config.log_level_ > 0)
704 {
705 std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
706 << ", available CUs count: " << cu_count << ", occup. grid size: "
707 << ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count
708 << std::endl;
709 }
710
711 return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks());
712 }
713
714 template <typename KernelFunction>
715 float LaunchKernel(const KernelFunction& kernel,
716 const Argument& arg,
717 const void* dev_gemm_args,
718 const StreamConfig& stream_config) const
719 {
720 int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
721
722 if(stream_config.log_level_ > 0)
723 {
724 std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_
725 << std::endl;
726 }
727
728 // run multiple kernels
729
730 return launch_and_time_kernel(stream_config,
731 kernel,
732 dim3(grid_size),
733 dim3(BlockSize),
734 0,
736 arg.group_count_,
737 arg.a_element_op_,
738 arg.b_element_op_,
739 arg.cde_element_op_);
740 }
741 };
742
743 static constexpr bool IsValidCompilationParameter()
744 {
745 // TODO: properly implement this check
746 return true;
747 }
748
749 static bool IsSupportedArgument(const Argument& arg)
750 {
752 {
753 return false;
754 }
755 bool supported = true;
756
757 constexpr index_t k_batch = 1;
758 bool isWave64 = get_warp_size() == 64;
759 for(index_t i = 0; i < arg.group_count_; ++i)
760 {
761 std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
762 std::array<index_t, NumDTensor> stride_Ds;
763 std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
764 if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) &&
765 !(GemmSpec == GemmSpecialization::MKPadding ||
766 GemmSpec == GemmSpecialization::NKPadding ||
767 GemmSpec == GemmSpecialization::MNKPadding ||
768 GemmSpec == GemmSpecialization::KPadding))
769 {
770 return false;
771 }
772 if(isWave64)
773 {
774 if constexpr(NXdlPerWave64 > 0)
775 {
776 using GridArg = typename GridwiseGemm64::Argument;
777 GridArg gridwise_arg(nullptr, // p_a_grid,
778 nullptr, // p_b_grid,
779 placeholder_p_ds_grid, // p_ds_grid,
780 nullptr, // p_e_grid ,
781 arg.gemm_descs_[i].M_,
782 arg.gemm_descs_[i].N_,
783 arg.gemm_descs_[i].K_,
784 arg.gemm_descs_[i].stride_A_,
785 arg.gemm_descs_[i].stride_B_,
786 stride_Ds,
787 arg.gemm_descs_[i].stride_C_,
788 k_batch,
789 arg.a_element_op_,
790 arg.b_element_op_,
791 arg.cde_element_op_);
792
793 supported = supported && GridwiseGemm64::CheckValidity(gridwise_arg);
794 }
795 else
796 {
797 supported = false;
798 }
799 }
800 else
801 {
802 if constexpr(NXdlPerWave32 > 0)
803 {
804 using GridArg = typename GridwiseGemm32::Argument;
805 GridArg gridwise_arg(nullptr, // p_a_grid,
806 nullptr, // p_b_grid,
807 placeholder_p_ds_grid, // p_ds_grid,
808 nullptr, // p_e_grid ,
809 arg.gemm_descs_[i].M_,
810 arg.gemm_descs_[i].N_,
811 arg.gemm_descs_[i].K_,
812 arg.gemm_descs_[i].stride_A_,
813 arg.gemm_descs_[i].stride_B_,
814 stride_Ds,
815 arg.gemm_descs_[i].stride_C_,
816 k_batch,
817 arg.a_element_op_,
818 arg.b_element_op_,
819 arg.cde_element_op_);
820
821 supported = supported && GridwiseGemm32::CheckValidity(gridwise_arg);
822 }
823 else
824 {
825 supported = false;
826 }
827 }
828 }
829
830 return supported;
831 }
832
833 bool IsSupportedArgument(const BaseArgument* p_arg) override
834 {
835 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
836 }
837
839 {
840 int occupancy = 0;
841 if(get_warp_size() == 64)
842 {
843 if constexpr(NXdlPerWave64 > 0)
844 {
847 GemmSpec,
848 ADataType,
849 BDataType,
850 DsDataType,
851 EDataType,
852 ALayout,
853 BLayout,
854 DsLayout,
855 ELayout,
856 KPerBlock,
859 AElementwiseOperation,
860 BElementwiseOperation,
861 CDEElementwiseOperation,
862 BlkGemmPipeSched,
863 BlkGemmPipelineVer>;
865 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
866 }
867 }
868 else
869 {
870
871 if constexpr(NXdlPerWave32 > 0)
872 {
875 GemmSpec,
876 ADataType,
877 BDataType,
878 DsDataType,
879 EDataType,
880 ALayout,
881 BLayout,
882 DsLayout,
883 ELayout,
884 KPerBlock,
887 AElementwiseOperation,
888 BElementwiseOperation,
889 CDEElementwiseOperation,
890 BlkGemmPipeSched,
891 BlkGemmPipelineVer>;
893 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
894 }
895 }
896 return occupancy;
897 }
898
899 static auto MakeArgument(std::vector<const void*>& p_As,
900 std::vector<const void*>& p_Bs,
901 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
902 std::vector<void*>& p_Es,
903 std::vector<GemmDesc>& gemm_descs,
904 AElementwiseOperation a_elementwise_op,
905 BElementwiseOperation b_elementwise_op,
906 CDEElementwiseOperation cde_elementwise_op)
907 {
908 int occupancy = GetKernelOccupancy();
909 int num_cu;
910
911 hipDeviceProp_t dev_prop;
912 hipDevice_t dev;
913 hip_check_error(hipGetDevice(&dev));
914 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
915 num_cu = dev_prop.multiProcessorCount;
916
917 return Argument{p_As,
918 p_Bs,
919 p_Ds,
920 p_Es,
921 gemm_descs,
922 a_elementwise_op,
923 b_elementwise_op,
924 cde_elementwise_op,
925 occupancy,
926 num_cu};
927 }
928
929 std::unique_ptr<BaseArgument>
930 MakeArgumentPointer(std::vector<const void*>& p_As,
931 std::vector<const void*>& p_Bs,
932 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
933 std::vector<void*>& p_Es,
934 std::vector<GemmDesc>& gemm_descs,
935 AElementwiseOperation a_elementwise_op,
936 BElementwiseOperation b_elementwise_op,
937 CDEElementwiseOperation cde_elementwise_op) override
938 {
939 int occupancy = GetKernelOccupancy();
940 int num_cu;
941
942 hipDeviceProp_t dev_prop;
943 hipDevice_t dev;
944 hip_check_error(hipGetDevice(&dev));
945 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
946 num_cu = dev_prop.multiProcessorCount;
947
948 return std::make_unique<Argument>(p_As,
949 p_Bs,
950 p_Ds,
951 p_Es,
952 gemm_descs,
953 a_elementwise_op,
954 b_elementwise_op,
955 cde_elementwise_op,
956 occupancy,
957 num_cu);
958 }
959
960 static auto MakeInvoker() { return Invoker{}; }
961
962 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
963 {
964 return std::make_unique<Invoker>(Invoker{});
965 }
966
967 std::string GetTypeString() const override
968 {
969 auto str = std::ostringstream();
970
971 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
974
975 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
981
982 // clang-format off
983 str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
984 << "<"
985 << std::string(ALayout::name)[0] << ","
986 << std::string(BLayout::name)[0] << ","
987 << std::string(ELayout::name)[0] << ","
988 << BlockSize << ", "
989 << MPerBlock << ", "
990 << NPerBlock << ", "
991 << KPerBlock << ", "
992 << AK1 << ", "
993 << BK1 << ", "
994 << MPerXDL << ", "
995 << NPerXDL << ", "
996 << MXdlPerWave << ", "
997 << NXdlPerWave << ", "
998 << ABlockTransferSrcScalarPerVector << ", "
999 << BBlockTransferSrcScalarPerVector << ", "
1000 << CShuffleMXdlPerWavePerShuffle << ", "
1001 << CShuffleNXdlPerWavePerShuffle << ", "
1002 << getGemmSpecializationString(GemmSpec) << ", "
1003 << "BlkGemmPipelineScheduler: "
1004 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
1005 << "BlkGemmPipelineVersion: "
1006 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
1007 << ">";
1008 // clang-format on
1009
1010 return str.str();
1011 }
1012
1014 void* p_dev_kernel_args,
1015 const void* p_host_kernel_args) const
1016 {
1017 arg.p_dev_gemm_args_ = p_dev_kernel_args;
1018 hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
1019 p_host_kernel_args,
1021 hipMemcpyHostToDevice));
1022 }
1023
1025 void* p_dev_kernel_args,
1026 const void* p_host_kernel_args) const override
1027 {
1028 return SetDeviceKernelArgs(
1029 *dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
1030 }
1031
1032 void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
1033 {
1034 arg.p_dev_gemm_args_ = p_dev_kernel_args;
1035 }
1036
1037 virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
1038 {
1039 return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
1040 }
1041
1042 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
1043 {
1044 return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(KernelArguments);
1045 }
1046};
1047
1048} // namespace device
1049} // namespace tensor_operation
1050} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Entry point kernel for device-wide Grouped GEMM operation.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:65
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
unsigned char uint8_t
Definition stdint.h:124
Definition ck/stream_config.hpp:10
int log_level_
Definition ck/stream_config.hpp:13
Definition block_to_ctile_map.hpp:271
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
Definition block_to_ctile_map.hpp:920
Definition block_to_ctile_map.hpp:872
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:537
CDEElementwiseOperation cde_element_op_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:573
const std::vector< GemmDesc > & gemm_descs_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:570
index_t tile_count_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:574
int gpu_cu_count_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:569
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:572
const void * p_dev_gemm_args_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:567
index_t group_count_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:566
int occupancy_num_blocks_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:568
Argument(std::vector< const void * > &, std::vector< const void * > &, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &, const std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, int occupancy_num_blocks, int gpu_cu_count)
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:538
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:571
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:595
float Run(const Argument &arg, const void *dev_gemm_args, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:610
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:657
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Launch Grouped Gemm kernel.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:642
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:578
static constexpr int BLOCK_SUBSCRIPTION_FACTOR
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:581
static int GetCuBlocks()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:586
static constexpr int CU_SIMDS
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:583
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:471
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:962
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1042
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:833
BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2ETileMap
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:532
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:474
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:749
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1024
static auto MakeInvoker()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:960
void SetDeviceKernelArgs(Argument &arg, void *p_dev_kernel_args) const
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1032
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:743
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:930
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:476
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1037
static int GetKernelOccupancy()
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:838
GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:479
OffsettedBlockToCTileMap2< Block2ETileMap > OffsettedLocalBlock2ETileMap
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:533
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop DeviceOp
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:472
std::string GetTypeString() const override
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:967
GroupedGemmKernelArgument< NumDTensor > KernelArguments
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:531
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:475
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op)
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:899
void SetDeviceKernelArgs(Argument &arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:1013
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:529
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp:528
Grouped GEMM kernel using output Tile Looping algorithm.
Definition device_grouped_gemm_tile_loop.hpp:43
Definition device_grouped_gemm.hpp:80
Structure representing single GEMM problem arguments.
Definition device_grouped_gemm.hpp:29