gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp File Reference

gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp File Reference#

Composable Kernel: gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp File Reference
gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp File Reference

Go to the source code of this file.

Classes

struct  ck::GridwiseWelfordSecondHalfBatchNormForwardFinal< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, XYGridDesc_M_K, MeanVarCountGridDesc_M_K, ScaleBiasGridDesc_M, MeanVarGridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize >

Namespaces

namespace  ck

Functions

template<typename GridwiseWelfordSecondHalfBatchNormForwardFinal_, typename XDataType, typename YDataType, typename AccDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType, typename YElementwiseOp, typename XYGridDesc_M_K, typename MeanVarCountGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M>
__global__ void ck::kernel_welford_second_half_batchnorm_forward_final (const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)