device_batchnorm_backward_impl.hpp Source File

device_batchnorm_backward_impl.hpp Source File#

Composable Kernel: device_batchnorm_backward_impl.hpp Source File
device_batchnorm_backward_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename XDataType,
26 typename DxDataType,
27 typename DyDataType,
28 typename AccDataType,
29 typename ScaleDataType,
30 typename DscaleDbiasDataType,
31 typename MeanVarDataType,
32 typename DyElementwiseOp,
33 index_t Rank,
34 index_t NumBatchNormReduceDim,
35 bool UseMultiblockInK,
36 index_t BlockSize,
37 index_t MThreadClusterSize,
38 index_t KThreadClusterSize,
39 index_t MThreadSliceSize,
40 index_t KThreadSliceSize,
41 index_t XDyDxVectorDim,
42 index_t XSrcVectorSize,
43 index_t DySrcVectorSize,
44 index_t DxDstVectorSize,
45 index_t ScaleSrcVectorSize,
46 index_t DscaleDbiasDstVectorSize,
47 index_t MeanVarSrcVectorSize>
49 DxDataType,
50 DyDataType,
51 AccDataType,
52 ScaleDataType,
53 DscaleDbiasDataType,
54 MeanVarDataType,
55 DyElementwiseOp,
56 Rank,
57 NumBatchNormReduceDim>
58{
59 static_assert(Rank <= 6, "Bigger Rank size is not supported!");
60 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
61 "Invalid thread cluster size assignments!");
62
63 static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
64 MThreadSliceSize % DySrcVectorSize == 0 &&
65 MThreadSliceSize % DxDstVectorSize == 0) ||
66 (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
67 KThreadSliceSize % DySrcVectorSize == 0 &&
68 KThreadSliceSize % DxDstVectorSize == 0),
69 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
70
71 static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
72
73 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
74 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
75
76 static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
77 const std::array<index_t, Rank>& xyStrides,
78 int blkGroupSize,
79 int numBlockTileIteration)
80 {
81 const auto tupleXYLengths =
82 generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
83 const auto tupleXYStrides =
84 generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
85
86 const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
87
88 const auto grid_desc_m_k = [&]() {
89 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
91
92 const auto reduceDimLengths =
93 generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
95 const auto invariantDimLengths =
96 generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
97
98 return transform_tensor_descriptor(raw_grid_desc,
99 make_tuple(make_merge_transform(invariantDimLengths),
100 make_merge_transform(reduceDimLengths)),
101 make_tuple(InvariantDims{}, ReduceDims{}),
103 }();
104
105 const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
106 const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
107
108 const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
109 const auto mPad =
110 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
111 const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
112
113 auto grid_desc_m_k_padded =
114 transform_tensor_descriptor(grid_desc_m_k,
115 make_tuple(make_right_pad_transform(invariantLength, mPad),
116 make_right_pad_transform(reduceLength, kPad)),
119
120 return (grid_desc_m_k_padded);
121 };
122
123 static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
124 {
125 const auto grid_desc_m_g =
126 make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
127
128 const auto mPad =
129 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
130
131 auto grid_desc_m_g_padded =
132 transform_tensor_descriptor(grid_desc_m_g,
133 make_tuple(make_right_pad_transform(invariantLength, mPad),
134 make_pass_through_transform(blkGroupSize)),
137
138 return (grid_desc_m_g_padded);
139 };
140
141 static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
142 {
143 const auto reduceLength = blkGroupSize;
144 const auto grid_desc_m_k =
145 make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
146
147 const auto mPad =
148 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
149 const auto kPad =
150 math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
151
152 auto grid_desc_m_k_padded =
153 transform_tensor_descriptor(grid_desc_m_k,
154 make_tuple(make_right_pad_transform(invariantLength, mPad),
155 make_right_pad_transform(reduceLength, kPad)),
158
159 return (grid_desc_m_k_padded);
160 };
161
162 static auto
163 MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
164 const std::array<index_t, NumInvariantDim>& strides)
165 {
166 const auto tupleLengths =
167 generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
168 const auto tupleStrides =
169 generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
170
171 auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
172
173 auto grid_desc_m = transform_tensor_descriptor(
174 raw_grid_desc,
175 make_tuple(make_merge_transform(tupleLengths)),
178
179 const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
180
181 const auto mPad =
182 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
183
184 auto grid_desc_m_padded =
185 transform_tensor_descriptor(grid_desc_m,
186 make_tuple(make_right_pad_transform(invariantLength, mPad)),
189 return (grid_desc_m_padded);
190 };
191
192 using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
195
196 struct Argument : public BaseArgument
197 {
198 Argument(const std::array<index_t, Rank> xyLengths,
199 const std::array<index_t, Rank> xStrides,
200 const std::array<index_t, Rank> dyStrides,
201 const std::array<index_t, Rank> dxStrides,
202 const std::array<int, NumBatchNormReduceDim> reduceDims,
203 const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
204 const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
205 const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
206 const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
207 const XDataType* p_x,
208 const DyDataType* p_dy,
209 const ScaleDataType* p_scale,
210 const MeanVarDataType* p_savedMean,
211 const MeanVarDataType* p_savedInvVar,
212 const DyElementwiseOp dy_elementwise_op,
213 double epsilon,
214 DxDataType* p_dx,
215 DscaleDbiasDataType* p_dscale,
216 DscaleDbiasDataType* p_dbias)
217 : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
218 bnScaleStrides_(bnScaleStrides),
219 bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
220 bnMeanVarStrides_(bnMeanVarStrides),
221 p_x_(p_x),
222 p_dy_(p_dy),
223 p_scale_(p_scale),
224 p_savedMean_(p_savedMean),
225 p_savedInvVar_(p_savedInvVar),
226 dy_elementwise_op_(dy_elementwise_op),
227 p_dx_(p_dx),
228 p_dscale_(p_dscale),
229 p_dbias_(p_dbias)
230 {
231 xyLengths_ =
233 xStrides_ =
235 dyStrides_ =
237 dxStrides_ =
239
242
244
245 haveSavedMeanInvVar_ = (p_savedMean_ != nullptr && p_savedInvVar_ != nullptr);
246
247 if(UseMultiblockInK)
248 {
249 int iterations = 1;
250 while(true)
251 {
252 int testBlkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
253 (K_BlockTileSize * iterations);
254
255 // we want the blkGroupSize be not more than 128
256 if(testBlkGroupSize <= 128)
257 break;
258
259 iterations++;
260 };
261
262 blkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
263 (K_BlockTileSize * iterations);
264
265 numBlockTileIteration = iterations;
266 }
267 else
268 {
269 blkGroupSize = 1;
271 };
272
274
282 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
284 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
286 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
287 }
288
289 AccDataType epsilon_;
290
292
293 std::array<index_t, Rank> xyLengths_;
294 std::array<index_t, Rank> xStrides_;
295 std::array<index_t, Rank> dyStrides_;
296 std::array<index_t, Rank> dxStrides_;
297
298 std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
299 std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
300 std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
301 std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
302
303 const XDataType* p_x_;
304 const DyDataType* p_dy_;
305 const ScaleDataType* p_scale_;
306 const MeanVarDataType* p_savedMean_;
307 const MeanVarDataType* p_savedInvVar_;
308 const DyElementwiseOp dy_elementwise_op_;
309 DxDataType* p_dx_;
310 DscaleDbiasDataType* p_dscale_;
311 DscaleDbiasDataType* p_dbias_;
312
315
318 size_t gridSize;
319
326
330
333
336 };
337
338 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
339 {
340 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
341
342 size_t workspace_size = 0;
343
344 if(UseMultiblockInK && pArg_->blkGroupSize > 1)
345 {
346 // workspace for the partial reduced result for dscale
347 workspace_size +=
348 pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
349
350 // workspace for the partial reduced result for dbias
351 workspace_size +=
352 pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
353
354 if(!pArg_->haveSavedMeanInvVar_)
355 {
356 // workspace for welford intermediate mean
357 workspace_size +=
358 pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
359
360 // workspace for welford intermediate variance
361 workspace_size +=
362 pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
363
364 // workspace for welford intermediate count
365 workspace_size +=
366 pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t) + 64;
367
368 // workspace for welford result mean
369 workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
370
371 // workspace for welford result inv_variance
372 workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
373 };
374 }
375
376 return (workspace_size);
377 };
378
380 void* p_workspace,
381 const StreamConfig& = StreamConfig{}) const override
382 {
383 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
384
385 pArg_->p_workspace_ = p_workspace;
386
387 index_t space_sz;
388
389 // setup buffer for the partial reduced result for dscale
390 pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
391
392 space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
393 space_sz = math::integer_least_multiple(space_sz, 64);
394
395 // setup buffer for the partial reduced result for dbias
396 pArg_->workspace_reduce_dbias =
397 reinterpret_cast<char*>(pArg_->workspace_reduce_dscale) + space_sz;
398
399 if(UseMultiblockInK && pArg_->blkGroupSize > 1)
400 {
401 space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
402 space_sz = math::integer_least_multiple(space_sz, 64);
403
404 // setup buffer for welford intermediate mean
405 pArg_->workspace_mean =
406 reinterpret_cast<char*>(pArg_->workspace_reduce_dbias) + space_sz;
407
408 space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
409 space_sz = math::integer_least_multiple(space_sz, 64);
410
411 // setup buffer for welford intermediate varirance
412 pArg_->workspace_variance = reinterpret_cast<char*>(pArg_->workspace_mean) + space_sz;
413
414 space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
415 space_sz = math::integer_least_multiple(space_sz, 64);
416
417 // setup buffer for welford intermediate count
418 pArg_->workspace_count = reinterpret_cast<char*>(pArg_->workspace_variance) + space_sz;
419
420 space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t);
421 space_sz = math::integer_least_multiple(space_sz, 64);
422
423 // setup buffer for welford result mean
424 pArg_->workspace_savedMean = reinterpret_cast<char*>(pArg_->workspace_count) + space_sz;
425
426 space_sz = pArg_->invariant_length * sizeof(MeanVarDataType);
427 space_sz = math::integer_least_multiple(space_sz, 64);
428
429 // setup buffer for welford result inv_variance
430 pArg_->workspace_savedInvVar =
431 reinterpret_cast<char*>(pArg_->workspace_savedMean) + space_sz;
432 };
433 };
434
435 struct Invoker : public BaseInvoker
436 {
437 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
438 {
439 float avg_time = 0;
440
441 const auto mean_var_count_grid_desc_m_g =
444
445 const auto dscale_dbias_grid_desc_m_g =
448
449 const auto mean_var_count_grid_desc_m_k =
452
453 const auto dscale_dbias_grid_desc_m_k =
456
457 using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
458 using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
459 using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g);
460 using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k);
461
462 using GridwiseWelfordSecondHalfReduceFirstHalf_ =
464 DyDataType,
465 AccDataType,
466 ScaleDataType,
467 DscaleDbiasDataType,
468 MeanVarDataType,
469 DyElementwiseOp,
472 MeanVarCountGridDesc_M_K,
473 DscaleDbiasGridDesc_M_G,
474 BlockSize,
475 MThreadClusterSize,
476 KThreadClusterSize,
477 MThreadSliceSize,
478 KThreadSliceSize,
479 XDyDxVectorDim,
480 XSrcVectorSize,
481 DySrcVectorSize,
482 MeanVarSrcVectorSize>;
483
484 using GridwiseReduceSecondHalfBatchNormBwdFinal_ =
486 DyDataType,
487 DxDataType,
488 AccDataType,
489 ScaleDataType,
490 DscaleDbiasDataType,
491 MeanVarDataType,
492 DyElementwiseOp,
494 DscaleDbiasGridDesc_M_K,
497 BlockSize,
498 MThreadClusterSize,
499 KThreadClusterSize,
500 MThreadSliceSize,
501 KThreadSliceSize,
502 XDyDxVectorDim,
503 XSrcVectorSize,
504 DySrcVectorSize,
505 DxDstVectorSize,
506 ScaleSrcVectorSize,
507 DscaleDbiasDstVectorSize,
508 MeanVarSrcVectorSize>;
509
510 if(UseMultiblockInK && arg.blkGroupSize > 1)
511 {
512 using GetReduceCountPerThreadFunctor =
514
515 GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
517
518 if(!arg.haveSavedMeanInvVar_)
519 {
520 using GridwiseMultiblockWelfordFirstHalf_ =
522 AccDataType,
523 MeanVarDataType,
525 MeanVarCountGridDesc_M_G,
526 GetReduceCountPerThreadFunctor,
527 BlockSize,
528 MThreadClusterSize,
529 KThreadClusterSize,
530 MThreadSliceSize,
531 KThreadSliceSize,
532 XDyDxVectorDim,
533 XSrcVectorSize>;
534
535 const auto kern_multiblock_welford_first_half =
536 kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
537 XDataType,
538 MeanVarDataType,
540 MeanVarCountGridDesc_M_G,
541 GetReduceCountPerThreadFunctor>;
542
543 avg_time += launch_and_time_kernel(
544 stream_config,
545 kern_multiblock_welford_first_half,
546 dim3(arg.gridSize),
547 dim3(BlockSize),
548 0,
549 arg.x_grid_desc_m_k,
550 mean_var_count_grid_desc_m_g,
551 get_reduce_count_per_thread,
553 arg.p_x_,
554 static_cast<MeanVarDataType*>(arg.workspace_mean),
555 static_cast<MeanVarDataType*>(arg.workspace_variance),
556 static_cast<int32_t*>(arg.workspace_count));
557 };
558
559 const auto kern_welford_second_half_reduce_first_half =
561 GridwiseWelfordSecondHalfReduceFirstHalf_,
562 XDataType,
563 DyDataType,
564 AccDataType,
565 ScaleDataType,
566 DscaleDbiasDataType,
567 MeanVarDataType,
568 DyElementwiseOp,
571 MeanVarCountGridDesc_M_K,
572 DscaleDbiasGridDesc_M_G>;
573
574 const auto kern_reduce_second_half_batchnorm_backward_final =
576 GridwiseReduceSecondHalfBatchNormBwdFinal_,
577 XDataType,
578 DyDataType,
579 DxDataType,
580 ScaleDataType,
581 DscaleDbiasDataType,
582 MeanVarDataType,
583 DyElementwiseOp,
585 DscaleDbiasGridDesc_M_K,
588
589 index_t numDscaleDbiasBlockTileIteration =
590 (arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize;
591
592 avg_time += launch_and_time_kernel(
593 stream_config,
594 kern_welford_second_half_reduce_first_half,
595 dim3(arg.gridSize),
596 dim3(BlockSize),
597 0,
598 arg.x_grid_desc_m_k,
601 mean_var_count_grid_desc_m_k,
602 dscale_dbias_grid_desc_m_g,
603 arg.blkGroupSize,
605 numDscaleDbiasBlockTileIteration,
606 arg.epsilon_,
608 arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr,
609 arg.haveSavedMeanInvVar_ ? arg.p_savedInvVar_ : nullptr,
611 ? nullptr
612 : static_cast<const MeanVarDataType*>(arg.workspace_mean),
614 ? nullptr
615 : static_cast<const MeanVarDataType*>(arg.workspace_variance),
616 arg.haveSavedMeanInvVar_ ? nullptr
617 : static_cast<const int32_t*>(arg.workspace_count),
620 ? nullptr
621 : static_cast<MeanVarDataType*>(arg.workspace_savedMean),
623 ? nullptr
624 : static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
625 arg.p_x_,
626 arg.p_dy_,
627 static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
628 static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
629
630 avg_time += launch_and_time_kernel(
631 stream_config,
632 kern_reduce_second_half_batchnorm_backward_final,
633 dim3(arg.gridSize),
634 dim3(BlockSize),
635 0,
636 arg.x_grid_desc_m_k,
639 dscale_dbias_grid_desc_m_k,
643 arg.blkGroupSize,
644 arg.reduce_length,
646 numDscaleDbiasBlockTileIteration,
647 static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
648 static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
650 ? arg.p_savedMean_
651 : static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
653 ? arg.p_savedInvVar_
654 : static_cast<const MeanVarDataType*>(arg.workspace_savedInvVar),
655 arg.p_x_,
656 arg.p_dy_,
657 arg.p_scale_,
659 arg.p_dx_,
660 arg.p_dscale_,
661 arg.p_dbias_);
662 }
663 else
664 {
665 using GetReduceCountPerThreadFunctor =
666 GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
667
668 GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
670
671 using GridwiseBatchNormBackwardWithBlockwiseWelford_ =
673 DyDataType,
674 DxDataType,
675 AccDataType,
676 ScaleDataType,
677 DscaleDbiasDataType,
678 MeanVarDataType,
679 DyElementwiseOp,
683 GetReduceCountPerThreadFunctor,
684 BlockSize,
685 MThreadClusterSize,
686 KThreadClusterSize,
687 MThreadSliceSize,
688 KThreadSliceSize,
689 XDyDxVectorDim,
690 XSrcVectorSize,
691 DySrcVectorSize,
692 DxDstVectorSize,
693 ScaleSrcVectorSize,
694 DscaleDbiasDstVectorSize,
695 MeanVarSrcVectorSize>;
696
697 const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
698 GridwiseBatchNormBackwardWithBlockwiseWelford_,
699 XDataType,
700 DyDataType,
701 DxDataType,
702 AccDataType,
703 ScaleDataType,
704 DscaleDbiasDataType,
705 MeanVarDataType,
706 DyElementwiseOp,
710 GetReduceCountPerThreadFunctor>;
711
712 avg_time += launch_and_time_kernel(stream_config,
713 kern_batchnorm_bwd,
714 dim3(arg.gridSize),
715 dim3(BlockSize),
716 0,
717 arg.x_grid_desc_m_k,
723 get_reduce_count_per_thread,
724 arg.reduce_length,
726 arg.epsilon_,
727 arg.p_x_,
728 arg.p_dy_,
729 arg.p_scale_,
731 arg.p_savedMean_,
732 arg.p_savedInvVar_,
734 arg.p_dx_,
735 arg.p_dscale_,
736 arg.p_dbias_);
737 };
738
739 return (avg_time);
740 };
741
742 float Run(const BaseArgument* pArg,
743 const StreamConfig& stream_config = StreamConfig{}) override
744 {
745 return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
746 };
747 };
748
749 bool IsSupportedArgument(const BaseArgument* pArg) override
750 {
751 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
752
753 if constexpr(XDyDxVectorDim == 0)
754 {
755 if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
756 pArg_->dyStrides_[NumInvariantDim - 1] != 1 ||
757 pArg_->dxStrides_[NumInvariantDim - 1] != 1)
758 return false;
759
760 if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
761 pArg_->xyLengths_[NumInvariantDim - 1] % DySrcVectorSize != 0 ||
762 pArg_->xyLengths_[NumInvariantDim - 1] % DxDstVectorSize != 0)
763 return false;
764 }
765 else
766 {
767 if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->dyStrides_[Rank - 1] != 1 ||
768 pArg_->dxStrides_[Rank - 1] != 1)
769 return false;
770
771 if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
772 pArg_->xyLengths_[Rank - 1] % DySrcVectorSize != 0 ||
773 pArg_->xyLengths_[Rank - 1] % DxDstVectorSize != 0)
774 return false;
775 };
776
777 if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
778 return false;
779
780 if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
781 return false;
782
783 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
784 return false;
785
786 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
787 return false;
788
789 if(pArg_->haveSavedMeanInvVar_)
790 {
791 if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcVectorSize != 1)
792 return false;
793
794 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcVectorSize != 0)
795 return false;
796 };
797
798 bool is_valid = true;
799
801 if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
802 is_valid = false;
803 });
804
805 if(!is_valid)
806 return false;
807
808 return true;
809 };
810
811 std::unique_ptr<BaseArgument>
812 MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
813 const std::array<index_t, Rank> xStrides,
814 const std::array<index_t, Rank> dyStrides,
815 const std::array<index_t, Rank> dxStrides,
816 const std::array<int, NumBatchNormReduceDim> reduceDims,
817 const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
818 const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
819 const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
820 const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
821 const void* p_x,
822 const void* p_dy,
823 const void* p_scale,
824 const void* p_savedMean,
825 const void* p_savedInvVar,
826 double epsilon,
827 const DyElementwiseOp dy_elementwise_op,
828 void* p_dx,
829 void* p_dscale,
830 void* p_dbias) override
831 {
832 return std::make_unique<Argument>(xyLengths,
833 xStrides,
834 dyStrides,
835 dxStrides,
836 reduceDims,
837 bnScaleBiasMeanVarLengths,
838 bnScaleStrides,
839 bnDscaleDbiasStrides,
840 bnMeanVarStrides,
841 static_cast<const XDataType*>(p_x),
842 static_cast<const DyDataType*>(p_dy),
843 static_cast<const ScaleDataType*>(p_scale),
844 static_cast<const MeanVarDataType*>(p_savedMean),
845 static_cast<const MeanVarDataType*>(p_savedInvVar),
846 dy_elementwise_op,
847 epsilon,
848 static_cast<DxDataType*>(p_dx),
849 static_cast<DscaleDbiasDataType*>(p_dscale),
850 static_cast<DscaleDbiasDataType*>(p_dbias));
851 };
852
853 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
854 {
855 return std::make_unique<Invoker>();
856 };
857
858 std::string GetTypeString() const override
859 {
860 auto str = std::stringstream();
861
862 // clang-format off
863 str << "DeviceBatchNormBwdImpl<" << BlockSize << ",";
864 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
865 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
866 str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
867 str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
868 // clang-format on
869
870 return str.str();
871 }
872}; // namespace device
873
874} // namespace device
875} // namespace tensor_operation
876} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition gridwise_multiblock_welford_first_half.hpp:21
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_batchnorm_backward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:31
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__global__ void kernel_reduce_second_half_batchnorm_backward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, const DscaleDbiasDataType *const __restrict__ p_reduce_dscale, const DscaleDbiasDataType *const __restrict__ p_reduce_dbias, const MeanVarDataType *const __restrict__ p_mean, const MeanVarDataType *const __restrict__ p_inv_var, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:100
Definition gridwise_multiblock_welford_first_half.hpp:55
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:99
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batchnorm_backward.hpp:27
Definition device_batchnorm_backward_impl.hpp:197
std::array< index_t, Rank > dyStrides_
Definition device_batchnorm_backward_impl.hpp:295
XYGridDesc_M_K x_grid_desc_m_k
Definition device_batchnorm_backward_impl.hpp:320
AccDataType epsilon_
Definition device_batchnorm_backward_impl.hpp:289
DscaleDbiasDataType * p_dscale_
Definition device_batchnorm_backward_impl.hpp:310
std::array< index_t, Rank > xStrides_
Definition device_batchnorm_backward_impl.hpp:294
std::array< index_t, Rank > xyLengths_
Definition device_batchnorm_backward_impl.hpp:293
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition device_batchnorm_backward_impl.hpp:299
bool haveSavedMeanInvVar_
Definition device_batchnorm_backward_impl.hpp:291
const MeanVarDataType * p_savedMean_
Definition device_batchnorm_backward_impl.hpp:306
int blkGroupSize
Definition device_batchnorm_backward_impl.hpp:316
std::array< index_t, Rank > dxStrides_
Definition device_batchnorm_backward_impl.hpp:296
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m
Definition device_batchnorm_backward_impl.hpp:324
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition device_batchnorm_backward_impl.hpp:301
void * workspace_reduce_dbias
Definition device_batchnorm_backward_impl.hpp:335
const ScaleDataType * p_scale_
Definition device_batchnorm_backward_impl.hpp:305
long_index_t reduce_length
Definition device_batchnorm_backward_impl.hpp:314
const DyDataType * p_dy_
Definition device_batchnorm_backward_impl.hpp:304
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const XDataType *p_x, const DyDataType *p_dy, const ScaleDataType *p_scale, const MeanVarDataType *p_savedMean, const MeanVarDataType *p_savedInvVar, const DyElementwiseOp dy_elementwise_op, double epsilon, DxDataType *p_dx, DscaleDbiasDataType *p_dscale, DscaleDbiasDataType *p_dbias)
Definition device_batchnorm_backward_impl.hpp:198
ScaleBiasGridDesc_M scale_grid_desc_m
Definition device_batchnorm_backward_impl.hpp:323
size_t gridSize
Definition device_batchnorm_backward_impl.hpp:318
DxDataType * p_dx_
Definition device_batchnorm_backward_impl.hpp:309
void * workspace_variance
Definition device_batchnorm_backward_impl.hpp:328
MeanVarGridDesc_M mean_var_grid_desc_m
Definition device_batchnorm_backward_impl.hpp:325
const XDataType * p_x_
Definition device_batchnorm_backward_impl.hpp:303
void * workspace_savedMean
Definition device_batchnorm_backward_impl.hpp:331
int numBlockTileIteration
Definition device_batchnorm_backward_impl.hpp:317
void * workspace_mean
Definition device_batchnorm_backward_impl.hpp:327
void * workspace_savedInvVar
Definition device_batchnorm_backward_impl.hpp:332
long_index_t invariant_length
Definition device_batchnorm_backward_impl.hpp:313
DscaleDbiasDataType * p_dbias_
Definition device_batchnorm_backward_impl.hpp:311
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition device_batchnorm_backward_impl.hpp:298
std::array< index_t, Rank - NumBatchNormReduceDim > bnDscaleDbiasStrides_
Definition device_batchnorm_backward_impl.hpp:300
void * workspace_count
Definition device_batchnorm_backward_impl.hpp:329
XYGridDesc_M_K dy_grid_desc_m_k
Definition device_batchnorm_backward_impl.hpp:321
const MeanVarDataType * p_savedInvVar_
Definition device_batchnorm_backward_impl.hpp:307
const DyElementwiseOp dy_elementwise_op_
Definition device_batchnorm_backward_impl.hpp:308
void * workspace_reduce_dscale
Definition device_batchnorm_backward_impl.hpp:334
XYGridDesc_M_K dx_grid_desc_m_k
Definition device_batchnorm_backward_impl.hpp:322
Definition device_batchnorm_backward_impl.hpp:436
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batchnorm_backward_impl.hpp:437
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batchnorm_backward_impl.hpp:742
Definition device_batchnorm_backward_impl.hpp:58
std::string GetTypeString() const override
Definition device_batchnorm_backward_impl.hpp:858
static constexpr index_t NumInvariantDim
Definition device_batchnorm_backward_impl.hpp:71
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias) override
Definition device_batchnorm_backward_impl.hpp:812
static constexpr index_t M_BlockTileSize
Definition device_batchnorm_backward_impl.hpp:73
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition device_batchnorm_backward_impl.hpp:749
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_backward_impl.hpp:123
ScaleBiasGridDesc_M MeanVarGridDesc_M
Definition device_batchnorm_backward_impl.hpp:194
static constexpr index_t K_BlockTileSize
Definition device_batchnorm_backward_impl.hpp:74
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_batchnorm_backward_impl.hpp:76
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasGridDesc_M
Definition device_batchnorm_backward_impl.hpp:193
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition device_batchnorm_backward_impl.hpp:163
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_batchnorm_backward_impl.hpp:338
static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_backward_impl.hpp:141
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_batchnorm_backward_impl.hpp:379
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batchnorm_backward_impl.hpp:853
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition device_batchnorm_backward_impl.hpp:192