fused_moegemm_pipeline_flatmm_ex.hpp Source File#
fused_moegemm_pipeline_flatmm_ex.hpp
Go to the documentation of this file.
345 constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
346 constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
347 constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
399 constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
400 constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
401 constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:820
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition load_tile.hpp:81
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition arch.hpp:121
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition update_tile.hpp:68
Definition fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition fused_moegemm_pipeline_flatmm_ex.hpp:46
typename Problem::DScaleDataType DScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:91
typename Problem::GDataType GDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:30
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:77
static constexpr index_t kAlignmentG
Definition fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition fused_moegemm_pipeline_flatmm_ex.hpp:45
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize_A()
Definition fused_moegemm_pipeline_flatmm_ex.hpp:72
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition fused_moegemm_pipeline_flatmm_ex.hpp:99
static constexpr value_type value
Definition tile/core/numeric/integral_constant.hpp:16
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192