warp_gemm_attribute_wmma_impl_base_traits.hpp Source File

warp_gemm_attribute_wmma_impl_base_traits.hpp Source File#

Composable Kernel: warp_gemm_attribute_wmma_impl_base_traits.hpp Source File
warp_gemm_attribute_wmma_impl_base_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5namespace ck_tile {
6template <typename Arch, typename ADType, typename BDType, typename CDType>
8
9// GFX11 specialization
10template <typename ADType, typename BDType, typename CDType>
11struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
12{
13 using ADataType = ADType;
14 using BDataType = BDType;
15 using CDataType = CDType;
16
20
21 static constexpr index_t kM = 16;
22 static constexpr index_t kN = 16;
23 static constexpr index_t kK = 16;
24
25 static constexpr index_t kAMBlock = 1;
26 static constexpr index_t kBNBlock = 1;
27
28 static constexpr index_t kRepeat = 2;
29 static constexpr index_t kAMLane = 16;
30 static constexpr index_t kBNLane = 16;
31 static constexpr index_t kABK0PerLane = 1;
32 static constexpr index_t kABKLane = 1;
33 static constexpr index_t kABK1PerLane = 16;
34
35 static constexpr index_t kCMLane = 2;
36 static constexpr index_t kCNLane = 16;
37 static constexpr index_t kCM0PerLane = 8;
38 static constexpr index_t kCM1PerLane = 1;
39
44
49
54};
55
56// GFX12 specialization
57template <typename ADType, typename BDType, typename CDType>
58struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
59{
60 using ADataType = ADType;
61 using BDataType = BDType;
62 using CDataType = CDType;
63
67
68 static constexpr index_t kM = 16;
69 static constexpr index_t kN = 16;
70 static constexpr index_t kK = 16;
71
72 static constexpr index_t kAMBlock = 1;
73 static constexpr index_t kBNBlock = 1;
74
75 static constexpr index_t kRepeat = 1;
76 static constexpr index_t kAMLane = 16;
77 static constexpr index_t kBNLane = 16;
78 static constexpr index_t kABK0PerLane = 1;
79 static constexpr index_t kABKLane = 2;
80 static constexpr index_t kABK1PerLane = 8;
81
82 static constexpr index_t kCMLane = 2;
83 static constexpr index_t kCNLane = 16;
84 static constexpr index_t kCM0PerLane = 1;
85 static constexpr index_t kCM1PerLane = 8;
86
91
96
101};
102} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
int32_t index_t
Definition integer.hpp:9
ext_vector_t< CDataType, 8 > CVecType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:19
sequence< 1, 2 > kCPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:45
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:26
static constexpr index_t kBNLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:30
static constexpr index_t kN
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:22
static constexpr index_t kCNLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:36
static constexpr index_t kM
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:21
sequence< 2, 1 > kCTPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:50
sequence< 0, 2 > kCYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:48
CDType CDataType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:15
sequence< 2, 2 > kCTYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:52
sequence< 2, 2 > kABYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:42
sequence< 0, 2 > kCTYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:53
static constexpr index_t kK
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:23
static constexpr index_t kCMLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:35
static constexpr index_t kAMLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:29
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:38
static constexpr index_t kABK0PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:31
ext_vector_t< ADataType, 16 > AVecType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:17
sequence< 1, 0 > kCTPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:51
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:25
sequence< 0, 1, 0 > kABPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:41
static constexpr index_t kABKLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:32
static constexpr index_t kRepeat
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:28
sequence< 1, 1 > kCYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:47
ADType ADataType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:13
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:37
sequence< 0, 2 > kABYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:43
sequence< 1, 0 > kCPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:46
static constexpr index_t kABK1PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:33
sequence< 0, 2, 1 > kABPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:40
ext_vector_t< BDataType, 16 > BVecType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:18
BDType BDataType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:14
static constexpr index_t kABK0PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:78
static constexpr index_t kCMLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:82
static constexpr index_t kCNLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:83
BDType BDataType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:61
ext_vector_t< ADataType, 8 > AVecType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:64
static constexpr index_t kN
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:69
static constexpr index_t kRepeat
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:75
static constexpr index_t kBNLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:77
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:85
sequence< 2, 1 > kABPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:87
sequence< 2, 1 > kCTPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:97
sequence< 2, 2 > kABYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:89
CDType CDataType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:62
sequence< 1, 1 > kCYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:94
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:73
sequence< 1, 0 > kABPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:88
sequence< 2, 2 > kCTYs2RHsMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:99
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:84
sequence< 0, 2 > kCTYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:100
static constexpr index_t kABK1PerLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:80
sequence< 1, 0 > kCTPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:98
sequence< 1, 0 > kCPs2RHssMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:93
sequence< 0, 2 > kABYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:90
static constexpr index_t kAMLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:76
ext_vector_t< BDataType, 8 > BVecType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:65
static constexpr index_t kK
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:70
ext_vector_t< CDataType, 8 > CVecType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:66
sequence< 0, 2 > kCYs2RHsMinor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:95
static constexpr index_t kM
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:68
sequence< 1, 2 > kCPs2RHssMajor
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:92
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:72
ADType ADataType
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:60
static constexpr index_t kABKLane
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:79
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:7
Definition arch.hpp:363
Definition arch.hpp:366
Definition tile/core/container/sequence.hpp:49