blockwise_gemm_pipeline_xdlops_v1_mx.hpp Source File

blockwise_gemm_pipeline_xdlops_v1_mx.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v1_mx.hpp Source File
blockwise_gemm_pipeline_xdlops_v1_mx.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 1
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::I0;
120 using Base::I1;
121 using Base::KRepeat;
122 using Base::MWaves;
123 using Base::NWaves;
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
126
135 using Base::GetWaveIdx;
138
141
142 using Base::AMmaKStride;
143 using Base::APackedSize;
144 using Base::BMmaKStride;
145 using Base::BPackedSize;
146 using Base::KThreadChunk;
147
148 using Base::KXdlPack;
149 using Base::MXdlPack;
150 using Base::NXdlPack;
151
152 using AccType = typename Base::AccType;
153 using Tuple5 = typename Base::Tuple5;
156
157 static constexpr index_t PrefetchStages = 1;
158 static constexpr index_t PrefillStages = 1;
159 static constexpr index_t GlobalBufferNum = 1;
160
161 static constexpr auto ScalesPerKBlockSize =
162 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
163
164 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
165 static constexpr auto AScalesPerXdlopsRun =
166 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
167 static constexpr auto BScalesPerXdlopsRun =
168 (BPackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
169
170 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
171 static constexpr auto ScalesPerXdlopsRunPerThreadA =
172 AScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
173 static constexpr auto ScalesPerXdlopsRunPerThreadB =
174 BScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
175
177 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
178 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
179 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
180 "A scale pack data type too large!");
181 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
182 "B scale pack data type too large!");
185
186 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
187 {
188 return num_loop > PrefetchStages;
189 }
190
191 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
192 {
193 ignore = num_loop;
194 return TailNumber::Full;
195 }
196
197 template <bool HasMainLoop,
198 TailNumber TailNum,
199 typename AGridDesc,
200 typename ABlockDesc,
201 typename ABlockTransfer,
202 typename AGridBuffer,
203 typename ABlockBuffer,
204 typename ABlockTransferStep,
205 typename BGridDesc,
206 typename BBlockDesc,
207 typename BBlockTransfer,
208 typename BGridBuffer,
209 typename BBlockBuffer,
210 typename BBlockTransferStep,
211 typename CThreadBuffer,
212 typename AScaleGridBuffer,
213 typename AScaleGridDesc,
214 typename AScaleThreadTransfer,
215 typename BScaleGridBuffer,
216 typename BScaleGridDesc,
217 typename BScaleThreadTransfer>
218 __device__ void Run(
219 // ABlockCopy
220 const AGridDesc& a_grid_desc,
221 const ABlockDesc& a_block_desc,
222 ABlockTransfer& a_blockwise_copy,
223 const AGridBuffer& a_grid_buf,
224 ABlockBuffer& a_block_buf,
225 const ABlockTransferStep& a_block_copy_step,
226 // BBlockCopy
227 const BGridDesc& b_grid_desc,
228 const BBlockDesc& b_block_desc,
229 BBlockTransfer& b_blockwise_copy,
230 const BGridBuffer& b_grid_buf,
231 BBlockBuffer& b_block_buf,
232 const BBlockTransferStep& b_block_copy_step,
233 // CThread
234 CThreadBuffer& c_thread_buf,
235 // A and B scales
236 const AScaleGridDesc& a_scale_grid_desc,
237 AScaleThreadTransfer& a_scale_thread_copy,
238 const AScaleGridBuffer& a_scale_grid_buf,
239 const BScaleGridDesc& b_scale_grid_desc,
240 BScaleThreadTransfer& b_scale_thread_copy,
241 const BScaleGridBuffer& b_scale_grid_buf,
242 index_t num_loop) const
243 {
245 a_thread_desc_.GetElementSpaceSize());
247 b_thread_desc_.GetElementSpaceSize());
248
250 a_scale_thread_desc.GetElementSpaceSize());
251
253 b_scale_thread_desc.GetElementSpaceSize());
254
255 // Global prefetch 1
256 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
257 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
258
259 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
260 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
261
262 // Prefetch a_scales
263 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
264 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
265 a_scale_thread_copy.Run(a_scale_grid_desc,
266 a_scale_grid_buf,
268 make_tuple(m0, k0, I0),
269 a_scale_thread_buf);
270
271 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
272 make_multi_index(0, I1, 0));
273 });
274 a_scale_thread_copy.MoveSrcSliceWindow(
275 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
276 });
277
278 // restore row id and advance to the next set of scales
279 a_scale_thread_copy.MoveSrcSliceWindow(
280 a_scale_grid_desc,
281 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
282
283 // Prefetch b_scales
284 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
285 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
286 b_scale_thread_copy.Run(b_scale_grid_desc,
287 b_scale_grid_buf,
289 make_tuple(n0, k0, I0),
290 b_scale_thread_buf);
291
292 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
293 make_multi_index(0, I1, 0));
294 });
295 b_scale_thread_copy.MoveSrcSliceWindow(
296 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
297 });
298
299 // restore col id and advance to the next set of scales
300 // NWaves * NPerXDL * NRepeat == NPerBlock
301 b_scale_thread_copy.MoveSrcSliceWindow(
302 b_scale_grid_desc,
303 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
304
305 // Local prefill 1
306 __builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT, LDS, GDS, Constant and Message
308
309 // Initialize C
310 c_thread_buf.Clear();
311
312 // main body
313 if constexpr(HasMainLoop)
314 {
315 // loop over k with the step KPerBlock
316 index_t i = 0;
317 do
318 {
319 // -------------------------------------------------------------------------------------------
320
321 // wait previous blockwise copy to finish
322
323 // k indexes mapping to threads for 32x32x64:
324 // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
325 // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
326 // k = 0 k = 1
327
328 // k indexes mapping to threads for 16x16x128:
329 // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
330 // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
331 // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
332 // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
333 // k = 0 k = 1
334 static_for<0, KRepeat, 1>{}([&](auto k) {
335 constexpr auto k_step =
336 k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
337
338 static_for<0, MRepeat, 1>{}([&](auto m0) {
339 static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
340 [&](auto chunk) {
341 constexpr auto a_k_step_chunk =
342 k_step +
343 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
346 I0,
348 I0,
350 a_block_buf,
353 I0,
355 k,
357 a_thread_buf);
358 });
359 });
360 static_for<0, NRepeat, 1>{}([&](auto n0) {
361 // read block data in chunks to assemble correct thread vectors
362 static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
363 [&](auto chunk) {
364 constexpr auto b_k_step_chunk =
365 k_step +
366 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
369 I0,
371 I0,
373 b_block_buf,
376 I0,
378 k,
380 b_thread_buf);
381 });
382 });
383 });
384
385 // load for next k loop
387 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
388 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
389 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
390 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
391
392 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
393 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
394 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
395 constexpr index_t a_scale_offset =
396 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
397 constexpr index_t b_scale_offset =
398 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
399
400 static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
402 "Must have at least one scale per Xdlops per Thread.");
403
406
407 // Pack scale_thread_buf into scale_thread_vec
409 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
410 a_scale_thread_buf[Number<a_scale_offset + s>{}];
411 });
413 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
414 b_scale_thread_buf[Number<b_scale_offset + s>{}];
415 });
416
417 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
418 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
419 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
420 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
421
424
425 static_for<0, KPack, 1>{}([&](auto ik) {
426 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
427 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
428 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
429 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
430 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
431 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
432 });
433
434 using mfma_input_type_a = typename vector_type< //
436 xdlops_gemm.K1PerXdlops / APackedSize>::type;
437 using mfma_input_type_b = typename vector_type< //
439 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
440
441 using mfma_scale_input_type_a = typename vector_type< //
442 AScaleDataType,
444 using mfma_scale_input_type_b = typename vector_type< //
445 BScaleDataType,
447
448 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
449 make_tuple(m0, n0, imxdl, inxdl, 0));
450
451 // MFMA accumulation
452 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
453 ikxdl * NXdlPack + inxdl>(
454 a_thread_vec.template AsType<mfma_input_type_a>(),
455 a_scale_thread_vec
456 .template AsType<mfma_scale_input_type_a>(),
457 b_thread_vec.template AsType<mfma_input_type_b>(),
458 b_scale_thread_vec
459 .template AsType<mfma_scale_input_type_b>(),
460 c_thread_buf.GetVectorTypeReference(
462 });
463 });
464 });
465 });
466 });
467 });
468
469 // Prefetch a_scales
470 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
471 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
472 a_scale_thread_copy.Run(a_scale_grid_desc,
473 a_scale_grid_buf,
475 make_tuple(m0, k0, I0),
476 a_scale_thread_buf);
477
478 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
479 make_multi_index(0, I1, 0));
480 });
481 a_scale_thread_copy.MoveSrcSliceWindow(
482 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
483 });
484
485 // restore row id and advance to the next set of scales
486 a_scale_thread_copy.MoveSrcSliceWindow(
487 a_scale_grid_desc,
488 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
489
490 // Prefetch b_scales
491 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
492 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
493 b_scale_thread_copy.Run(b_scale_grid_desc,
494 b_scale_grid_buf,
496 make_tuple(n0, k0, I0),
497 b_scale_thread_buf);
498
499 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
500 make_multi_index(0, I1, 0));
501 });
502 b_scale_thread_copy.MoveSrcSliceWindow(
503 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
504 });
505
506 // restore col id and advance to the next set of scales
507 // NWaves * NPerXDL * NRepeat == NPerBlock
508 b_scale_thread_copy.MoveSrcSliceWindow(
509 b_scale_grid_desc,
510 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
511
512 __builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT and LGKM_CNT
514
515 i += 1;
516 } while(i < (num_loop - 1));
517 }
518
519 // tail
520 if constexpr(TailNum == TailNumber::Full)
521 {
522 static_for<0, KRepeat, 1>{}([&](auto k) {
523 constexpr auto k_step =
524 k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
525
526 static_for<0, MRepeat, 1>{}([&](auto m0) {
527 static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
528 [&](auto chunk) {
529 constexpr auto a_k_step_chunk =
530 k_step +
531 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
534 I0,
536 I0,
538 a_block_buf,
541 I0,
543 k,
545 a_thread_buf);
546 });
547 });
548 static_for<0, NRepeat, 1>{}([&](auto n0) {
549 // read block data in chunks to assemble correct thread vectors
550 static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
551 [&](auto chunk) {
552 constexpr auto b_k_step_chunk =
553 k_step +
554 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
557 I0,
559 I0,
561 b_block_buf,
564 I0,
566 k,
568 b_thread_buf);
569 });
570 });
571 });
572 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
573 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
574 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
575 constexpr index_t a_scale_offset =
576 a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
577 constexpr index_t b_scale_offset =
578 b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
579
580 static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
582 "Must have at least one scale per Xdlops per Thread.");
583
586
587 // Pack scale_thread_buf into scale_thread_vec
589 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
590 a_scale_thread_buf[Number<a_scale_offset + s>{}];
591 });
593 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
594 b_scale_thread_buf[Number<b_scale_offset + s>{}];
595 });
596
597 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
598 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
599 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
600 constexpr auto kxdl = ikxdl + k0 * KXdlPack;
601
604
605 static_for<0, KPack, 1>{}([&](auto ik) {
606 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
607 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
608 make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
609 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
610 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
611 make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
612 });
613
614 using mfma_input_type_a = typename vector_type< //
616 xdlops_gemm.K1PerXdlops / APackedSize>::type;
617 using mfma_input_type_b = typename vector_type< //
619 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
620
621 using mfma_scale_input_type_a = typename vector_type< //
622 AScaleDataType,
624 using mfma_scale_input_type_b = typename vector_type< //
625 BScaleDataType,
627
628 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
629 make_tuple(m0, n0, imxdl, inxdl, 0));
630
631 // MFMA accumulation
632 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
633 ikxdl * NXdlPack + inxdl>(
634 a_thread_vec.template AsType<mfma_input_type_a>(),
635 a_scale_thread_vec
636 .template AsType<mfma_scale_input_type_a>(),
637 b_thread_vec.template AsType<mfma_input_type_b>(),
638 b_scale_thread_vec
639 .template AsType<mfma_scale_input_type_b>(),
640 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
641 });
642 });
643 });
644 });
645 });
646 });
647 }
648 }
649
650 // TODO: make this field protected when a_scale_thread_copy_ is moved
651 // here
654 Number<KRepeat / KXdlPack>{},
656
657 // TODO: make this field protected when b_scale_thread_copy_ is moved
658 // here
661 Number<KRepeat / KXdlPack>{},
663
664 protected:
665 using Base::a_thread_copy_;
666 using Base::a_thread_desc_;
667 using Base::b_thread_copy_;
668 using Base::b_thread_desc_;
669 using Base::c_thread_desc_;
670};
671
672} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v1_mx.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v1_mx.hpp:218
Definition blockwise_gemm_pipeline_xdlops_v1_mx.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10