transform_conv_bwd_data_to_gemm_v1.hpp Source File

transform_conv_bwd_data_to_gemm_v1.hpp Source File#

Composable Kernel: transform_conv_bwd_data_to_gemm_v1.hpp Source File
transform_conv_bwd_data_to_gemm_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14namespace tensor_operation {
15
22#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
23
24template <
25 index_t NDimSpatial,
27 index_t AK1,
28 index_t BK1,
29 index_t GemmMPerBlock,
30 index_t GemmNPerBlock,
31 index_t GemmKPerBlock,
32 bool DoPadGemmM,
33 bool DoPadGemmN,
34 typename ALayout,
35 typename BLayout,
36 typename CLayout,
37 bool SplitN = false,
38 typename ADataType = float,
39 typename CDataType = float,
40 index_t NumGroupsToMerge = 1,
41 typename IndexType = index_t,
42 bool CTranspose = false>
44{
45 private:
46 static constexpr auto I0 = Number<0>{};
47 static constexpr auto I1 = Number<1>{};
48 static constexpr auto I2 = Number<2>{};
49 static constexpr auto I3 = Number<3>{};
50
51 static constexpr auto NonSpatialDimsNum = Number<3>{};
52
53 static constexpr auto DIdx = NonSpatialDimsNum;
54 static constexpr auto HIdx =
55 NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
56 static constexpr auto WIdx =
58
59 static constexpr auto ZIdx = NonSpatialDimsNum;
60 static constexpr auto YIdx =
61 NDimSpatial == 2 ? NonSpatialDimsNum : Number<NonSpatialDimsNum + 1>{};
62 static constexpr auto XIdx =
64
65 template <typename ConvDimsType>
66 static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
67 const ConvDimsType& strides,
68 index_t i)
69 {
70 long_index_t acc = 1;
71 for(; i < (NDimSpatial + 3); i++)
72 {
73 acc +=
74 static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
75 }
76
77 return acc;
78 }
79
80 template <typename ConvDimsType>
81 static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths,
82 const ConvDimsType& a_g_n_k_wos_strides,
83 const ConvDimsType& c_g_n_c_wis_lengths,
84 const ConvDimsType& c_g_n_c_wis_strides)
85 {
86 const long_index_t a_element_space_size =
87 calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1);
88 const long_index_t c_element_space_size =
89 calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1);
90 const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
91 c_element_space_size * sizeof(CDataType));
92 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
93
94 const IndexType N = a_g_n_k_wos_lengths[I1];
95
96 if(element_space_size > TwoGB)
97 {
98 // Minimum divisor of N to not exceed 2GB
99 const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
100
101 if(divisor <= static_cast<double>(N))
102 {
103 // Find least divisor of N larger than element_space_size / TwoGB
104 // Iterate up to sqrt(N). There are no divisors above this value.
105 for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
106 least_divisor++)
107 {
108 if(N % least_divisor == 0)
109 {
110 return N / least_divisor;
111 }
112 }
113 // Not found, process one Convolution N per block
114 return 1;
115 }
116 else
117 {
118 // Split Convolution's N dimension into N workgroups. However
119 // this still might not result in sufficiently small tensor,
120 // but at least later on we could divide the image as well.
121 return 1;
122 }
123 }
124 else
125 {
126 // Split N is not needed.
127 return N;
128 }
129 }
130
131 public:
132 __host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {}
133
134 template <typename TransformConvBwdDataToGemm_v1Base>
135 __host__ __device__ TransformConvBwdDataToGemm_v1(
136 const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base)
137 : N_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.N_)},
138 Di_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Di_)},
139 Hi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Hi_)},
140 Wi_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wi_)},
141 Do_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Do_)},
142 Ho_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Ho_)},
143 Wo_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Wo_)},
144 Z_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Z_)},
145 Y_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.Y_)},
146 X_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.X_)},
147 K_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.K_)},
148 C_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.C_)},
149 DiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DiStride_)},
150 HiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HiStride_)},
151 WiStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WiStride_)},
152 DoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DoStride_)},
153 HoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HoStride_)},
154 WoStride_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WoStride_)},
156 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)},
158 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)},
160 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)},
162 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)},
164 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)},
166 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)},
167 ConvStrideD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)},
168 ConvStrideH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)},
169 ConvStrideW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)},
171 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)},
173 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)},
175 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)},
176 InLeftPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)},
177 InLeftPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)},
178 InLeftPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)},
179 InRightPadD_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadD_)},
180 InRightPadH_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadH_)},
181 InRightPadW_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.InRightPadW_)},
182 IdxZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)},
183 IdxYTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)},
184 IdxXTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)},
186 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)},
188 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)},
190 static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)},
191 ZTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZTilde_)},
192 YTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YTilde_)},
193 XTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XTilde_)},
194 DTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.DTilde_)},
195 HTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.HTilde_)},
196 WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
197 ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
198 YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
199 XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)},
200 batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_}
201 {
202 }
203
204 template <typename ConvDimsType, typename ConvSpatialDimsType>
205 __host__ __device__
206 TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths,
207 const ConvDimsType& a_g_n_k_wos_strides,
208 const ConvDimsType& b_g_k_c_xs_lengths,
209 const ConvDimsType& b_g_k_c_xs_strides,
210 const ConvDimsType& c_g_n_c_wis_lengths,
211 const ConvDimsType& c_g_n_c_wis_strides,
212 const ConvSpatialDimsType& conv_filter_strides,
213 const ConvSpatialDimsType& conv_filter_dilations,
214 const ConvSpatialDimsType& input_left_pads,
215 const ConvSpatialDimsType& input_right_pads,
216 const ConvSpatialDimsType& tildes,
217 const index_t batch_k = 1)
218 : Hi_{c_g_n_c_wis_lengths[HIdx]},
219 Wi_{c_g_n_c_wis_lengths[WIdx]},
220 Ho_{a_g_n_k_wos_lengths[HIdx]},
221 Wo_{a_g_n_k_wos_lengths[WIdx]},
222 Y_{b_g_k_c_xs_lengths[YIdx]},
223 X_{b_g_k_c_xs_lengths[XIdx]},
224 K_{a_g_n_k_wos_lengths[I2]},
225 C_{b_g_k_c_xs_lengths[I2]},
226 HiStride_{c_g_n_c_wis_strides[HIdx]},
227 WiStride_{c_g_n_c_wis_strides[WIdx]},
228 HoStride_{a_g_n_k_wos_strides[HIdx]},
229 WoStride_{a_g_n_k_wos_strides[WIdx]},
230 CStrideTensorB_{b_g_k_c_xs_strides[I2]},
231 CStrideTensorC_{c_g_n_c_wis_strides[I2]},
232 KStrideTensorA_{a_g_n_k_wos_strides[I2]},
233 KStrideTensorB_{b_g_k_c_xs_strides[I1]},
234 NStrideTensorA_{a_g_n_k_wos_strides[I1]},
235 NStrideTensorC_{c_g_n_c_wis_strides[I1]},
236 ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]},
237 ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]},
238 ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]},
239 ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]},
240 InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]},
241 InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]},
242 InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
243 InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
244 IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
245 IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]},
246 batch_k_{batch_k}
247 {
252
253 if constexpr(SplitN)
254 {
255 N_ = GetSplitedNSize(
256 a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides);
257 }
258 else
259 {
260 N_ = c_g_n_c_wis_lengths[I1];
261 }
262 if constexpr(NDimSpatial == 3)
263 {
264 Di_ = c_g_n_c_wis_lengths[DIdx];
265 Do_ = a_g_n_k_wos_lengths[DIdx];
266 Z_ = b_g_k_c_xs_lengths[ZIdx];
267 DiStride_ = c_g_n_c_wis_strides[DIdx];
268 DoStride_ = a_g_n_k_wos_strides[DIdx];
269 ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum];
270 ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum];
271 InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum];
272 InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum];
273 IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum];
278 }
279 else
280 {
281 Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1;
283 }
284
287
290
293
296 }
297
298#if 0 // At now not supported to split tensor
299 __host__ bool AreDescriptorsSmallerThan2GB() const
300 {
301 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
302
303 const long_index_t in_desc_space_size =
304 I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
305 (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
306 const long_index_t out_desc_space_size =
307 I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
308 (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
309
310 bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
311 bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
312
313 return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
314 }
315
316 __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
317 CDataType* c_grid_ptr_base) const
318 {
319 // Create copies
320 auto conv_to_gemm_transformer_left = *this;
321 auto conv_to_gemm_transformer_right = *this;
322 IndexType a_right_offset = 0;
323 IndexType c_right_offset = 0;
324 // Calculate real filter size
325 const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
326 const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
327 const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
328 // Calculate start position in input for right tensor
329 const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
330 const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
331 const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
332 // Calculate last position in input for left tensor
333 const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
334 const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
335 const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
336 // Allow to split if whole left padding will be in left tensor and right padding in right
337 // tensor
338 const bool is_possible_to_split_d = Do_ != 1 &&
339 di_right_transformer_start_idx > InLeftPadD_ &&
340 di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
341 const bool is_possible_to_split_h = Ho_ != 1 &&
342 hi_right_transformer_start_idx > InLeftPadH_ &&
343 hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
344 const bool is_possible_to_split_w = Wo_ != 1 &&
345 wi_right_transformer_start_idx > InLeftPadW_ &&
346 wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
347
348 if(is_possible_to_split_d)
349 {
350 // Apply new sizes
351 // Split output on half
352 conv_to_gemm_transformer_left.Do_ = Do_ / 2;
353 conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
354 // Assign left padding to left convolution
355 conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
356 conv_to_gemm_transformer_right.InLeftPadD_ = 0;
357 // Assign right padding to right convolution
358 conv_to_gemm_transformer_left.InRightPadD_ = 0;
359 conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
360 // Calculate new input size
361 conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
362 conv_to_gemm_transformer_right.Di_ =
363 math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
364 (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
365 ;
366 // Calcualte offsets
367 a_right_offset = (Do_ / 2) * DoStride_;
368 c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
369 }
370 else if(is_possible_to_split_h)
371 {
372 conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
373 conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
374
375 conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
376 conv_to_gemm_transformer_right.InLeftPadH_ = 0;
377
378 conv_to_gemm_transformer_left.InRightPadH_ = 0;
379 conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
380
381 conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
382 conv_to_gemm_transformer_right.Hi_ =
383 math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
384 (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
385 a_right_offset = (Ho_ / 2) * HoStride_;
386 c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
387 }
388 else if(is_possible_to_split_w)
389 {
390 conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
391 conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
392
393 conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
394 conv_to_gemm_transformer_right.InLeftPadW_ = 0;
395
396 conv_to_gemm_transformer_left.InRightPadW_ = 0;
397 conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
398
399 conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
400 conv_to_gemm_transformer_right.Wi_ =
401 math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
402 (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
403
404 a_right_offset = (Wo_ / 2) * WoStride_;
405 c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
406 }
407 // Return left transform, right transformer, right offset to Input and right offset to
408 // Output
409 return ck::make_tuple(conv_to_gemm_transformer_left,
410 conv_to_gemm_transformer_right,
411 a_grid_ptr_base + a_right_offset,
412 c_grid_ptr_base + c_right_offset);
413 }
414
415 __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
416 CDataType* c_grid_ptr_base) const
417 {
418 // Create copies
419 auto conv_to_gemm_transformer_left = *this;
420 auto conv_to_gemm_transformer_right = *this;
421 IndexType a_right_offset = 0;
422 IndexType c_right_offset = 0;
423
424 // Calculate start position in input for right tensor
425 const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
426 const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
427 const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
428 // Calculate last position in input for left tensor
429 const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
430 const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
431 const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
432
433
434 if(Di_!=1)
435 {
436 // Apply new sizes
437 // Split output on half
438 conv_to_gemm_transformer_left.Di_ = Di_ / 2;
439 conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
440 // Assign left padding to left convolution
441 conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
442 conv_to_gemm_transformer_right.InLeftPadD_ = 0;
443 // // Assign right padding to right convolution
444 conv_to_gemm_transformer_left.InRightPadD_ = 0;
445 conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
446 // Calculate new input size
447 conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
448 conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
449 ;
450 // Calcualte offsets
451 a_right_offset = do_right_transformer_start_idx * DoStride_;
452 c_right_offset = (Di_ / 2) * DiStride_;
453 }
454 else if(Hi_!=1)
455 {
456 // Apply new sizes
457 // Split output on half
458 conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
459 conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
460 // Assign left padding to left convolution
461 conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
462 conv_to_gemm_transformer_right.InLeftPadH_ = 0;
463 // // Assign right padding to right convolution
464 conv_to_gemm_transformer_left.InRightPadH_ = 0;
465 conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
466 // Calculate new input size
467 conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
468 conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
469 ;
470 // Calcualte offsets
471 a_right_offset = ho_right_transformer_start_idx * HoStride_;
472 c_right_offset = (Hi_ / 2) * HiStride_;
473 }
474 else if(Wi_!=1)
475 {
476 // Apply new sizes
477 // Split output on half
478 conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
479 conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
480 // Assign left padding to left convolution
481 conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
482 conv_to_gemm_transformer_right.InLeftPadW_ = 0;
483 // Assign right padding to right convolution
484 conv_to_gemm_transformer_left.InRightPadW_ = 0;
485 conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
486 // Calculate new input size
487 conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
488 conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
489 ;
490 // Calcualte offsets
491 a_right_offset = wo_right_transformer_start_idx * WoStride_;
492 c_right_offset = (Wi_ / 2) * WiStride_;
493 }
494 // Return left transform, right transformer, right offset to Input and right offset to
495 // Output
496 return ck::make_tuple(conv_to_gemm_transformer_left,
497 conv_to_gemm_transformer_right,
498 a_grid_ptr_base + a_right_offset,
499 c_grid_ptr_base + c_right_offset);
500 }
501#endif
502
503 __host__ __device__ auto MakeOutGridDesc() const
504 {
506 {
507 if constexpr(ConvBwdDataSpecialization ==
508 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
509 Filter1x1Stride1Pad0)
510 {
511
514 }
515 else
516 {
518 make_tuple(N_, Ho_, Wo_, K_),
520 }
521 }
523 {
524 if constexpr(ConvBwdDataSpecialization ==
525 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
526 Filter1x1Stride1Pad0)
527 {
528
531 }
532 else
533 {
535 make_tuple(N_, Do_, Ho_, Wo_, K_),
537 }
538 }
540 {
541 // assume packed
542 if constexpr(ConvBwdDataSpecialization ==
543 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
544 Filter1x1Stride1Pad0)
545 {
547 }
548 else
549 {
551 }
552 }
554 {
555 // assume packed
556 if constexpr(ConvBwdDataSpecialization ==
557 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
558 Filter1x1Stride1Pad0)
559 {
561 }
562 else
563 {
565 }
566 }
568 {
569 // assume packed
570 static_assert(ConvBwdDataSpecialization ==
571 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
572 Filter1x1Stride1Pad0);
573
574 const auto out_gemm_raw_grid_desc = make_naive_tensor_descriptor(
576
578 out_gemm_raw_grid_desc,
583 }
585 {
586 // assume packed
587 static_assert(ConvBwdDataSpecialization ==
588 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
589 Filter1x1Stride1Pad0);
590
591 const auto out_gemm_raw_grid_desc =
594
596 out_gemm_raw_grid_desc,
601 }
602 else
603 {
604 throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
605 }
606 }
607
608 __host__ __device__ auto MakeWeiGridDesc() const
609 {
610
612 {
614 }
616 {
618 }
619 else
620 {
621 throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
622 }
623 }
624
648
649 template <
650 typename ALayout_ = ALayout,
651 typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
658 bool>::type = false>
659 __host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
660 {
661 // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
662 const auto out_grid_desc = MakeOutGridDesc();
663
664 if constexpr(ConvBwdDataSpecialization ==
665 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
666 Filter1x1Stride1Pad0)
667 {
668 const index_t K0PerBlock = GemmKPerBlock / AK1;
669 const index_t AK0 =
670 math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock;
671
672 // A: output tensor
673 const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
674 out_grid_desc,
679
680 const auto out_gemmak0_gemmm_gemmak1_grid_desc =
682 out_gemmak0_gemmmraw_gemmak1_grid_desc,
683 make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
685
686 return out_gemmak0_gemmm_gemmak1_grid_desc;
687 }
688 else
689 {
690 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
691 const auto IDTildeSliceBegin = math::integer_divide_floor(
693 const auto IHTildeSliceBegin = math::integer_divide_floor(
695 const auto IWTildeSliceBegin = math::integer_divide_floor(
697
698 const auto IDTildeSliceEnd = math::min(
700 const auto IHTildeSliceEnd = math::min(
702 const auto IWTildeSliceEnd = math::min(
704
705 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
706 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
707 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
708
709 // GemmK is different for each GEMM
710 const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
711 const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
712 const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
713
714 if constexpr(NDimSpatial == 2)
715 {
716 const index_t K0PerBlock = GemmKPerBlock / AK1;
717 const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
718 AK1 * K0PerBlock * batch_k_) *
719 K0PerBlock;
720
721#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
722 // A: output tensor
723 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
724 out_grid_desc,
726 make_pad_transform(Ho_, I0, I0),
727 make_pad_transform(Wo_, I0, I0),
731
732 const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
733 out_n_hop_wop_k_grid_desc,
743
744 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
746 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
748 make_slice_transform(YDot_, I0, YDotSlice),
749 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
750 make_slice_transform(XDot_, I0, XDotSlice),
751 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
754 Sequence<1>{},
755 Sequence<2>{},
756 Sequence<3>{},
757 Sequence<4>{},
758 Sequence<5>{}),
760 Sequence<1>{},
761 Sequence<2>{},
762 Sequence<3>{},
763 Sequence<4>{},
764 Sequence<5>{}));
765
766 const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
767 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
768 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
769 make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
772
773 const auto out_gemmk_gemmm_padded_grid_desc =
775 out_gemmk_gemmmraw_grid_desc,
776 make_tuple(GemmKPerBlock, GemmMPerBlock),
778
779 const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
780 out_gemmk_gemmm_padded_grid_desc,
783 out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
786 return out_gemmak0_gemmm_gemmak1_grid_desc;
787#else
788 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
789 out_grid_desc,
791 make_pad_transform(Ho_, I0, I0),
792 make_pad_transform(Wo_, I0, I0),
796
797 const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
798 out_n_hop_wop_k_grid_desc,
800 Ho_,
801 Wo_,
802 K_,
803 YDot_,
804 XDot_,
805 HTilde_,
806 WTilde_,
809 HTildeSlice,
810 WTildeSlice,
811 YDotSlice,
812 XDotSlice,
813 IHTildeSliceBegin,
814 IWTildeSliceBegin,
817 AK0 * batch_k_,
818 AK1,
819 GemmMPerBlock,
820 GemmKPerBlock)),
823
824 return out_n_hop_wop_k_grid_desc_final;
825#endif
826 }
827 else if constexpr(NDimSpatial == 3)
828 {
829 // A: output tensor
830 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
831 out_grid_desc,
833 make_pad_transform(Do_, I0, I0),
834 make_pad_transform(Ho_, I0, I0),
835 make_pad_transform(Wo_, I0, I0),
841
842 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
844 out_n_hop_wop_k_grid_desc,
857 Sequence<1>{},
858 Sequence<2>{},
859 Sequence<3>{},
860 Sequence<4>{}),
865 Sequence<7>{}));
866
867 const auto
868 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
870 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
873 make_slice_transform(ZDot_, I0, ZDotSlice),
874 make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
875 make_slice_transform(YDot_, I0, YDotSlice),
876 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
877 make_slice_transform(XDot_, I0, XDotSlice),
878 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
881 Sequence<1>{},
882 Sequence<2>{},
883 Sequence<3>{},
884 Sequence<4>{},
885 Sequence<5>{},
886 Sequence<6>{},
887 Sequence<7>{}),
889 Sequence<1>{},
890 Sequence<2>{},
891 Sequence<3>{},
892 Sequence<4>{},
893 Sequence<5>{},
894 Sequence<6>{},
895 Sequence<7>{}));
896
897 const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
898 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
900 make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
902 make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
905
906 const auto out_gemmk_gemmm_padded_grid_desc =
908 out_gemmk_gemmmraw_grid_desc,
909 make_tuple(GemmKPerBlock, GemmMPerBlock),
911
912 const index_t K0PerBlock = GemmKPerBlock / AK1;
913 const index_t AK0 =
914 math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
915 AK1 * K0PerBlock * batch_k_) *
916 K0PerBlock;
917
918 const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
919 out_gemmk_gemmm_padded_grid_desc,
922 out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
925
926 return out_gemmak0_gemmm_gemmak1_grid_desc;
927 }
928 else
929 {
930 throw std::runtime_error("wrong! only implemented for 2D and 3D now");
931 }
932 }
933 }
934
935 template <
936 typename BLayout_ = BLayout,
937 typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
942 bool>::type = false>
943 __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
944 {
945
946 if constexpr(ConvBwdDataSpecialization ==
947 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
948 Filter1x1Stride1Pad0)
949 {
950 const index_t K0PerBlock = GemmKPerBlock / BK1;
951 const index_t BK0 =
952 math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
953
954 // B: weight tensor
955 const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
962
963 const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
965 wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
966 make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
968
969 return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
970 }
971 else
972 {
973 // assume packed
974 // k_y_x_c for 2d or k_z_y_x_c for 3d
977 const auto wei_grid_desc = MakeWeiGridDesc();
978
979 // GemmK is different for each GEMM
980 const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
981 const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
982 const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
983
984 // B weight tensor
985 if constexpr(NDimSpatial == 2)
986 {
987 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
988 wei_grid_desc,
998
999 const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
1000 wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
1002 make_slice_transform(YDot_, I0, YDotSlice),
1003 make_slice_transform(XDot_, I0, XDotSlice),
1008 Sequence<1>{},
1009 Sequence<3>{},
1010 Sequence<2>{},
1011 Sequence<4>{},
1012 Sequence<5>{}),
1014 Sequence<1>{},
1015 Sequence<2>{},
1016 Sequence<>{},
1017 Sequence<>{},
1018 Sequence<3>{}));
1019
1020 const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
1021 wei_k_ydotslice_xdotslice_c_grid_desc,
1022 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
1026
1027 const auto wei_gemmk_gemmn_padded_grid_desc =
1029 wei_gemmk_gemmnraw_grid_desc,
1030 make_tuple(GemmKPerBlock, GemmNPerBlock),
1032
1033 const index_t K0PerBlock = GemmKPerBlock / BK1;
1034 const index_t BK0 =
1035 math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
1036 BK1 * K0PerBlock * batch_k_) *
1037 K0PerBlock;
1038
1039 const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
1040 wei_gemmk_gemmn_padded_grid_desc,
1043 wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
1046
1047 return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
1048 }
1049 else if constexpr(NDimSpatial == 3)
1050 {
1051 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
1053 wei_grid_desc,
1066 Sequence<1>{},
1067 Sequence<2>{},
1068 Sequence<3>{},
1069 Sequence<4>{}),
1074 Sequence<7>{}));
1075
1076 const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
1078 wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
1080 make_slice_transform(ZDot_, I0, ZDotSlice),
1081 make_slice_transform(YDot_, I0, YDotSlice),
1082 make_slice_transform(XDot_, I0, XDotSlice),
1088 Sequence<1>{},
1089 Sequence<3>{},
1090 Sequence<5>{},
1091 Sequence<2>{},
1092 Sequence<4>{},
1093 Sequence<6>{},
1094 Sequence<7>{}),
1096 Sequence<1>{},
1097 Sequence<2>{},
1098 Sequence<3>{},
1099 Sequence<>{},
1100 Sequence<>{},
1101 Sequence<>{},
1102 Sequence<4>{}));
1103
1104 const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
1105 wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
1106 make_tuple(
1107 make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
1111
1112 const auto wei_gemmk_gemmn_padded_grid_desc =
1114 wei_gemmk_gemmnraw_grid_desc,
1115 make_tuple(GemmKPerBlock, GemmNPerBlock),
1117
1118 const index_t K0PerBlock = GemmKPerBlock / BK1;
1119 const index_t BK0 =
1120 math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
1121 BK1 * K0PerBlock * batch_k_) *
1122 K0PerBlock;
1123
1124 const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
1125 wei_gemmk_gemmn_padded_grid_desc,
1128 wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
1131
1132 return wei_gemmbk0_gemm_gemmbk1_grid_desc;
1133 }
1134 else
1135 {
1136 throw std::runtime_error("wrong! only implemented for 2D and 3D now");
1137 }
1138 }
1139 }
1140
1141 template <
1142 typename CLayout_ = CLayout,
1143 typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
1149 bool>::type = false>
1150 __host__ __device__ auto MakeCDescriptor_M_N() const
1151 {
1152 static_assert(CTranspose == false);
1153 // assume strided
1154 // n_hi_wi_c for 2d n_di_hi_wi_c for 3d
1155 const auto in_grid_desc = MakeInGridDesc();
1156
1157 if constexpr(ConvBwdDataSpecialization ==
1158 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
1159 Filter1x1Stride1Pad0)
1160 {
1161 // C: input tensor
1162 if constexpr(NDimSpatial == 2)
1163 {
1164 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
1165 in_grid_desc,
1166 make_tuple(
1173
1174 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1175 in_n_y_ho_x_wo_c_grid_desc,
1182
1183 const auto in_gemmm_gemmn_grid_desc =
1185 in_gemmmraw_gemmnraw_grid_desc,
1186 make_tuple(GemmMPerBlock, GemmNPerBlock),
1188
1189 return in_gemmm_gemmn_grid_desc;
1190 }
1191 else if constexpr(NDimSpatial == 3)
1192 {
1193
1194 // C: input tensor
1195 const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
1196 in_grid_desc,
1197 make_tuple(
1203 make_tuple(
1209 Sequence<7>{}));
1210
1211 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1212 in_n_x_do_y_ho_x_wo_c_grid_desc,
1219 Sequence<3>{},
1220 Sequence<5>{},
1222 Sequence<7>{}),
1223 make_tuple(
1225
1226 const auto in_gemmm_gemmn_grid_desc =
1228 in_gemmmraw_gemmnraw_grid_desc,
1229 make_tuple(GemmMPerBlock, GemmNPerBlock),
1231
1232 return in_gemmm_gemmn_grid_desc;
1233 }
1234 else
1235 {
1236 throw std::runtime_error("wrong! only implemented for 2D and 3D now");
1237 }
1238 }
1239 else
1240 {
1241 // only work on DTilde, HTilde and WTilde that contribute to
1242 // non-padding area of input tensor
1243 const auto IDTildeSliceBegin = math::integer_divide_floor(
1245 const auto IHTildeSliceBegin = math::integer_divide_floor(
1247 const auto IWTildeSliceBegin = math::integer_divide_floor(
1249
1250 const auto IDTildeSliceEnd = math::min(
1252 const auto IHTildeSliceEnd = math::min(
1254 const auto IWTildeSliceEnd = math::min(
1256
1257 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
1258 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
1259 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
1260
1261 // C: input tensor
1262 if constexpr(NDimSpatial == 2)
1263 {
1264 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
1265 in_grid_desc,
1272
1273 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
1275 in_n_hip_wip_c_grid_desc,
1283 make_tuple(
1285
1286 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
1287 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
1290 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
1292 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
1295 Sequence<1>{},
1296 Sequence<2>{},
1297 Sequence<3>{},
1298 Sequence<4>{},
1299 Sequence<5>{}),
1301 Sequence<>{},
1302 Sequence<1>{},
1303 Sequence<>{},
1304 Sequence<2>{},
1305 Sequence<3>{}));
1306
1307 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1308 in_n_htildeslice_wtildeslice_c_grid_desc,
1309 make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
1313
1314 const auto in_gemmm_gemmn_grid_desc =
1316 in_gemmmraw_gemmnraw_grid_desc,
1317 make_tuple(GemmMPerBlock, GemmNPerBlock),
1319
1320 return in_gemmm_gemmn_grid_desc;
1321 }
1322 else if(NDimSpatial == 3)
1323 {
1324 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
1325 in_grid_desc,
1331 make_tuple(
1333 make_tuple(
1335
1336 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
1338 in_n_dip_hip_wip_c_grid_desc,
1348 Sequence<1>{},
1349 Sequence<2>{},
1350 Sequence<3>{},
1351 Sequence<4>{}),
1356 Sequence<7>{}));
1357
1358 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
1360 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
1363 make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
1365 make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
1367 make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
1370 Sequence<1>{},
1371 Sequence<2>{},
1372 Sequence<3>{},
1373 Sequence<4>{},
1374 Sequence<5>{},
1375 Sequence<6>{},
1376 Sequence<7>{}),
1378 Sequence<>{},
1379 Sequence<1>{},
1380 Sequence<>{},
1381 Sequence<2>{},
1382 Sequence<>{},
1383 Sequence<3>{},
1384 Sequence<4>{}));
1385
1386 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1387 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
1388 make_tuple(
1389 make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
1393
1394 const auto in_gemmm_gemmn_grid_desc =
1396 in_gemmmraw_gemmnraw_grid_desc,
1397 make_tuple(GemmMPerBlock, GemmNPerBlock),
1399 return in_gemmm_gemmn_grid_desc;
1400 }
1401 else
1402 {
1403 throw std::runtime_error("wrong! only implemented for 2D and 3D now");
1404 }
1405 }
1406 }
1407
1408 template <typename CLayout_ = CLayout,
1409 typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
1412 bool>::type = false>
1413 __host__ __device__ auto MakeCDescriptor_M_N() const
1414 {
1415 const auto in_grid_desc = make_naive_tensor_descriptor(
1417
1418 static_assert(ConvBwdDataSpecialization ==
1419 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
1420 Filter1x1Stride1Pad0);
1421
1422 if constexpr(CTranspose)
1423 {
1424 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1425 in_grid_desc,
1431 in_gemmmraw_gemmnraw_grid_desc,
1432 make_tuple(GemmNPerBlock, GemmMPerBlock),
1434 }
1435 else
1436 {
1437 const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
1438 in_grid_desc,
1443
1445 in_gemmmraw_gemmnraw_grid_desc,
1446 make_tuple(GemmMPerBlock, GemmNPerBlock),
1448 }
1449 }
1450 // for input bias
1451 template <typename CLayout_ = CLayout,
1452 typename std::enable_if<NDimSpatial == 2 &&
1455 bool>::type = false>
1456 __host__ __device__ auto MakeCDescriptor_M_N() const
1457 {
1458 if constexpr(ConvBwdDataSpecialization ==
1459 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
1460 Filter1x1Stride1Pad0)
1461 {
1462 if constexpr(CTranspose)
1463 {
1464 const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
1465 make_tuple(C_, N_ * Ho_ * Wo_), make_tuple(I1, I0));
1466
1467 return in_gemmm_gemmn_grid_desc;
1468 }
1469 else
1470 {
1471 const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
1472 make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
1473
1474 return in_gemmm_gemmn_grid_desc;
1475 }
1476 }
1477 else
1478 {
1479 static_assert(CTranspose == false);
1480 // only work on HTilde and WTilde that contribute to non-padding area of input
1481 // tensor
1482 const auto IHTildeSliceBegin = math::integer_divide_floor(
1484 const auto IWTildeSliceBegin = math::integer_divide_floor(
1486
1487 const auto IHTildeSliceEnd = math::min(
1489 const auto IWTildeSliceEnd = math::min(
1491
1492 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
1493 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
1494
1495 // bias tensor
1496 const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor(
1497 make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1));
1498
1499 const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
1500 in_gemmmraw_gemmnraw_grid_desc,
1501 make_tuple(GemmMPerBlock, GemmNPerBlock),
1503
1504 return in_gemmm_gemmn_grid_desc;
1505 }
1506 }
1507
1508 IndexType N_;
1509 IndexType Di_, Hi_, Wi_;
1510 IndexType Do_, Ho_, Wo_;
1511 IndexType Z_, Y_, X_;
1512 IndexType K_, C_;
1525 IndexType ZDot_, YDot_, XDot_;
1527};
1528
1529} // namespace tensor_operation
1530} // namespace ck
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N, index_t Ho, index_t Wo, index_t K, index_t YDot, index_t XDot, index_t HTilde, index_t WTilde, index_t ConvDilationH, index_t ConvDilationW, index_t HTildeSlice, index_t WTildeSlice, index_t YDotSlice, index_t XDotSlice, index_t IHTildeSliceBegin, index_t IWTildeSliceBegin, index_t GcdStrideDilationH, index_t GcdStrideDilationW, index_t K0, index_t K1, index_t MPerBlock, index_t GemmKPerBlock)
Definition multi_index_transform_helper.hpp:97
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:659
__host__ __device__ TransformConvBwdDataToGemm_v1(const TransformConvBwdDataToGemm_v1Base &transform_conv_bwd_data_to_gemm_base)
Definition transform_conv_bwd_data_to_gemm_v1.hpp:135
__host__ __device__ auto MakeWeiGridDesc() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:608
__host__ __device__ auto MakeInGridDesc() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:625
__host__ __device__ TransformConvBwdDataToGemm_v1(const ConvDimsType &a_g_n_k_wos_lengths, const ConvDimsType &a_g_n_k_wos_strides, const ConvDimsType &b_g_k_c_xs_lengths, const ConvDimsType &b_g_k_c_xs_strides, const ConvDimsType &c_g_n_c_wis_lengths, const ConvDimsType &c_g_n_c_wis_strides, const ConvSpatialDimsType &conv_filter_strides, const ConvSpatialDimsType &conv_filter_dilations, const ConvSpatialDimsType &input_left_pads, const ConvSpatialDimsType &input_right_pads, const ConvSpatialDimsType &tildes, const index_t batch_k=1)
Definition transform_conv_bwd_data_to_gemm_v1.hpp:206
__host__ __device__ auto MakeOutGridDesc() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:503
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:943
__host__ __device__ constexpr TransformConvBwdDataToGemm_v1()
Definition transform_conv_bwd_data_to_gemm_v1.hpp:132
__host__ __device__ auto MakeCDescriptor_M_N() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:1150