device_gemm_xdl_cshuffle_v3_b_scale.hpp Source File

device_gemm_xdl_cshuffle_v3_b_scale.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_v3_b_scale.hpp Source File
device_gemm_xdl_cshuffle_v3_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename ALayout,
26 typename BLayout,
27 typename CLayout,
28 typename ADataType,
29 typename BDataType,
30 typename BScaleDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t BlockSize,
39 index_t ScaleBlockN, // scale block for N
40 index_t ScaleBlockK, // scale block for K
41 index_t MPerBlock,
42 index_t NPerBlock,
43 index_t KPerBlock,
44 index_t AK1,
45 index_t BK1,
46 index_t MPerXDL,
47 index_t NPerXDL,
48 index_t MXdlPerWave,
49 index_t NXdlPerWave,
50 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51 typename ABlockTransferThreadClusterArrangeOrder,
52 typename ABlockTransferSrcAccessOrder,
53 index_t ABlockTransferSrcVectorDim,
54 index_t ABlockTransferSrcScalarPerVector,
55 index_t ABlockTransferDstScalarPerVector_AK1,
56 bool ABlockLdsExtraM,
57 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58 typename BBlockTransferThreadClusterArrangeOrder,
59 typename BBlockTransferSrcAccessOrder,
60 index_t BBlockTransferSrcVectorDim,
61 index_t BBlockTransferSrcScalarPerVector,
62 index_t BBlockTransferDstScalarPerVector_BK1,
63 bool BBlockLdsExtraN,
64 index_t CShuffleMXdlPerWavePerShuffle,
65 index_t CShuffleNXdlPerWavePerShuffle,
66 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
70 typename ComputeTypeA = CDataType,
71 typename ComputeTypeB = ComputeTypeA,
72 bool PermuteA = false,
73 bool PermuteB = false>
74struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
75 BLayout,
76 CLayout,
77 ADataType,
78 BDataType,
79 BScaleDataType,
80 CDataType,
81 ScaleBlockN,
82 ScaleBlockK,
83 AElementwiseOperation,
84 BElementwiseOperation,
85 CElementwiseOperation>
86{
88 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
89 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
90
91 // GridwiseGemm
92 template <index_t NXdlPerWave_>
94 ALayout,
95 BLayout,
96 CLayout,
97 ADataType,
98 BDataType,
99 GemmAccDataType,
100 CShuffleDataType,
101 CDataType,
102 AElementwiseOperation,
103 BElementwiseOperation,
104 CElementwiseOperation,
105 GemmSpec,
106 BlockSize,
107 ScaleBlockN,
108 ScaleBlockK,
109 MPerBlock,
110 NPerBlock,
111 KPerBlock,
112 AK1,
113 BK1,
114 MPerXDL,
115 NPerXDL,
116 MXdlPerWave,
117 NXdlPerWave_,
118 ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 ABlockTransferThreadClusterArrangeOrder,
120 ABlockTransferSrcAccessOrder,
121 ABlockTransferSrcVectorDim,
122 ABlockTransferSrcScalarPerVector,
123 ABlockTransferDstScalarPerVector_AK1,
124 false,
125 ABlockLdsExtraM,
126 BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 BBlockTransferThreadClusterArrangeOrder,
128 BBlockTransferSrcAccessOrder,
129 BBlockTransferSrcVectorDim,
130 BBlockTransferSrcScalarPerVector,
131 BBlockTransferDstScalarPerVector_BK1,
132 false,
133 BBlockLdsExtraN,
134 CShuffleMXdlPerWavePerShuffle,
135 CShuffleNXdlPerWavePerShuffle,
136 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 CShuffleBlockTransferScalarPerVector_NPerBlock,
138 BlkGemmPipeSched,
139 BlkGemmPipelineVer,
140 ComputeTypeA,
141 ComputeTypeB,
142 PermuteA,
143 PermuteB>;
146
147 using Argument = typename GridwiseGemm64::Argument;
148
149 static constexpr index_t APackedSize = []() {
151 return 2;
152 else
153 return 1;
154 }();
155
156 static constexpr index_t BPackedSize = []() {
158 return 2;
159 else
160 return 1;
161 }();
162
163 // Invoker
164 struct Invoker : public BaseInvoker
165 {
166 template <typename GridwiseGemm>
167 float RunImp(const typename GridwiseGemm::Argument& arg,
168 const StreamConfig& stream_config = StreamConfig{})
169 {
170 if(stream_config.log_level_ > 0)
171 {
172 arg.Print();
173 }
174
175 if(!GridwiseGemm::CheckValidity(arg))
176 {
177 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
178 }
179
180 index_t gdx, gdy, gdz;
181 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
182
183 float ave_time = 0;
184
185 index_t k_grain = arg.KBatch * KPerBlock;
186 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
187
188 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
189
190 const auto Run = [&](const auto& kernel) {
191 if(stream_config.flush_cache)
192 {
193 auto arg_ = arg;
194
195 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
196 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
197 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
198 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
199
200 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
201 sizeof(ADataType) / APackedSize;
202 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
203 sizeof(BDataType) / BPackedSize;
204
206 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
207 rotating_mem.Print();
208
209 auto run_flush_cache = [&]() {
210 // flush icache
212 // rotating mem
213 rotating_mem.Next();
214 // clear c mem
215 if(arg_.KBatch > 1)
216 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
217 0,
218 arg_.M * arg_.N * sizeof(CDataType),
219 stream_config.stream_id_));
220 };
221
223 stream_config,
224 run_flush_cache,
225 kernel,
226 dim3(gdx, gdy, gdz),
227 dim3(BlockSize),
228 0,
229 arg_);
230 }
231 else
232 {
233 if(arg.KBatch > 1)
234 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
235 0,
236 arg.M * arg.N * sizeof(CDataType),
237 stream_config.stream_id_));
238
239 ave_time = launch_and_time_kernel(
240 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
241 }
242 };
243
244 constexpr index_t minimum_occupancy =
245 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
246 ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
247 MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
248 ? 2
249 : 1
250 : 2;
251
252 if(has_main_k_block_loop)
253 {
254 // Tail number always full
255 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
256 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
257 {
258 if(arg.KBatch > 1)
259 {
260 const auto kernel =
261 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
262 true,
264 minimum_occupancy>;
265 Run(kernel);
266 }
267 else
268 {
269 const auto kernel =
270 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
271 true,
273 minimum_occupancy>;
274 Run(kernel);
275 }
276 }
277 // Tail number could be One to Seven
278 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
279 {
280 if(arg.KBatch > 1)
281 {
282 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
283 {
284 const auto kernel =
285 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
286 true,
288 minimum_occupancy,
290 Run(kernel);
291 }
292 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
294 {
295 const auto kernel =
296 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
297 true,
299 minimum_occupancy,
301 Run(kernel);
302 }
303
304 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
305 {
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
307 {
308 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
309 GridwiseGemm,
310 true,
312 minimum_occupancy,
314 Run(kernel);
315 }
316 }
317
318 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
319 {
320 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
322 {
323 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
324 GridwiseGemm,
325 true,
327 minimum_occupancy,
329 Run(kernel);
330 }
331 }
332
333 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
334 {
335 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
337 {
338 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
339 GridwiseGemm,
340 true,
342 minimum_occupancy,
344 Run(kernel);
345 }
346 }
347
348 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
349 {
350 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
352 {
353 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
354 GridwiseGemm,
355 true,
357 minimum_occupancy,
359 Run(kernel);
360 }
361 }
362
363 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
364 {
365 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
366 {
367 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
368 GridwiseGemm,
369 true,
371 minimum_occupancy,
373 Run(kernel);
374 }
375 }
376
377 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
378 {
379 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
381 {
382 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
383 GridwiseGemm,
384 true,
386 minimum_occupancy,
388 Run(kernel);
389 }
390 }
391 }
392 else
393 {
394 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
395 {
396 const auto kernel =
397 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
398 true,
400 minimum_occupancy,
402 Run(kernel);
403 }
404 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
406 {
407 const auto kernel =
408 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
409 true,
411 minimum_occupancy,
413 Run(kernel);
414 }
415
416 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
417 {
418 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
419 {
420 const auto kernel =
421 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
422 true,
424 minimum_occupancy,
426 Run(kernel);
427 }
428 }
429
430 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
431 {
432 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
434 {
435 const auto kernel =
436 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
437 true,
439 minimum_occupancy,
441 Run(kernel);
442 }
443 }
444
445 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
446 {
447 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
449 {
450 const auto kernel =
451 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
452 true,
454 minimum_occupancy,
456 Run(kernel);
457 }
458 }
459
460 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
461 {
462 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
464 {
465 const auto kernel =
466 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
467 true,
469 minimum_occupancy,
471 Run(kernel);
472 }
473 }
474
475 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
476 {
477 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
478 {
479 const auto kernel =
480 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
481 true,
483 minimum_occupancy,
485 Run(kernel);
486 }
487 }
488
489 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
490 {
491 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
493 {
494 const auto kernel =
495 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
496 true,
498 minimum_occupancy,
500 Run(kernel);
501 }
502 }
503 }
504 }
505 // Tail number could be Odd or Even
506 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
507 {
508 if(arg.KBatch > 1)
509 {
510 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
511 {
512 const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
513 GridwiseGemm,
514 true,
516 minimum_occupancy,
518 Run(kernel);
519 }
520 else
521 {
522 const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
523 GridwiseGemm,
524 true,
526 minimum_occupancy,
528 Run(kernel);
529 }
530 }
531 else
532 {
533 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
534 {
535 const auto kernel =
537 true,
539 minimum_occupancy,
541 Run(kernel);
542 }
543 else
544 {
545 const auto kernel =
547 true,
549 minimum_occupancy,
551 Run(kernel);
552 }
553 }
554 }
555 else
556 {
557 if(arg.KBatch > 1)
558 {
559 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
560 {
561 const auto kernel =
562 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
563 true,
565 minimum_occupancy,
567 Run(kernel);
568 }
569 else
570 {
571 const auto kernel =
572 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
573 true,
575 minimum_occupancy,
577 Run(kernel);
578 }
579 }
580 else
581 {
582 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
583 {
584 const auto kernel =
585 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
586 true,
588 minimum_occupancy,
590 Run(kernel);
591 }
592 else
593 {
594 const auto kernel =
595 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
596 true,
598 minimum_occupancy,
600 Run(kernel);
601 }
602 }
603 }
604 }
605 else
606 {
607 // Tail number always 1
608 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
609 {
610 if(arg.KBatch > 1)
611 {
612 const auto kernel =
613 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
614 false,
616 minimum_occupancy>;
617 Run(kernel);
618 }
619 else
620 {
621 const auto kernel =
622 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
623 false,
625 minimum_occupancy>;
626 Run(kernel);
627 }
628 }
629 }
630
631 return ave_time;
632 }
633
635 // polymorphic
636 float Run(const BaseArgument* p_arg,
637 const StreamConfig& stream_config = StreamConfig{}) override
638 {
639 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
640 }
641 };
642
643 static constexpr bool IsValidCompilationParameter()
644 {
645 // TODO: properly implement this check
646 return true;
647 }
648
649 static bool IsSupportedArgument(const Argument& arg)
650 {
652 {
653 return false;
654 }
655
656 if(is_gfx11_supported() && arg.KBatch > 1)
657 {
658 return false;
659 }
660
661 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
662 {
663 return false;
664 }
665
666 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
667 GemmSpec == GemmSpecialization::NKPadding ||
668 GemmSpec == GemmSpecialization::MNKPadding ||
669 GemmSpec == GemmSpecialization::KPadding))
670 {
671 return false;
672 }
673 if(get_warp_size() == 64)
674 {
675 if constexpr(NXdlPerWave64 > 0)
676 {
678 }
679 }
680 else
681 {
682
683 if constexpr(NXdlPerWave32 > 0)
684 {
686 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
687 }
688 }
689 return false;
690 }
691
692 // polymorphic
693 bool IsSupportedArgument(const BaseArgument* p_arg) override
694 {
695 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
696 }
697
698 index_t GetKPerBlock() override { return KPerBlock; }
699
700 bool GetPermuteB() override { return PermuteB; }
701
702 static auto MakeArgument(const ADataType* p_a,
703 const BDataType* p_b,
704 CDataType* p_c,
705 index_t M,
706 index_t N,
707 index_t K,
708 index_t StrideA,
709 index_t StrideB,
710 index_t StrideC,
711 index_t StrideScaleB,
712 const BScaleDataType* p_b_scale,
713 index_t KBatch,
714 AElementwiseOperation a_element_op,
715 BElementwiseOperation b_element_op,
716 CElementwiseOperation c_element_op)
717 {
718 return Argument{p_a,
719 p_b,
720 p_c,
721 M,
722 N,
723 K,
724 StrideA,
725 StrideB,
726 StrideC,
727 StrideScaleB,
728 p_b_scale,
729 KBatch,
730 a_element_op,
731 b_element_op,
732 c_element_op};
733 }
734
735 static auto MakeInvoker() { return Invoker{}; }
736
737 // polymorphic
738 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
739 const void* p_b,
740 void* p_c,
741 index_t M,
742 index_t N,
743 index_t K,
744 index_t StrideA,
745 index_t StrideB,
746 index_t StrideC,
747 index_t StrideScaleB,
748 const void* p_b_scale,
749 index_t KBatch,
750 AElementwiseOperation a_element_op,
751 BElementwiseOperation b_element_op,
752 CElementwiseOperation c_element_op) override
753 {
754 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
755 static_cast<const BDataType*>(p_b),
756 static_cast<CDataType*>(p_c),
757 M,
758 N,
759 K,
760 StrideA,
761 StrideB,
762 StrideC,
763 StrideScaleB,
764 static_cast<const BScaleDataType*>(p_b_scale),
765 KBatch,
766 a_element_op,
767 b_element_op,
768 c_element_op);
769 }
770
771 // polymorphic
772 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
773 {
774 return std::make_unique<Invoker>(Invoker{});
775 }
776
777 // polymorphic
778 std::string GetTypeString() const override
779 {
780 auto str = std::stringstream();
781
782 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
785
786 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
792
793 // clang-format off
794 str << "DeviceGemmXdlUniversal"
795 << "<"
796 << getGemmSpecializationString(GemmSpec) << ", "
797 << std::string(ALayout::name)[0]
798 << std::string(BLayout::name)[0]
799 << std::string(CLayout::name)[0]
800 << ">"
801 << " BlkSize: "
802 << BlockSize << ", "
803 << "BlkTile: "
804 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
805 << "WaveTile: "
806 << MPerXDL<<"x"<<NPerXDL << ", "
807 << "WaveMap: "
808 << MXdlPerWave<<"x" << NXdlPerWave<<", "
809 << "VmemReadVec: "
810 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
811 << "BlkGemmPipelineScheduler: "
812 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
813 << "BlkGemmPipelineVersion: "
814 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
815 << "BlkGemmPipelinePrefetchStages: "
816 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
817 // clang-format on
818
819 return str.str();
820 }
821};
822
823} // namespace device
824} // namespace tensor_operation
825} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
Definition data_type.hpp:187
Definition device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition device_gemm_xdl_cshuffle_v3.hpp:263
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:167
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3.hpp:747
"Universal" GEMM operation with SplitK support.
Definition device_gemm_xdl_cshuffle_v3.hpp:178
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3.hpp:235
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:693
index_t GetKPerBlock() override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:698
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:772
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:649
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:643
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3.hpp:234
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:778
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3.hpp:181
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3.hpp:185
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, const BScaleDataType *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:702
bool GetPermuteB() override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:700
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:735
static constexpr index_t BPackedSize
Definition device_gemm_xdl_cshuffle_v3.hpp:246
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3.hpp:237
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3.hpp:182
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_xdl_cshuffle_v3_b_scale.hpp:738
static constexpr index_t APackedSize
Definition device_gemm_xdl_cshuffle_v3.hpp:239
Definition device_gemm_v2.hpp:93
Definition flush_cache.hpp:299