device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File

device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File
device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
13#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
24template <typename InDataType,
25 typename WeiDataType,
26 typename OutDataType,
27 typename AccDataType,
28 typename InElementwiseOperation,
29 typename WeiElementwiseOperation,
30 typename OutElementwiseOperation,
31 ck::index_t BlockSize,
32 ck::index_t MPerBlock,
33 ck::index_t NPerBlock,
34 ck::index_t K0PerBlock,
35 ck::index_t K1,
36 ck::index_t MPerXdl,
37 ck::index_t NPerXdl,
38 ck::index_t MXdlPerWave,
39 ck::index_t NXdlPerWave,
40 typename ABlockTransferThreadClusterLengths_K0_M_K1,
41 typename ABlockTransferThreadClusterArrangeOrder,
42 typename ABlockTransferSrcAccessOrder,
43 ck::index_t ABlockTransferSrcVectorDim,
44 ck::index_t ABlockTransferSrcScalarPerVector,
45 ck::index_t ABlockTransferDstScalarPerVector_K1,
46 bool ABlockLdsAddExtraM,
47 typename BBlockTransferThreadClusterLengths_K0_N_K1,
48 typename BBlockTransferThreadClusterArrangeOrder,
49 typename BBlockTransferSrcAccessOrder,
50 ck::index_t BBlockTransferSrcVectorDim,
51 ck::index_t BBlockTransferSrcScalarPerVector,
52 ck::index_t BBlockTransferDstScalarPerVector_K1,
53 bool BBlockLdsAddExtraN,
54 index_t CShuffleMXdlPerWavePerShuffle,
55 index_t CShuffleNXdlPerWavePerShuffle,
56 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
57 index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
59 : public DeviceConvBwdWeight<2,
60 ck::tensor_layout::convolution::NHWC,
61 ck::tensor_layout::convolution::KYXC,
62 ck::tensor_layout::convolution::NHWK,
63 InDataType,
64 WeiDataType,
65 OutDataType,
66 InElementwiseOperation,
67 WeiElementwiseOperation,
68 OutElementwiseOperation>
69{
70 static constexpr ck::index_t NDimSpatial = 2;
71
72 using DeviceOp =
74
76 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
77 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
78
79 using ADataType = OutDataType;
80 using BDataType = InDataType;
81 using CDataType = WeiDataType;
82
83 using AElementwiseOperation = OutElementwiseOperation;
84 using BElementwiseOperation = InElementwiseOperation;
85 using CElementwiseOperation = WeiElementwiseOperation;
86
87 // TODO make A/B datatype different
88 using ABDataType = InDataType;
89
90 static constexpr auto I0 = Number<0>{};
91 static constexpr auto I1 = Number<1>{};
92 static constexpr auto I2 = Number<2>{};
93 static constexpr auto I3 = Number<3>{};
94 static constexpr auto I4 = Number<4>{};
95 static constexpr auto I5 = Number<5>{};
96
97 static constexpr auto K1Number = Number<K1>{};
98 static constexpr auto GemmK1Number = K1Number;
99
100 static constexpr auto N1Number = K1Number;
101
102 // Bytes per 32 lds bank: 32 * 4 bytes
103 static constexpr auto BankLength = 128;
104 static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
105
106 // M1 & M0
107 static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
108 static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
109 static constexpr auto ABlockLdsM1Padding = 4;
110
111 // N1 & N0
112 static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
113 static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
114 static constexpr auto BBlockLdsN1Padding = 4;
115
117 ck::index_t N,
118 ck::index_t K,
119 ck::index_t C,
120 std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
121 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
122 std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
123 std::array<ck::index_t, NDimSpatial> conv_filter_strides,
124 std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
125 std::array<ck::index_t, NDimSpatial> input_left_pads,
126 std::array<ck::index_t, NDimSpatial> input_right_pads,
127 ck::index_t batch_k)
128 {
129 using namespace ck;
130
131 const index_t Hi = input_spatial_lengths[0];
132 const index_t Wi = input_spatial_lengths[1];
133
134 const index_t Ho = output_spatial_lengths[0];
135 const index_t Wo = output_spatial_lengths[1];
136
137 const index_t Y = filter_spatial_lengths[0];
138 const index_t X = filter_spatial_lengths[1];
139
140 const index_t ConvStrideH = conv_filter_strides[0];
141 const index_t ConvStrideW = conv_filter_strides[1];
142
143 const index_t ConvDilationH = conv_filter_dilations[0];
144 const index_t ConvDilationW = conv_filter_dilations[1];
145
146 const index_t InLeftPadH = input_left_pads[0];
147 const index_t InLeftPadW = input_left_pads[1];
148
149 const index_t InRightPadH = input_right_pads[0];
150 const index_t InRightPadW = input_right_pads[1];
151
152 const index_t GemmKTotal = N * Ho * Wo;
153 const index_t GemmM = K;
154 const index_t GemmN = C * X * Y;
155
156 const index_t GemmKBatch = batch_k;
157 const index_t GemmK0 =
158 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
159 K0PerBlock;
160
161 const auto in_n_hi_wi_c_grid_desc =
163
164 // A: output tensor
165 const index_t N0 = N / N1Number;
166 const index_t GemmK0Total = N0 * Ho * Wo;
167
168 const index_t GemmK0S =
169 math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock;
170 const index_t GemmK0Pad = GemmKBatch * GemmK0S;
171 const auto out_n_ho_wo_k_grid_desc =
173
174 const auto out_n0_ho_wo_k_n1_grid_desc =
175 transform_tensor_descriptor(out_n_ho_wo_k_grid_desc,
181
182 const auto out_gemmk0total_gemmm_gemmk1_grid_desc =
183 transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc,
189
190 const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
191 out_gemmk0total_gemmm_gemmk1_grid_desc,
192 make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
197
198 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
199 out_gemmk0pad_gemmm_gemmk1_grid_desc,
200 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
205
206 // B: input tensor
207 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
208 in_n_hi_wi_c_grid_desc,
210 make_pad_transform(Hi, InLeftPadH, InRightPadH),
211 make_pad_transform(Wi, InLeftPadW, InRightPadW),
215
216 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
217 in_n_hip_wip_c_grid_desc,
220 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
221 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
225
226 const auto in_n0_y_ho_x_wo_c_n1_grid_desc =
227 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
235 Sequence<1>{},
236 Sequence<2>{},
237 Sequence<3>{},
238 Sequence<4>{},
239 Sequence<5>{}),
241 Sequence<1>{},
242 Sequence<2>{},
243 Sequence<3>{},
244 Sequence<4>{},
245 Sequence<5>{}));
246
247 const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
248 in_n0_y_ho_x_wo_c_n1_grid_desc,
254
255 const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
256 in_gemmk0total_gemmn_gemmk1_grid_desc,
257 make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
262
263 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
264 in_gemmk0pad_gemmn_gemmk1_grid_desc,
265 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
270
271 // C: weight tensor
272 const auto wei_gemmm_gemmn_grid_desc =
274
275 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
276 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
277 wei_gemmm_gemmn_grid_desc);
278 }
279
281 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1));
282
286
287 // GridwiseGemm
288 template <index_t NXdlPerWave_>
290 BlockSize,
291 ADataType, // TODO: distinguish A/B datatype
292 AccDataType,
293 CDataType,
301 MPerBlock,
302 NPerBlock,
303 K0PerBlock,
304 MPerXdl,
305 NPerXdl,
306 K1,
307 MXdlPerWave,
308 NXdlPerWave_,
309 ABlockTransferThreadClusterLengths_K0_M_K1,
310 ABlockTransferThreadClusterArrangeOrder,
311 ABlockTransferSrcAccessOrder,
312 ABlockTransferSrcVectorDim,
313 ABlockTransferSrcScalarPerVector,
314 ABlockTransferDstScalarPerVector_K1,
315 false, // AThreadTransferSrcResetCoordinateAfterRun,
316 ABlockLdsAddExtraM,
320 BBlockTransferThreadClusterLengths_K0_N_K1,
321 BBlockTransferThreadClusterArrangeOrder,
322 BBlockTransferSrcAccessOrder,
323 BBlockTransferSrcVectorDim,
324 BBlockTransferSrcScalarPerVector,
325 BBlockTransferDstScalarPerVector_K1,
326 false, // BThreadTransferSrcResetCoordinateAfterRun,
327 BBlockLdsAddExtraN,
331 CShuffleMXdlPerWavePerShuffle,
332 CShuffleNXdlPerWavePerShuffle,
333 CBlockTransferScalarPerVector_NWaveNPerXdl,
334 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
335 true,
336 true>;
339
340 template <index_t NXdlPerWave_>
342 BlockSize,
343 ADataType, // TODO: distinguish A/B datatype
344 AccDataType,
345 CDataType,
353 MPerBlock,
354 NPerBlock,
355 K0PerBlock,
356 MPerXdl,
357 NPerXdl,
358 K1,
359 MXdlPerWave,
360 NXdlPerWave_,
361 ABlockTransferThreadClusterLengths_K0_M_K1,
362 ABlockTransferThreadClusterArrangeOrder,
363 ABlockTransferSrcAccessOrder,
364 ABlockTransferSrcVectorDim,
365 ABlockTransferSrcScalarPerVector,
366 ABlockTransferDstScalarPerVector_K1,
367 false, // AThreadTransferSrcResetCoordinateAfterRun,
368 ABlockLdsAddExtraM,
372 BBlockTransferThreadClusterLengths_K0_N_K1,
373 BBlockTransferThreadClusterArrangeOrder,
374 BBlockTransferSrcAccessOrder,
375 BBlockTransferSrcVectorDim,
376 BBlockTransferSrcScalarPerVector,
377 BBlockTransferDstScalarPerVector_K1,
378 false, // BThreadTransferSrcResetCoordinateAfterRun,
379 BBlockLdsAddExtraN,
383 CShuffleMXdlPerWavePerShuffle,
384 CShuffleNXdlPerWavePerShuffle,
385 CBlockTransferScalarPerVector_NWaveNPerXdl,
386 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
387 true,
388 true>;
391
392 // Argument
394 decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
395
397 decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
398 struct Argument : public BaseArgument
399 {
400 Argument(const InDataType* p_in_grid,
401 WeiDataType* p_wei_grid,
402 const OutDataType* p_out_grid,
403 ck::index_t N,
404 ck::index_t K,
405 ck::index_t C,
406 std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
407 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
408 std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
409 std::array<ck::index_t, NDimSpatial> conv_filter_strides,
410 std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
411 std::array<ck::index_t, NDimSpatial> input_left_pads,
412 std::array<ck::index_t, NDimSpatial> input_right_pads,
413 ck::index_t M01,
414 ck::index_t N01,
415 InElementwiseOperation in_element_op,
416 WeiElementwiseOperation wei_element_op,
417 OutElementwiseOperation out_element_op,
418 ck::index_t split_k)
419 : p_a_grid_{p_out_grid},
420 p_b_grid_{p_in_grid},
421 p_c_grid_{p_wei_grid},
427 M01_{M01},
428 N01_{N01},
429 a_element_op_{out_element_op},
430 b_element_op_{in_element_op},
431 c_element_op_{wei_element_op},
432 Conv_N_{N},
433 Conv_K_{K},
434 Conv_C_{C},
435 output_spatial_lengths_{output_spatial_lengths},
436 filter_spatial_lengths_{filter_spatial_lengths},
437 conv_filter_strides_{conv_filter_strides},
438 input_left_pads_{input_left_pads},
439 input_right_pads_{input_right_pads},
440 k_batch_{split_k}
441 {
442 const auto descs =
444 K,
445 C,
446 input_spatial_lengths,
447 filter_spatial_lengths,
448 output_spatial_lengths,
449 conv_filter_strides,
450 conv_filter_dilations,
451 input_left_pads,
452 input_right_pads,
453 k_batch_);
454
457 c_grid_desc_m_n_ = descs[I2];
458
460 GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
461
462 if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
466 {
468 GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
469 }
470 }
471
482 InElementwiseOperation a_element_op_;
483 OutElementwiseOperation b_element_op_;
484 WeiElementwiseOperation c_element_op_;
485 // for checking IsSupportedArgument()
489 std::array<index_t, NDimSpatial> output_spatial_lengths_;
490 std::array<index_t, NDimSpatial> filter_spatial_lengths_;
491 std::array<index_t, NDimSpatial> conv_filter_strides_;
492 std::array<index_t, NDimSpatial> input_left_pads_;
493 std::array<index_t, NDimSpatial> input_right_pads_;
495 };
496
497 // Invoker
498 struct Invoker : public BaseInvoker
499 {
501
502 void Print(const Argument& arg)
503 {
504 std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
505 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
506 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
507 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
508 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
509
510 std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
511 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
512 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
513 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
514 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
515
516 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
517 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
518 }
519
520 template <typename GridwiseGemm>
521 float RunImp(const typename GridwiseGemm::Argument& arg,
522 const StreamConfig& stream_config = StreamConfig{})
523 {
524 if(stream_config.log_level_ > 0)
525 {
526 Print(arg);
527 }
528
529 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
530 arg.b_grid_desc_kbatch_k0_n_k1_,
531 arg.c_grid_desc_m_n_,
532 arg.block_2_ctile_map_))
533 {
534 throw std::runtime_error(
535 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting");
536 }
537 const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
538 const index_t grid_size =
539 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
540
541 const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
542
543 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
544
545 float ave_time = 0;
546
547 const auto Run = [&](const auto& kernel) {
548 hipGetErrorString(hipMemsetAsync(
549 arg.p_c_grid_,
550 0,
551 arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
552 sizeof(CDataType),
553 stream_config.stream_id_));
554
555 ave_time =
556 launch_and_time_kernel(stream_config,
557 kernel,
558 dim3(grid_size),
559 dim3(BlockSize),
560 0,
561 arg.p_a_grid_,
562 arg.p_b_grid_,
563 arg.p_c_grid_,
564 arg.a_grid_desc_kbatch_k0_m_k1_,
565 arg.b_grid_desc_kbatch_k0_n_k1_,
566 arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
567 arg.a_element_op_,
568 arg.b_element_op_,
569 arg.c_element_op_,
570 arg.block_2_ctile_map_);
571 };
572
573 if(has_main_k0_block_loop)
574 {
575 if(kbatch == 1)
576 {
577 const auto kernel = kernel_gemm_xdlops_bwd_weight<
578 GridwiseGemm,
579 ADataType, // TODO: distiguish A/B datatype
580 CDataType,
584 OutElementwiseOperation,
585 InElementwiseOperation,
586 WeiElementwiseOperation,
588 true>;
589
590 Run(kernel);
591 }
592 else
593 {
594 const auto kernel = kernel_gemm_xdlops_bwd_weight<
595 GridwiseGemmAtomicAdd,
596 ADataType, // TODO: distiguish A/B datatype
597 CDataType,
601 OutElementwiseOperation,
602 InElementwiseOperation,
603 WeiElementwiseOperation,
605 true>;
606
607 Run(kernel);
608 }
609 }
610 else
611 {
612 if(kbatch == 1)
613 {
614 const auto kernel = kernel_gemm_xdlops_bwd_weight<
615 GridwiseGemm,
616 ADataType, // TODO: distiguish A/B datatype
617 CDataType,
621 OutElementwiseOperation,
622 InElementwiseOperation,
623 WeiElementwiseOperation,
625 false>;
626
627 Run(kernel);
628 }
629 else
630 {
631 const auto kernel = kernel_gemm_xdlops_bwd_weight<
632 GridwiseGemmAtomicAdd,
633 ADataType, // TODO: distiguish A/B datatype
634 CDataType,
638 OutElementwiseOperation,
639 InElementwiseOperation,
640 WeiElementwiseOperation,
642 false>;
643
644 Run(kernel);
645 }
646 }
647
648 return ave_time;
649 }
650
652
653 float Run(const BaseArgument* p_arg,
654 const StreamConfig& stream_config = StreamConfig{}) override
655 {
656 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
657 }
658 };
659
660 static constexpr bool IsValidCompilationParameter()
661 {
662 // TODO: properly implement this check
663 return true;
664 }
665
666 static bool IsSupportedArgument(const Argument& arg)
667 {
669 {
670 return false;
671 }
672 // vector load A/B matrix from global memory
673 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
674 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
675 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
676 {
677 return false;
678 }
679
680 // unmerge N to N0 and N1, where N1 equals to K1
681 if(!(arg.Conv_N_ % K1 == 0))
682 {
683 return false;
684 }
685
686 // vector store C matrix into global memory
687 if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
688 {
689 return false;
690 }
691
692 // Split-K autodeduction is not supported
693 if(arg.k_batch_ < 1)
694 {
695 return false;
696 }
697
698 // Gridwise GEMM size
699 return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
703 }
704
705 bool IsSupportedArgument(const BaseArgument* p_arg) override
706 {
707 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
708 }
709
710 static auto MakeArgument(const InDataType* p_in_grid,
711 WeiDataType* p_wei_grid,
712 const OutDataType* p_out_grid,
713 ck::index_t N,
714 ck::index_t K,
715 ck::index_t C,
716 std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
717 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
718 std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
719 std::array<ck::index_t, NDimSpatial> conv_filter_strides,
720 std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
721 std::array<ck::index_t, NDimSpatial> input_left_pads,
722 std::array<ck::index_t, NDimSpatial> input_right_pads,
723 InElementwiseOperation in_element_op,
724 WeiElementwiseOperation wei_element_op,
725 OutElementwiseOperation out_element_op,
726 ck::index_t split_k)
727 {
728 return Argument{p_in_grid,
729 p_wei_grid,
730 p_out_grid,
731 N,
732 K,
733 C,
734 input_spatial_lengths,
735 filter_spatial_lengths,
736 output_spatial_lengths,
737 conv_filter_strides,
738 conv_filter_dilations,
739 input_left_pads,
740 input_right_pads,
741 1,
742 1,
743 in_element_op,
744 wei_element_op,
745 out_element_op,
746 split_k};
747 }
748
749 static auto MakeInvoker() { return Invoker{}; }
750
751 std::unique_ptr<BaseArgument>
752 MakeArgumentPointer(const void* p_in_grid,
753 void* p_wei_grid,
754 const void* p_out_grid,
755 ck::index_t N,
756 ck::index_t K,
757 ck::index_t C,
758 std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
759 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
760 std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
761 std::array<ck::index_t, NDimSpatial> conv_filter_strides,
762 std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
763 std::array<ck::index_t, NDimSpatial> input_left_pads,
764 std::array<ck::index_t, NDimSpatial> input_right_pads,
765 InElementwiseOperation in_element_op,
766 WeiElementwiseOperation wei_element_op,
767 OutElementwiseOperation out_element_op,
768 ck::index_t split_k) override
769 {
770 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
771 static_cast<WeiDataType*>(p_wei_grid),
772 static_cast<const OutDataType*>(p_out_grid),
773 N,
774 K,
775 C,
776 input_spatial_lengths,
777 filter_spatial_lengths,
778 output_spatial_lengths,
779 conv_filter_strides,
780 conv_filter_dilations,
781 input_left_pads,
782 input_right_pads,
783 1,
784 1,
785 in_element_op,
786 wei_element_op,
787 out_element_op,
788 split_k);
789 }
790
791 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
792 {
793 return std::make_unique<Invoker>(Invoker{});
794 }
795
796 std::string GetTypeString() const override
797 {
798 auto str = std::stringstream();
799
800 // clang-format off
801 str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
802 << "<"
803 << BlockSize << ", "
804 << MPerBlock << ", "
805 << NPerBlock << ", "
806 << K0PerBlock << ", "
807 << K1 << ", "
808 << MPerXDL << ", "
809 << NPerXDL << ", "
810 << MXdlPerWave << ", "
811 << NXdlPerWave << ", "
812 << ABlockTransferSrcScalarPerVector << ", "
813 << ABlockTransferDstScalarPerVector_K1 << ", "
814 << BBlockTransferSrcScalarPerVector << ", "
815 << BBlockTransferDstScalarPerVector_K1 << ", "
816 << CShuffleMXdlPerWavePerShuffle << ", "
817 << CShuffleNXdlPerWavePerShuffle << ", "
818 << CBlockTransferScalarPerVector_NWaveNPerXdl
819 << ">";
820 // clang-format on
821
822 return str.str();
823 }
824};
825
826} // namespace device
827} // namespace tensor_operation
828} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition gridwise_gemm_xdlops_bwd_weight.hpp:157
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_bwd_weight.hpp:254
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:69
static constexpr auto ABlockLdsM1Padding
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:109
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXdl, NPerXdl, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true > GridwiseGemmAtomicAddBase
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:341
static constexpr auto I5
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:95
static constexpr auto I2
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:92
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:660
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:666
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k) override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:752
static constexpr auto NXdlPerWave32
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:77
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:283
static constexpr auto I3
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:93
static constexpr auto N1Number
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:100
GridwiseGemmAtomicAddBase< NXdlPerWave32 > GridwiseGemmAtomicAdd32
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:390
WeiDataType CDataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:81
static constexpr auto K1Number
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:97
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, ck::index_t batch_k)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:116
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXdl, NPerXdl, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true > GridwiseGemmBase
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:289
static constexpr auto I0
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:90
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:338
std::string GetTypeString() const override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:796
static constexpr auto GemmK1Number
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:98
static constexpr auto BBlockLdsN1PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:112
static constexpr auto BBlockLdsN0PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:113
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:72
static auto MakeInvoker()
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:749
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:76
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:393
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:705
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)) ABCGridDescs
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:280
static constexpr auto I1
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:91
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:284
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:285
static constexpr auto ABlockLdsM0PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:108
OutDataType ADataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:79
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:791
static constexpr auto BankLength
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:103
InDataType ABDataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:88
static constexpr auto ABlockLdsM1PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:107
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:710
static constexpr auto I4
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:94
OutElementwiseOperation AElementwiseOperation
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:83
InDataType BDataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:80
static constexpr auto BBlockLdsN1Padding
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:114
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:396
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:337
static constexpr ck::index_t NDimSpatial
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:70
WeiElementwiseOperation CElementwiseOperation
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:85
InElementwiseOperation BElementwiseOperation
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:84
static constexpr auto ElePerBank
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:104
GridwiseGemmAtomicAddBase< math::max(NXdlPerWave64, 1)> GridwiseGemmAtomicAdd64
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:389
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:521
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:653
void Print(const Argument &arg)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:502
DeviceOp::Argument Argument
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:500
index_t N01_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:481
index_t Conv_N_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:486
const ADataType * p_a_grid_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:472
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:493
index_t Conv_K_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:487
WeiElementwiseOperation c_element_op_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:484
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:492
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:478
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:475
index_t k_batch_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:494
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:476
std::array< index_t, NDimSpatial > filter_spatial_lengths_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:490
Block2CTileMap block_2_ctile_map_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:479
index_t M01_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:480
index_t Conv_C_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:488
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:477
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, ck::index_t M01, ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:400
const BDataType * p_b_grid_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:473
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:491
OutElementwiseOperation b_element_op_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:483
InElementwiseOperation a_element_op_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:482
CDataType * p_c_grid_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:474
std::array< index_t, NDimSpatial > output_spatial_lengths_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:489