tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ > Struct Template Reference#
ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ > Struct Template Reference
#include <gemm_group_quant_utils.hpp>
Inheritance diagram for ck_tile::tile_distribution_encoding_pattern_bq< BlockGemmShape, WarpGemm, BlockSize, YPerTile, XPerTile, XPerQ >:
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr auto | make_2d_static_tile_distribution () |
| Creates a 2D tile distribution for BQ (B-matrix quantization scales). | |
Static Public Attributes | |
| static constexpr index_t | warp_size = get_warp_size() |
| static constexpr index_t | num_warps = BlockSize / get_warp_size() |
| static constexpr index_t | MWarps = BlockGemmShape::BlockWarps::at(number<0>{}) |
| static constexpr index_t | NWarps = BlockGemmShape::BlockWarps::at(number<1>{}) |
| static constexpr index_t | KWarps = BlockGemmShape::BlockWarps::at(number<2>{}) |
| static constexpr index_t | NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN) |
Member Function Documentation
◆ make_2d_static_tile_distribution()
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
inlinestaticconstexpr |
Creates a 2D tile distribution for BQ (B-matrix quantization scales).
This function determines the optimal thread distribution pattern for loading and applying quantization scales to the B matrix based on the quantization group size (XPerQ) relative to warp dimensions.
Three distinct distribution patterns are handled:
- Fine-grained quantization (XPerQ < WarpGemm::kN):
- Multiple quantization groups exist within a single warp's N-dimension
- Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
- Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
- Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
- Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
- Each warp handles exactly one quantization scale
- Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
- Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
- Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
- Quantization group spans multiple warps
- All warps share the same scale value
- Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
- Returns
- A static tile distribution encoding for the BQ scale tensor
Member Data Documentation
◆ KWarps
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
staticconstexpr |
◆ MWarps
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
staticconstexpr |
◆ NIterPerWarp
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
staticconstexpr |
◆ num_warps
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
staticconstexpr |
◆ NWarps
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
staticconstexpr |
◆ warp_size
template<typename BlockGemmShape, typename WarpGemm, index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t XPerQ>
|
staticconstexpr |
The documentation for this struct was generated from the following file: