18 template <
typename Problem>
23 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
24 constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
25 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
26 constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
28 static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
29 return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
32 template <
typename Problem>
36 using BlockGemmShape =
typename Problem::BlockGemmShape;
38 constexpr index_t BlockSize = Problem::kBlockSize;
39 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
40 constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
41 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
42 constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
43 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
45 typename Problem::ComputeDataType,
46 typename Problem::CDataType,
52 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
53 using TileEncodingPattern =
59 Problem::QuantGroupSize::kN>;
61 return TileEncodingPattern::make_2d_static_tile_distribution();
64 template <
typename Problem>
67 using BlockWarps =
typename Problem::BlockGemmShape::BlockWarps;
68 using WarpTile =
typename Problem::BlockGemmShape::WarpTile;
70 static_assert(Problem::QuantGroupSize::kK % WarpTile::at(
I2) == 0,
71 "KPerWarpGemm must be a multiple of QuantGroupSize!");
74 typename Problem::ComputeDataType,
75 typename Problem::CDataType,
80 static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
81 std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
82 static_assert(std::is_same_v<typename Problem::CDataType, float>);
84 typename Problem::BDataType,
85 typename Problem::CDataType,
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
int32_t index_t
Definition integer.hpp:9
Definition block_universal_gemm_as_bs_bquant_cr.hpp:56
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:12
UniversalGemmPipelineAgBgCrPolicy Base
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:65
static CK_TILE_HOST_DEVICE constexpr auto MakeBQDramTileDistribution()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:33
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeBQ()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:19
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:693
Definition gemm_group_quant_utils.hpp:176