blockwise_gemm_dlops_v2r2.hpp Source File

blockwise_gemm_dlops_v2r2.hpp Source File#

Composable Kernel: blockwise_gemm_dlops_v2r2.hpp Source File
blockwise_gemm_dlops_v2r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
5#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
6
7#include "common_header.hpp"
8#include "tensor_adaptor.hpp"
10#include "threadwise_contraction_dlops.hpp"
11
12namespace ck {
13
14// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
15// A and B are visable to the whole block, C is distributed among each thread
16// Assume:
17// 1. A:
18// 1. AKMBlockDesc is known at compile-time
19// 2. ABlockBuffer is DynamicBuffer
20// 2. B:
21// 1. BKNBlockDesc is known at compile-time
22// 2. BBlockBuffer is DynamicBuffer
23// 3. C:
24// 1. CM0M1N0N1ThreadDesc is known at compile-time
25// 2. CThreadBuffer is StaticBuffer
26// Also assume:
27// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
28template <
29 index_t BlockSize,
30 typename FloatA,
31 typename FloatB,
32 typename FloatC,
33 typename AKMBlockDesc,
34 typename BKNBlockDesc,
35 index_t M1PerThreadM11,
36 index_t N1PerThreadN11,
37 index_t KPerThread,
38 index_t M1N1ThreadClusterM100,
39 index_t M1N1ThreadClusterN100,
40 index_t M1N1ThreadClusterM101,
41 index_t M1N1ThreadClusterN101,
42 index_t AThreadCopyScalarPerVector_M11,
43 index_t BThreadCopyScalarPerVector_N11,
44 typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
45 bool>::type = false>
47{
51
52 static constexpr auto I0 = Number<0>{};
53 static constexpr auto I1 = Number<1>{};
54 static constexpr auto I2 = Number<2>{};
55 static constexpr auto I3 = Number<3>{};
56
57 static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
58 static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
59 static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
60
61 static constexpr index_t M100 = M1N1ThreadClusterM100;
62 static constexpr index_t N100 = M1N1ThreadClusterN100;
63
64 static constexpr index_t M101 = M1N1ThreadClusterM101;
65 static constexpr index_t N101 = M1N1ThreadClusterN101;
66
67 static constexpr index_t M11 = M1PerThreadM11;
68 static constexpr index_t N11 = N1PerThreadN11;
69
70 static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
71 static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
72
73 static constexpr index_t M0 = M / M1;
74 static constexpr index_t N0 = N / N1;
75
76 __host__ __device__ static constexpr auto
77 MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
78 {
79 const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
80 AKMBlockDesc{},
85
86 return a_k_m0_m1_block_desc;
87 }
88
89 __host__ __device__ static constexpr auto
90 MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
91 {
92 const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
93 BKNBlockDesc{},
98
99 return b_k_n0_n1_block_desc;
100 }
101
102 __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
103 {
104 // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
105 // lower: [M, N]
106 constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
114
115 return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
116 }
117
118 __host__ __device__ static constexpr auto
120 {
121 // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
122 // lower: [M0, M1, N0, N1]
123 constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
133
134 return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
135 }
136
137 __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
138 {
140 }
141
142 static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
143 static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
144
145 public:
147 : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
149 a_thread_copy_{
150 make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
151 b_thread_copy_{
152 make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
153 {
154 static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
155 "wrong! Desc should be known at compile-time");
156
157 static_assert(BlockSize == M101 * M100 * N101 * N100,
158 "wrong! blocksize and cluster size not consistent");
159
160 static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
161
162 static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
163 "wrong! K dimension not consistent");
164
165 // TODO: remove this restriction
166 static_assert(M0 == 2 && N0 == 2, "wrong");
167 }
168
170 {
171 // lower: [M0, M1, N0, N1]
172 // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
174
175 // lower: [M0, M100, M101, M11, N0, N100, N101, N11]
176 // upper: [Tid, M0, M11, N0, N11]
177 constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
186
187 constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
188
189 return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
190 }
191
192 __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
193
194 __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
195
196 template <typename CM0M1N0N1ThreadDesc,
197 typename ABlockBuffer,
198 typename BBlockBuffer,
199 typename CThreadBuffer>
200 __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
201 const ABlockBuffer& a_block_buf,
202 const BBlockBuffer& b_block_buf,
203 CThreadBuffer& c_thread_buf) const
204 {
205 static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
206 "wrong! Desc should be known at compile-time");
207
208 // TODO: remove this restriction
209 static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
210 CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
211 "wrong");
212
214 a_k_m0_m1_thread_desc_.GetElementSpaceSize());
216 b_k_n0_n1_thread_desc_.GetElementSpaceSize());
217
218 constexpr auto threadwise_gemm =
219 ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
220 FloatB,
221 FloatC,
222 decltype(a_k_m0_m1_thread_desc_),
223 decltype(b_k_n0_n1_thread_desc_),
224 CM0M1N0N1ThreadDesc,
228
229 // read A_sub_0
230 a_thread_copy_.Run(a_k_m0_m1_block_desc_,
231 make_tuple(I0, I0, I0),
232 a_block_buf,
233 a_k_m0_m1_thread_desc_,
234 make_tuple(I0, I0, I0),
235 a_thread_buf);
236
237 // read B_sub_0
238 b_thread_copy_.Run(b_k_n0_n1_block_desc_,
239 make_tuple(I0, I0, I0),
240 b_block_buf,
241 b_k_n0_n1_thread_desc_,
242 make_tuple(I0, I0, I0),
243 b_thread_buf);
244
245 // read B_sub_1
246 b_thread_copy_.Run(b_k_n0_n1_block_desc_,
247 make_tuple(I0, I1, I0),
248 b_block_buf,
249 b_k_n0_n1_thread_desc_,
250 make_tuple(I0, I1, I0),
251 b_thread_buf);
252
253 // read A_sub_1
254 a_thread_copy_.Run(a_k_m0_m1_block_desc_,
255 make_tuple(I0, I1, I0),
256 a_block_buf,
257 a_k_m0_m1_thread_desc_,
258 make_tuple(I0, I1, I0),
259 a_thread_buf);
260
261 // C_sub_00 += transpose(A_sub_0) * B_sub_0
262 threadwise_gemm.Run(a_thread_buf,
263 make_tuple(I0, I0, I0),
264 b_thread_buf,
265 make_tuple(I0, I0, I0),
266 c_thread_buf,
267 make_tuple(I0, I0, I0, I0));
268
269 // C_sub_01 += transpose(A_sub_0) * B_sub_1
270 threadwise_gemm.Run(a_thread_buf,
271 make_tuple(I0, I0, I0),
272 b_thread_buf,
273 make_tuple(I0, I1, I0),
274 c_thread_buf,
275 make_tuple(I0, I0, I1, I0));
276
277 // loop over rest of k
279 // read A_sub_0
280 a_thread_copy_.Run(a_k_m0_m1_block_desc_,
281 make_tuple(k, I0, I0),
282 a_block_buf,
283 a_k_m0_m1_thread_desc_,
284 make_tuple(I0, I0, I0),
285 a_thread_buf);
286
287 // C_sub_10 += transpose(A_sub_1) * B_sub_0
288 threadwise_gemm.Run(a_thread_buf,
289 make_tuple(I0, I1, I0),
290 b_thread_buf,
291 make_tuple(I0, I0, I0),
292 c_thread_buf,
293 make_tuple(I1, I0, I0, I0));
294
295 // read B_sub_0
296 b_thread_copy_.Run(b_k_n0_n1_block_desc_,
297 make_tuple(k, I0, I0),
298 b_block_buf,
299 b_k_n0_n1_thread_desc_,
300 make_tuple(I0, I0, I0),
301 b_thread_buf);
302
303 // C_sub_11 += transpose(A_sub_1) * B_sub_1
304 threadwise_gemm.Run(a_thread_buf,
305 make_tuple(I0, I1, I0),
306 b_thread_buf,
307 make_tuple(I0, I1, I0),
308 c_thread_buf,
309 make_tuple(I1, I0, I1, I0));
310
311 // read B_sub_1
312 b_thread_copy_.Run(b_k_n0_n1_block_desc_,
313 make_tuple(k, I1, I0),
314 b_block_buf,
315 b_k_n0_n1_thread_desc_,
316 make_tuple(I0, I1, I0),
317 b_thread_buf);
318
319 // read A_sub_1
320 a_thread_copy_.Run(a_k_m0_m1_block_desc_,
321 make_tuple(k, I1, I0),
322 a_block_buf,
323 a_k_m0_m1_thread_desc_,
324 make_tuple(I0, I1, I0),
325 a_thread_buf);
326
327 // C_sub_00 += transpose(A_sub_0) * B_sub_0
328 threadwise_gemm.Run(a_thread_buf,
329 make_tuple(I0, I0, I0),
330 b_thread_buf,
331 make_tuple(I0, I0, I0),
332 c_thread_buf,
333 make_tuple(I0, I0, I0, I0));
334
335 // C_sub_01 += transpose(A_sub_0) * B_sub_1
336 threadwise_gemm.Run(a_thread_buf,
337 make_tuple(I0, I0, I0),
338 b_thread_buf,
339 make_tuple(I0, I1, I0),
340 c_thread_buf,
341 make_tuple(I0, I0, I1, I0));
342 });
343
344 // C_sub_10 += transpose(A_sub_1) * B_sub_0
345 threadwise_gemm.Run(a_thread_buf,
346 make_tuple(I0, I1, I0),
347 b_thread_buf,
348 make_tuple(I0, I0, I0),
349 c_thread_buf,
350 make_tuple(I1, I0, I0, I0));
351
352 // C_sub_11 += transpose(A_sub_1) * B_sub_1
353 threadwise_gemm.Run(a_thread_buf,
354 make_tuple(I0, I1, I0),
355 b_thread_buf,
356 make_tuple(I0, I1, I0),
357 c_thread_buf,
358 make_tuple(I1, I0, I1, I0));
359 }
360
361 private:
362 // A[K, M0, M1]
363 static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
365
366 // B[K, N0, N1]
367 static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
369
370 using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
371 FloatA,
372 decltype(a_k_m0_m1_block_desc_),
373 decltype(a_k_m0_m1_thread_desc_),
374 Sequence<KPerThread, 1, M1PerThreadM11>,
375 Sequence<0, 1, 2>,
376 2,
377 AThreadCopyScalarPerVector_M11,
378 1>;
379
380 using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
381 FloatB,
382 decltype(b_k_n0_n1_block_desc_),
383 decltype(b_k_n0_n1_thread_desc_),
384 Sequence<KPerThread, 1, N1PerThreadN11>,
385 Sequence<0, 1, 2>,
386 2,
387 BThreadCopyScalarPerVector_N11,
388 1>;
389
390 CIndex c_thread_origin_data_idx_;
391
392 AThreadCopy a_thread_copy_;
393 BThreadCopy b_thread_copy_;
394};
395
396} // namespace ck
397#endif
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__host__ static __device__ constexpr auto MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
Definition blockwise_gemm_dlops_v2r2.hpp:119
static constexpr auto b_k_n0_n1_block_desc_
Definition blockwise_gemm_dlops_v2r2.hpp:143
__host__ static __device__ constexpr auto MakeAKM0M1BlockDescriptor(const AKMBlockDesc &)
Definition blockwise_gemm_dlops_v2r2.hpp:77
static constexpr auto I2
Definition blockwise_gemm_dlops_v2r2.hpp:54
static __device__ CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
Definition blockwise_gemm_dlops_v2r2.hpp:169
__host__ static __device__ constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
Definition blockwise_gemm_dlops_v2r2.hpp:102
static constexpr auto I0
Definition blockwise_gemm_dlops_v2r2.hpp:52
static constexpr index_t N101
Definition blockwise_gemm_dlops_v2r2.hpp:65
MultiIndex< 4 > CIndex
Definition blockwise_gemm_dlops_v2r2.hpp:50
static constexpr index_t M11
Definition blockwise_gemm_dlops_v2r2.hpp:67
static constexpr index_t N11
Definition blockwise_gemm_dlops_v2r2.hpp:68
static constexpr index_t N
Definition blockwise_gemm_dlops_v2r2.hpp:59
static constexpr index_t N1
Definition blockwise_gemm_dlops_v2r2.hpp:71
static constexpr index_t N100
Definition blockwise_gemm_dlops_v2r2.hpp:62
static constexpr index_t M0
Definition blockwise_gemm_dlops_v2r2.hpp:73
static constexpr auto I1
Definition blockwise_gemm_dlops_v2r2.hpp:53
__device__ void Run(const CM0M1N0N1ThreadDesc &, const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_dlops_v2r2.hpp:200
static constexpr index_t K
Definition blockwise_gemm_dlops_v2r2.hpp:57
static constexpr index_t M100
Definition blockwise_gemm_dlops_v2r2.hpp:61
static constexpr auto a_k_m0_m1_block_desc_
Definition blockwise_gemm_dlops_v2r2.hpp:142
MultiIndex< 3 > BIndex
Definition blockwise_gemm_dlops_v2r2.hpp:49
static constexpr index_t N0
Definition blockwise_gemm_dlops_v2r2.hpp:74
__host__ static __device__ constexpr auto GetCM0M1N0N1ThreadTensorLengths()
Definition blockwise_gemm_dlops_v2r2.hpp:137
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
Definition blockwise_gemm_dlops_v2r2.hpp:146
MultiIndex< 3 > AIndex
Definition blockwise_gemm_dlops_v2r2.hpp:48
__host__ static __device__ constexpr auto GetBBlockAlignment()
Definition blockwise_gemm_dlops_v2r2.hpp:194
__host__ static __device__ constexpr index_t GetABlockAlignment()
Definition blockwise_gemm_dlops_v2r2.hpp:192
__host__ static __device__ constexpr auto MakeBKN0N1BlockDescriptor(const BKNBlockDesc &)
Definition blockwise_gemm_dlops_v2r2.hpp:90
static constexpr auto I3
Definition blockwise_gemm_dlops_v2r2.hpp:55
static constexpr index_t M101
Definition blockwise_gemm_dlops_v2r2.hpp:64
static constexpr index_t M1
Definition blockwise_gemm_dlops_v2r2.hpp:70
static constexpr index_t M
Definition blockwise_gemm_dlops_v2r2.hpp:58
Definition utility/sequence.hpp:43
Definition functional2.hpp:33