warp_gemm_impl.hpp Source File

warp_gemm_impl.hpp Source File#

Composable Kernel: warp_gemm_impl.hpp Source File
warp_gemm_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7namespace ck_tile {
8
9template <typename WarpGemmAttribute_>
11{
13
14 static constexpr index_t kM = WarpGemmAttribute::kM;
15 static constexpr index_t kN = WarpGemmAttribute::kN;
16 static constexpr index_t kK = WarpGemmAttribute::kK;
17 static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane;
22 static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
23
24 using ADataType = typename WarpGemmAttribute::ADataType;
25 using BDataType = typename WarpGemmAttribute::BDataType;
26 using CDataType = typename WarpGemmAttribute::CDataType;
27
28 using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding;
29 using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding;
30 using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding;
31
35
39
40 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
41 {
42 return WarpGemmAttribute_::get_num_of_access();
43 }
44
45 template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
47 operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
48 {
52 using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
53 using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
54 using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
55
56 constexpr auto I0 = number<0>{};
57
58 const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
59 const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
60 auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
61
62 // c_vec += a_vec * b_vec
63 WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
64
65 c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
66 }
67
68 template <typename CTensor,
69 typename ATensor,
70 typename BTensor,
71 index_t i_subk,
72 bool post_nop_ = false>
73 CK_TILE_DEVICE void operator()(CTensor& c,
74 const ATensor& a,
75 const BTensor& b,
77 bool_constant<post_nop_> = {}) const
78 {
79 using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
80 using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
81 using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
82
83 constexpr auto I0 = number<0>{};
84
85 const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
86 const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
87 auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
88
89 // c_vec += a_vec * b_vec
91
92 c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
93 }
94
95 template <index_t opselA,
96 index_t opselB,
97 typename CTensor,
98 typename ATensor,
99 typename BTensor,
100 bool post_nop_ = false>
101 CK_TILE_DEVICE void operator()(CTensor& c,
102 const ATensor& a,
103 const BTensor& b,
104 const int32_t& a_scale,
105 const int32_t& b_scale,
106 bool_constant<post_nop_> = {}) const
107 {
111 using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
112 using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
113 using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
114
115 constexpr auto I0 = number<0>{};
116
117 const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
118 const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
119 auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
120
121 // c_vec += a_vec * b_vec
122 WarpGemmAttribute{}.template operator()<opselA, opselB>(
123 c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
124
125 c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
126 }
127
128 template <typename ATensor, typename BTensor>
129 CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
130 {
131 using CTensor = CWarpTensor;
134 CTensor c;
135
136 using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
137 using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
138 using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
139
140 constexpr auto I0 = number<0>{};
141
142 const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
143 const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
144
145 // c_vec = a_vec * b_vec
146 auto c_vec = WarpGemmAttribute{}(a_vec, b_vec);
147
148 c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
149
150 return c;
151 }
152
153 template <index_t opselA, index_t opselB, typename ATensor, typename BTensor>
154 CK_TILE_DEVICE auto operator()(const ATensor& a,
155 const BTensor& b,
156 const int32_t& a_scale,
157 const int32_t& b_scale) const
158 {
159 using CTensor = CWarpTensor;
162 CTensor c;
163
164 using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
165 using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
166 using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
167
168 constexpr auto I0 = number<0>{};
169
170 const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
171 const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
172
173 // c_vec = a_vec * b_vec
174 auto c_vec =
175 WarpGemmAttribute{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
176
177 c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
178
179 return c;
180 }
181};
182
183} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
constexpr bool is_similiar_distributed_tensor_v
Definition static_distributed_tensor.hpp:230
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t int32_t
Definition integer.hpp:10
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition warp_gemm_impl.hpp:11
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, number< i_subk >, bool_constant< post_nop_ >={}) const
Definition warp_gemm_impl.hpp:73
typename WarpGemmAttribute::CWarpDstrEncoding CWarpDstrEncoding
Definition warp_gemm_impl.hpp:30
CK_TILE_DEVICE auto operator()(const ATensor &a, const BTensor &b) const
Definition warp_gemm_impl.hpp:129
typename WarpGemmAttribute::BWarpDstrEncoding BWarpDstrEncoding
Definition warp_gemm_impl.hpp:29
remove_cvref_t< decltype(make_static_tile_distribution(BWarpDstrEncoding{}))> BWarpDstr
Definition warp_gemm_impl.hpp:33
CK_TILE_DEVICE auto operator()(const ATensor &a, const BTensor &b, const int32_t &a_scale, const int32_t &b_scale) const
Definition warp_gemm_impl.hpp:154
typename WarpGemmAttribute::AWarpDstrEncoding AWarpDstrEncoding
Definition warp_gemm_impl.hpp:28
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, bool_constant< post_nop_ >={}) const
Definition warp_gemm_impl.hpp:47
remove_cvref_t< decltype(make_static_tile_distribution(CWarpDstrEncoding{}))> CWarpDstr
Definition warp_gemm_impl.hpp:34
static_distributed_tensor< ADataType, AWarpDstr > AWarpTensor
Definition warp_gemm_impl.hpp:36
static_distributed_tensor< BDataType, BWarpDstr > BWarpTensor
Definition warp_gemm_impl.hpp:37
remove_cvref_t< WarpGemmAttributeMfma< WarpGemmAttributeMfmaImplF32F32F32M16N16K4< WGAttrCtlEnum::Default_ > > > WarpGemmAttribute
Definition warp_gemm_impl.hpp:12
remove_cvref_t< decltype(make_static_tile_distribution(AWarpDstrEncoding{}))> AWarpDstr
Definition warp_gemm_impl.hpp:32
static_distributed_tensor< CDataType, CWarpDstr > CWarpTensor
Definition warp_gemm_impl.hpp:38
CK_TILE_DEVICE void operator()(CTensor &c, const ATensor &a, const BTensor &b, const int32_t &a_scale, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition warp_gemm_impl.hpp:101
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_impl.hpp:40
Definition static_distributed_tensor.hpp:21