mx_flatmm_kernel.hpp Source File#
mx_flatmm_kernel.hpp
Go to the documentation of this file.
66 return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition flatmm_kernel.hpp:229
Definition flatmm_kernel.hpp:249
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPongSize()
Definition flatmm_kernel.hpp:356
static CK_TILE_HOST constexpr auto BlockSize()
Definition flatmm_kernel.hpp:330
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPingSize()
Definition flatmm_kernel.hpp:352
Definition mx_flatmm_kernel.hpp:18
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition mx_flatmm_kernel.hpp:28
static CK_TILE_HOST const std::string GetName()
Definition mx_flatmm_kernel.hpp:63
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition mx_flatmm_kernel.hpp:25
FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ > Underlying
Definition mx_flatmm_kernel.hpp:19
static constexpr index_t NumDTensor
Definition mx_flatmm_kernel.hpp:50
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition mx_flatmm_kernel.hpp:29
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition mx_flatmm_kernel.hpp:27
static constexpr int NThreadPerXdl
Definition mx_flatmm_kernel.hpp:40
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition mx_flatmm_kernel.hpp:118
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition mx_flatmm_kernel.hpp:37
static constexpr bool UsePersistentKernel
Definition mx_flatmm_kernel.hpp:32
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition mx_flatmm_kernel.hpp:321
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition mx_flatmm_kernel.hpp:26
static CK_TILE_HOST constexpr auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition mx_flatmm_kernel.hpp:72
static constexpr int MThreadPerXdl
Definition mx_flatmm_kernel.hpp:39
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition mx_flatmm_kernel.hpp:34
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition mx_flatmm_kernel.hpp:21
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition mx_flatmm_kernel.hpp:114
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition mx_flatmm_kernel.hpp:468
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition mx_flatmm_kernel.hpp:30
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition mx_flatmm_kernel.hpp:35
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition mx_flatmm_kernel.hpp:253
static constexpr index_t KernelBlockSize
Definition mx_flatmm_kernel.hpp:31
remove_cvref_t< MXFlatmmPipeline_ > FlatmmPipeline
Definition mx_flatmm_kernel.hpp:22
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition mx_flatmm_kernel.hpp:400
static constexpr int KThreadPerXdl
Definition mx_flatmm_kernel.hpp:41
remove_cvref_t< typename MXFlatmmPipeline_::BlockGemmShape > BlockGemmShape
Definition mx_flatmm_kernel.hpp:23
static constexpr int BPackedSize
Definition mx_flatmm_kernel.hpp:44
static constexpr int APackedSize
Definition mx_flatmm_kernel.hpp:43
Definition type_traits.hpp:115
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49