device_gemm_multiple_d_wmma_cshuffle.hpp Source File

device_gemm_multiple_d_wmma_cshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_d_wmma_cshuffle.hpp Source File
device_gemm_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename ALayout,
26 typename BLayout,
27 typename DsLayout,
28 typename ELayout,
29 typename ADataType,
30 typename BDataType,
31 typename AccDataType,
32 typename CShuffleDataType,
33 typename DsDataType,
34 typename EDataType,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CDEElementwiseOperation,
38 GemmSpecialization GemmSpec,
39 ck::index_t NumPrefetch,
40 ck::index_t BlockSize,
41 ck::index_t MPerBlock,
42 ck::index_t NPerBlock,
43 ck::index_t KPerBlock,
44 ck::index_t K1,
45 ck::index_t MPerWmma,
46 ck::index_t NPerWmma,
47 ck::index_t MRepeat,
48 ck::index_t NRepeat,
49 typename ABlockTransferThreadClusterLengths_K0_M_K1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
52 ck::index_t ABlockTransferSrcVectorDim,
53 ck::index_t ABlockTransferSrcScalarPerVector,
54 ck::index_t ABlockTransferDstScalarPerVector_K1,
55 bool ABlockLdsAddExtraM,
56 typename BBlockTransferThreadClusterLengths_K0_N_K1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 ck::index_t BBlockTransferSrcVectorDim,
60 ck::index_t BBlockTransferSrcScalarPerVector,
61 ck::index_t BBlockTransferDstScalarPerVector_K1,
62 bool BBlockLdsAddExtraN,
63 index_t CShuffleMRepeatPerShuffle,
64 index_t CShuffleNRepeatPerShuffle,
65 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
70 BLayout,
71 DsLayout,
72 ELayout,
73 ADataType,
74 BDataType,
75 DsDataType,
76 EDataType,
77 AElementwiseOperation,
78 BElementwiseOperation,
79 CDEElementwiseOperation>
80{
82 static constexpr index_t NumDTensor = DsDataType::Size();
83
84 static constexpr auto I0 = Number<0>{};
85 static constexpr auto I1 = Number<1>{};
86 static constexpr auto I2 = Number<2>{};
87 static constexpr auto I3 = Number<3>{};
88 static constexpr auto I4 = Number<4>{};
89 static constexpr auto I5 = Number<5>{};
90 static constexpr auto I6 = Number<6>{};
91 // K1 = Max Vector Access Pixels
92 static constexpr auto K1Number = Number<K1>{};
93
94 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
95 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
96 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
97
98 static constexpr auto AEnableLds_auto =
100 static constexpr auto BEnableLds_auto =
102
103 // If true, LDS is used unconditionally
104 static constexpr auto AEnableLds_manu = false;
105 static constexpr auto BEnableLds_manu = false;
106
107 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
108 static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
109
110 static constexpr auto matrix_padder =
111 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
112
113 // Describe how data read from Global memory
114 static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
115 {
116 const auto a_grid_desc_m_k = [&]() {
118 {
119 const auto a_grid_desc_mraw_kraw =
121
122 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
123 }
125 {
126 const auto a_grid_desc_mraw_kraw =
128
129 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
130 }
131 }();
132
133 const auto M = a_grid_desc_m_k.GetLength(I0);
134 const auto K = a_grid_desc_m_k.GetLength(I1);
135 assert(K % K1 == 0);
136
137 if constexpr(AEnableLds)
138 {
139 const index_t K0 = K / K1;
140
142 a_grid_desc_m_k,
147 }
148 else
149 {
150 constexpr auto A_KRow = 2;
151 constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
152 const auto A_KWmma = K / WmmaK;
153
154 const auto M0 = M / MPerBlock;
155 // 0 1 0 1 2 3 4 5 6
156 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
158 a_grid_desc_m_k,
162 make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
165 }
166 }
167
168 static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
169 {
170 const auto b_grid_desc_n_k = [&]() {
172 {
173 const auto b_grid_desc_nraw_kraw =
175
176 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
177 }
179 {
180 const auto b_grid_desc_nraw_kraw =
182
183 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
184 }
185 }();
186
187 const auto N = b_grid_desc_n_k.GetLength(I0);
188 const auto K = b_grid_desc_n_k.GetLength(I1);
189 assert(K % K1 == 0);
190
191 if constexpr(BEnableLds)
192 {
193 const index_t K0 = K / K1;
194
196 b_grid_desc_n_k,
201 }
202 else
203 {
204 constexpr auto B_KRow = 2;
205 constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
206 const auto B_KWmma = K / WmmaK;
207
208 const auto N0 = N / NPerBlock;
209 // 0 1 0 1 2 3 4 5 6
210 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
212 b_grid_desc_n_k,
216 make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
219 }
220 }
221
222 template <typename ELayout_>
223 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
224 {
225 const auto e_grid_desc_mraw_nraw = [&]() {
227 {
228 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
229 make_tuple(StrideE, I1));
230 }
232 {
233 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
234 make_tuple(I1, StrideE));
235 }
236 }();
237
238 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
239 }
240
241 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
242 const std::array<index_t, NumDTensor>& Ns,
243 const std::array<index_t, NumDTensor>& DsStride)
244 {
245 return generate_tuple(
246 [&](auto i) {
247 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
248
249 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(Ms[i], Ns[i], DsStride[i]);
250 },
252 }
253
254 // Gridwise descriptor, mapping to whole given provblem.
255 using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
256 using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
259
260 // GridwiseOp
262 // DataType Family
263 ADataType,
264 BDataType,
265 AccDataType,
266 CShuffleDataType,
267 DsDataType,
268 EDataType,
269 // InMemory Data Descriptor
270 AGridDesc,
271 BGridDesc,
274 // ElementwiseOp Family
275 AElementwiseOperation,
276 BElementwiseOperation,
277 CDEElementwiseOperation,
279 // Tiling Family
280 MPerBlock,
281 NPerBlock,
282 KPerBlock,
283 MPerWmma,
284 NPerWmma,
285 K1,
286 MRepeat,
287 NRepeat,
288 // ThreadCluster Family
289 BlockSize,
290 ABlockTransferThreadClusterLengths_K0_M_K1,
291 ABlockTransferThreadClusterArrangeOrder,
292 ABlockTransferSrcAccessOrder,
293 ABlockTransferSrcVectorDim,
294 ABlockTransferSrcScalarPerVector,
295 ABlockTransferDstScalarPerVector_K1,
296 false, // AThreadTransferSrcResetCoordinateAfterRun,
298 ABlockLdsAddExtraM,
299 BBlockTransferThreadClusterLengths_K0_N_K1,
300 BBlockTransferThreadClusterArrangeOrder,
301 BBlockTransferSrcAccessOrder,
302 BBlockTransferSrcVectorDim,
303 BBlockTransferSrcScalarPerVector,
304 BBlockTransferDstScalarPerVector_K1,
305 false, // BThreadTransferSrcResetCoordinateAfterRun,
307 BBlockLdsAddExtraN,
308 CShuffleMRepeatPerShuffle,
309 CShuffleNRepeatPerShuffle,
310 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
311 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
312 NumPrefetch,
313 LoopSched,
314 PipelineVer>;
315
316 // Argument
317 struct Argument : public BaseArgument
318 {
319 Argument(const void* p_a_grid,
320 const void* p_b_grid,
321 std::array<const void*, NumDTensor> p_ds_grid,
322 void* p_e_grid,
323 index_t M,
324 index_t N,
325 index_t K,
326 index_t StrideA,
327 index_t StrideB,
328 std::array<index_t, NumDTensor> StrideDs,
329 index_t StrideE,
330 index_t M01,
331 index_t N01,
332 AElementwiseOperation a_element_op,
333 BElementwiseOperation b_element_op,
334 CDEElementwiseOperation cde_element_op)
335 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
336 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
337 p_ds_grid_{},
338 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
339 a_grid_desc{},
340 b_grid_desc{},
346 M01_{M01},
347 N01_{N01},
348 a_element_op_{a_element_op},
349 b_element_op_{b_element_op},
350 cde_element_op_{cde_element_op},
351 MRaw_{M},
352 NRaw_{N},
353 KRaw_{K}
354 {
357 static_for<0, NumDTensor, 1>{}([&](auto i) {
358 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
359 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
360
361 // D pointer
362 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
363
364 // D desc
367 });
369
371
377 {
381
385 }
386 }
387
388 // Pointers
389 const ADataType* p_a_grid_;
390 const BDataType* p_b_grid_;
392 EDataType* p_e_grid_;
393
394 // Tensor Descriptors
403
404 // Block to Tile mapping
406
407 // Idle
410
411 // ElementwiseOp
412 AElementwiseOperation a_element_op_;
413 BElementwiseOperation b_element_op_;
414 CDEElementwiseOperation cde_element_op_;
415
416 // for checking vector load/store
420 };
421
422 // Invoker
423 struct Invoker : public BaseInvoker
424 {
426
427 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
428 {
430 arg.b_grid_desc,
434 {
435 throw std::runtime_error(
436 "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
437 }
438
439 const index_t grid_size =
440 arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
441
442 const auto K = [&]() {
443 if constexpr(AEnableLds)
444 {
445 return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
446 }
447 else
448 {
449 return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
450 arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
451 }
452 }();
453
454 auto launch_kernel = [&](auto has_main_k_block_loop) {
455 const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
457 ADataType,
458 BDataType,
460 EDataType,
467 AElementwiseOperation,
468 BElementwiseOperation,
469 CDEElementwiseOperation,
471 has_main_k_block_loop>; // Last Option is W/O
472
473 return launch_and_time_kernel(stream_config,
474 kernel,
475 dim3(grid_size),
476 dim3(BlockSize),
477 0,
478 arg.p_a_grid_,
479 arg.p_b_grid_,
480 arg.p_ds_grid_,
481 arg.p_e_grid_,
482 arg.a_grid_desc,
483 arg.b_grid_desc,
486 arg.a_element_op_,
487 arg.b_element_op_,
488 arg.cde_element_op_,
490 };
491
493 {
494 return launch_kernel(integral_constant<bool, true>{});
495 }
496 else
497 {
498 return launch_kernel(integral_constant<bool, false>{});
499 }
500 }
501
502 // polymorphic
503 float Run(const BaseArgument* p_arg,
504 const StreamConfig& stream_config = StreamConfig{}) override
505 {
506 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
507 }
508 };
509
510 static constexpr bool IsValidCompilationParameter()
511 {
512 // TODO: properly implement this check
513 return true;
514 }
515
516 static bool IsSupportedArgument(const Argument& arg)
517 {
519 {
521 {
522 return false;
523 }
524 }
525 else
526 {
527 return false;
528 }
529 // check vector load/store
530 {
533
534 // check vector load of A
535 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
536 {
537 if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
538 {
539 return false;
540 }
541 }
542 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
543 {
544 // FIXME: not rigorous
545 if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
546 {
547 return false;
548 }
549 }
550 else
551 {
552 return false;
553 }
554
555 // check vector laod of B
556 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
557 {
558 if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
559 {
560 return false;
561 }
562 }
563 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
564 {
565 // FIXME: not rigorous
566 if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
567 {
568 return false;
569 }
570 }
571 else
572 {
573 return false;
574 }
575
576 // check vector load of Ds
577 // only support RowMajor for now
578 bool all_valid = true;
579
580 static_for<0, NumDTensor, 1>{}([&](auto i) {
581 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
582
583 if constexpr(!is_same_v<DLayout, Row>)
584 {
585 all_valid = false;
586 }
587 });
588
589 if(!all_valid)
590 {
591 return false;
592 }
593
594 // check vector store of E
595 // only support RowMajor for now
596 if constexpr(is_same_v<ELayout, Row>)
597 {
598 if(arg.NRaw_ % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
599 {
600 return false;
601 }
602 }
603 else
604 {
605 return false;
606 }
607 }
608
610 arg.b_grid_desc,
614 }
615
616 // polymorphic
617 bool IsSupportedArgument(const BaseArgument* p_arg) override
618 {
619 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
620 }
621
622 static auto MakeArgument(const void* p_a,
623 const void* p_b,
624 std::array<const void*, NumDTensor> p_ds,
625 void* p_e,
626 index_t M,
627 index_t N,
628 index_t K,
629 index_t StrideA,
630 index_t StrideB,
631 std::array<ck::index_t, NumDTensor> StrideDs,
632 index_t StrideE,
633 AElementwiseOperation a_element_op,
634 BElementwiseOperation b_element_op,
635 CDEElementwiseOperation cde_element_op)
636 {
637 return Argument{p_a,
638 p_b,
639 p_ds,
640 p_e,
641 M,
642 N,
643 K,
644 StrideA,
645 StrideB,
646 StrideDs,
647 StrideE,
648 1,
649 1,
650 a_element_op,
651 b_element_op,
652 cde_element_op};
653 }
654
655 // polymorphic
656 std::unique_ptr<BaseArgument>
657 MakeArgumentPointer(const void* p_a,
658 const void* p_b,
659 std::array<const void*, NumDTensor> p_ds,
660 void* p_e,
661 index_t M,
662 index_t N,
663 index_t K,
664 index_t StrideA,
665 index_t StrideB,
666 std::array<ck::index_t, NumDTensor> StrideDs,
667 index_t StrideE,
668 AElementwiseOperation a_element_op,
669 BElementwiseOperation b_element_op,
670 CDEElementwiseOperation cde_element_op) override
671 {
672 return std::make_unique<Argument>(p_a,
673 p_b,
674 p_ds,
675 p_e,
676 M,
677 N,
678 K,
679 StrideA,
680 StrideB,
681 StrideDs,
682 StrideE,
683 1,
684 1,
685 a_element_op,
686 b_element_op,
687 cde_element_op);
688 }
689
690 static auto MakeInvoker() { return Invoker{}; }
691
692 // polymorphic
693 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
694 {
695 return std::make_unique<Invoker>(Invoker{});
696 }
697
698 // polymorphic
699 std::string GetTypeString() const override
700 {
701 auto str = std::stringstream();
702
703 std::map<LoopScheduler, std::string> LoopSchedToString{
704 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
705
706 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
707 {PipelineVersion::v2, "v2"}};
708
709 // clang-format off
710 str << "DeviceGemmMultipleD_Wmma_CShuffle"
711 << "<"
712 << BlockSize << ", "
713 << MPerBlock << ", "
714 << NPerBlock << ", "
715 << KPerBlock << ", "
716 << K1 << ", "
717 << MPerWmma << ", "
718 << NPerWmma << ", "
719 << MRepeat << ", "
720 << NRepeat
721 << ">"
722 << " AEnableLds: "
723 << AEnableLds << ", "
724 << "BEnableLds: "
725 << BEnableLds << ", "
726 << "NumPrefetch: "
727 << NumPrefetch << ", "
728 << "LoopScheduler: "
729 << LoopSchedToString[LoopSched] << ", "
730 << "PipelineVersion: "
731 << PipelineVersionToString[PipelineVer];
732 // clang-format on
733
734 return str.str();
735 }
736};
737
738} // namespace device
739} // namespace tensor_operation
740} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__global__ void kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:225
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:318
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:390
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:405
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:397
GridwiseOp::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:391
index_t KRaw_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:419
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:389
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:414
AGridDesc a_grid_desc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:395
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:412
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:319
EDataType * p_e_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:392
index_t MRaw_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:417
index_t M01_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:408
BGridDesc b_grid_desc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:396
index_t N01_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:409
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:413
GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:400
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:398
index_t NRaw_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:418
GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:402
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:424
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:503
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:425
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:427
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:80
static constexpr auto AEnableLds_auto
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:98
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:255
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:657
static constexpr auto MWaves
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:94
static constexpr auto I4
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:88
static constexpr auto BEnableLds_auto
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:100
static constexpr auto K1Number
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:92
static constexpr auto BEnableLds
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:108
static constexpr auto AEnableLds
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:107
static constexpr auto AEnableLds_manu
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:104
static constexpr auto NWaves
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:95
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:256
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:114
static constexpr auto I3
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:87
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:622
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:516
static constexpr auto I0
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:84
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &Ms, const std::array< index_t, NumDTensor > &Ns, const std::array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:241
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:617
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:257
static constexpr auto I2
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:86
static constexpr auto BEnableLds_manu
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:105
static constexpr auto I5
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:89
static constexpr auto I6
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:90
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:693
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:110
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:258
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:168
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:261
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:82
static constexpr auto WmmaK
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:96
std::string GetTypeString() const override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:699
DeviceGemmMultipleD_Wmma_CShuffle DeviceOp
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:81
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:510
static auto MakeInvoker()
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:690
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:223
static constexpr auto I1
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:85
Definition device_gemm_multiple_d.hpp:36
Definition matrix_padder.hpp:180