tile_gemm_quant_traits.hpp Source File

tile_gemm_quant_traits.hpp Source File#

Composable Kernel: tile_gemm_quant_traits.hpp Source File
tile_gemm_quant_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7#include <cstdint>
8
9namespace ck_tile {
10
11enum struct QuantType : std::uint16_t
12{
17};
18
19inline std::string quant_type_to_string(QuantType quant_type)
20{
21 switch(quant_type)
22 {
23 case QuantType::AQuantGrouped: return "AQuantGrouped";
24 case QuantType::BQuantGrouped: return "BQuantGrouped";
25 case QuantType::RowColQuant: return "RowColQuant";
26 case QuantType::TensorQuant: return "TensorQuant";
27 default: return "Unknown";
28 }
29}
30
31template <bool kPadM_,
32 bool kPadN_,
33 bool kPadK_,
34 bool PreshuffleQuant_,
35 bool PreshuffleB_,
36 typename ALayout_,
37 typename BLayout_,
38 typename CLayout_,
39 QuantType QuantType_,
40 typename AQLayout_ = ALayout_,
41 typename BQLayout_ = BLayout_,
42 bool TransposeC_ = false,
43 bool DoubleSmemBuffer_ = false,
44 bool UsePersistentKernel_ = false>
46{
47 static constexpr bool kPadM = kPadM_;
48 static constexpr bool kPadN = kPadN_;
49 static constexpr bool kPadK = kPadK_;
50
51 static constexpr QuantType kQuantType = QuantType_;
52
53 static constexpr int _VectorSize = 16;
54 static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
55
56 using ALayout = ALayout_;
57 using BLayout = BLayout_;
58 using CLayout = CLayout_;
59 using AQLayout = AQLayout_;
60 using BQLayout = BQLayout_;
61
62 // TODO: It should be replaced to single value
63 using AsLayout = ALayout_;
64 using BsLayout = BLayout_;
65
66 static constexpr bool TransposeC = TransposeC_;
67 static constexpr bool UseStructuredSparsity = false;
68 static constexpr index_t NumWaveGroups = 1;
69 static constexpr bool UsePersistentKernel = UsePersistentKernel_;
70
71 static constexpr bool PreshuffleQuant = PreshuffleQuant_;
72 static constexpr bool PreshuffleB = PreshuffleB_;
73};
74
75} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
std::string quant_type_to_string(QuantType quant_type)
Definition tile_gemm_quant_traits.hpp:19
QuantType
Definition tile_gemm_quant_traits.hpp:12
@ BQuantGrouped
Definition tile_gemm_quant_traits.hpp:14
@ RowColQuant
Definition tile_gemm_quant_traits.hpp:15
@ TensorQuant
Definition tile_gemm_quant_traits.hpp:16
@ AQuantGrouped
Definition tile_gemm_quant_traits.hpp:13
int32_t index_t
Definition integer.hpp:9
Definition tile_gemm_quant_traits.hpp:46
AQLayout_ AQLayout
Definition tile_gemm_quant_traits.hpp:59
static constexpr QuantType kQuantType
Definition tile_gemm_quant_traits.hpp:51
static constexpr bool kPadN
Definition tile_gemm_quant_traits.hpp:48
static constexpr bool kPadK
Definition tile_gemm_quant_traits.hpp:49
ALayout_ ALayout
Definition tile_gemm_quant_traits.hpp:56
static constexpr bool DoubleSmemBuffer
Definition tile_gemm_quant_traits.hpp:54
CLayout_ CLayout
Definition tile_gemm_quant_traits.hpp:58
BLayout_ BLayout
Definition tile_gemm_quant_traits.hpp:57
static constexpr int _VectorSize
Definition tile_gemm_quant_traits.hpp:53
static constexpr bool TransposeC
Definition tile_gemm_quant_traits.hpp:66
static constexpr bool UsePersistentKernel
Definition tile_gemm_quant_traits.hpp:69
static constexpr bool PreshuffleB
Definition tile_gemm_quant_traits.hpp:72
static constexpr bool kPadM
Definition tile_gemm_quant_traits.hpp:47
BQLayout_ BQLayout
Definition tile_gemm_quant_traits.hpp:60
BLayout_ BsLayout
Definition tile_gemm_quant_traits.hpp:64
static constexpr index_t NumWaveGroups
Definition tile_gemm_quant_traits.hpp:68
static constexpr bool PreshuffleQuant
Definition tile_gemm_quant_traits.hpp:71
ALayout_ AsLayout
Definition tile_gemm_quant_traits.hpp:63
static constexpr bool UseStructuredSparsity
Definition tile_gemm_quant_traits.hpp:67