device_normalization_fwd_splitk_impl.hpp Source File#
device_normalization_fwd_splitk_impl.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const WorkspaceMeanVarDataType *const p_mean_global, const WorkspaceMeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition device_normalization_fwd_splitk_impl.hpp:58
__global__ void kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, WorkspaceMeanVarDataType *const __restrict__ p_welford_mean, WorkspaceMeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count)
Definition device_normalization_fwd_splitk_impl.hpp:27
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_normalization_splitk_1st.hpp:28
Definition gridwise_normalization_splitk_2nd.hpp:42
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_normalization_fwd.hpp:23
Definition device_normalization_fwd_splitk_impl.hpp:338
std::vector< index_t > Lengths_
Definition device_normalization_fwd_splitk_impl.hpp:445
SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_
Definition device_normalization_fwd_splitk_impl.hpp:464
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_
Definition device_normalization_fwd_splitk_impl.hpp:468
std::vector< index_t > gammaStrides_
Definition device_normalization_fwd_splitk_impl.hpp:447
std::vector< index_t > betaStrides_
Definition device_normalization_fwd_splitk_impl.hpp:448
std::vector< index_t > saveInvStdStrides_
Definition device_normalization_fwd_splitk_impl.hpp:451
Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_
Definition device_normalization_fwd_splitk_impl.hpp:469
void * p_workspace_mean_
Definition device_normalization_fwd_splitk_impl.hpp:441
SrcGridDesc_M_K y_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:463
int kGridSize_
Definition device_normalization_fwd_splitk_impl.hpp:455
SrcGridDesc_M_K gamma_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:461
std::vector< index_t > saveMeanStrides_
Definition device_normalization_fwd_splitk_impl.hpp:450
index_t KRaw_
Definition device_normalization_fwd_splitk_impl.hpp:472
const XDataType * p_x_
Definition device_normalization_fwd_splitk_impl.hpp:435
SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_
Definition device_normalization_fwd_splitk_impl.hpp:465
void * p_workspace_var_
Definition device_normalization_fwd_splitk_impl.hpp:442
YElementwiseOperation y_elementwise_op_
Definition device_normalization_fwd_splitk_impl.hpp:453
size_t gridSize_
Definition device_normalization_fwd_splitk_impl.hpp:458
std::vector< index_t > xStrides_
Definition device_normalization_fwd_splitk_impl.hpp:446
int numBlockTileIteration_
Definition device_normalization_fwd_splitk_impl.hpp:457
std::vector< index_t > yStrides_
Definition device_normalization_fwd_splitk_impl.hpp:449
ComputeDataType epsilon_
Definition device_normalization_fwd_splitk_impl.hpp:433
void * p_workspace_count_
Definition device_normalization_fwd_splitk_impl.hpp:443
const BetaDataType * p_beta_
Definition device_normalization_fwd_splitk_impl.hpp:437
SrcGridDesc_M_K beta_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:462
SrcGridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:460
SaveMeanInvStdDataType * p_saveMean_
Definition device_normalization_fwd_splitk_impl.hpp:439
YDataType * p_y_
Definition device_normalization_fwd_splitk_impl.hpp:438
index_t invariant_lowest_length_
Definition device_normalization_fwd_splitk_impl.hpp:474
const GammaDataType * p_gamma_
Definition device_normalization_fwd_splitk_impl.hpp:436
index_t MRaw_
Definition device_normalization_fwd_splitk_impl.hpp:471
SaveMeanInvStdDataType * p_saveInvStd_
Definition device_normalization_fwd_splitk_impl.hpp:440
Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_
Definition device_normalization_fwd_splitk_impl.hpp:467
int numMeanVarCountIteration_
Definition device_normalization_fwd_splitk_impl.hpp:456
Argument(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType *p_x, const GammaDataType *p_gamma, const BetaDataType *p_beta, YDataType *p_y, SaveMeanInvStdDataType *p_saveMean, SaveMeanInvStdDataType *p_saveInvStd)
Definition device_normalization_fwd_splitk_impl.hpp:339
Definition device_normalization_fwd_splitk_impl.hpp:478
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_fwd_splitk_impl.hpp:479
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_fwd_splitk_impl.hpp:553
Definition device_normalization_fwd_splitk_impl.hpp:145
decltype(MakeWorkspaceMeanVarDescriptor_M_K< Sequence< true, true >, 1, 1 >(1, 1)) Kernel2MeanVarGridDesc_M_KBlock
Definition device_normalization_fwd_splitk_impl.hpp:288
static constexpr auto I0
Definition device_normalization_fwd_splitk_impl.hpp:165
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_normalization_fwd_splitk_impl.hpp:580
decltype(MakeWorkspaceMeanVarDescriptor_M_K< Sequence< true, false >, 1, 1 >(1, 1)) Kernel1MeanVarGridDesc_M_KBlock
Definition device_normalization_fwd_splitk_impl.hpp:285
decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})) SaveMeanInvStdGridDesc_M
Definition device_normalization_fwd_splitk_impl.hpp:294
static constexpr index_t K_BlockTileSize
Definition device_normalization_fwd_splitk_impl.hpp:170
static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K)
Definition device_normalization_fwd_splitk_impl.hpp:248
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_fwd_splitk_impl.hpp:255
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_normalization_fwd_splitk_impl.hpp:560
decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)) SrcGridDesc_M_K
Definition device_normalization_fwd_splitk_impl.hpp:284
static constexpr index_t NumInvariantDim
Definition device_normalization_fwd_splitk_impl.hpp:168
GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize > GridwiseWelford
Definition device_normalization_fwd_splitk_impl.hpp:296
SaveMeanInvStdDataType WorkspaceMeanVarDataType
Definition device_normalization_fwd_splitk_impl.hpp:146
static constexpr auto I1
Definition device_normalization_fwd_splitk_impl.hpp:166
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_fwd_splitk_impl.hpp:607
decltype(MakeWorkspaceCountDescriptor_M_K< Sequence< true, true >, 1, 1 >(1, 1)) Kernel2CountGridDesc_M_KBlock
Definition device_normalization_fwd_splitk_impl.hpp:291
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int kBlockSize, int numBlockTileIteration)
Definition device_normalization_fwd_splitk_impl.hpp:175
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_fwd_splitk_impl.hpp:727
static constexpr index_t M_BlockTileSize
Definition device_normalization_fwd_splitk_impl.hpp:169
GridwiseNormalizationSplitK2nd< WorkspaceMeanVarDataType, XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, Kernel2MeanVarGridDesc_M_KBlock, Kernel2CountGridDesc_M_KBlock, SrcGridDesc_M_K, SaveMeanInvStdGridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, XYVectorDim, YDstVectorSize, SaveMeanInvStdDstVectorSize > GridwiseWelfordNormalization
Definition device_normalization_fwd_splitk_impl.hpp:309
static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K)
Definition device_normalization_fwd_splitk_impl.hpp:240
tensor_operation::element_wise::PassThrough PassThrough
Definition device_normalization_fwd_splitk_impl.hpp:163
std::string GetTypeString() const override
Definition device_normalization_fwd_splitk_impl.hpp:732
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_saveMean, void *p_saveInvStd, YElementwiseOperation y_elementwise_op) override
Definition device_normalization_fwd_splitk_impl.hpp:687
static constexpr bool reduceAllDim
Definition device_normalization_fwd_splitk_impl.hpp:172
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340