block_to_ctile_map.hpp Source File

block_to_ctile_map.hpp Source File#

Composable Kernel: block_to_ctile_map.hpp Source File
block_to_ctile_map.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/math.hpp"
11#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
12#include <limits>
13#include <stdlib.h>
14#endif
15
16namespace ck {
17
18// Rows of column-vectors
19template <index_t MPerBlock,
20 index_t NPerBlock,
21 typename CGridDesc_M_N,
22 bool DeviceCTileIndexCheck = false>
24{
25 static constexpr auto I0 = Number<0>{};
26 static constexpr auto I1 = Number<1>{};
27 static constexpr auto I2 = Number<2>{};
28 static constexpr auto I3 = Number<3>{};
29
30 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
31
32 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
33 index_t M01 = 1)
34 : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
35 {
36 }
37
38 __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
39 {
40 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
41 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
42
43 const auto M00 = math::integer_divide_ceil(M0, M01_);
44
45 const index_t grid_size = M00 * M01_ * N0;
46
47 return grid_size;
48 }
49
50 template <typename TopIdx>
51 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
52 {
53 return underlying_map_.CalculateBottomIndex(idx_top);
54 }
55
56 template <typename CTileIdx, typename CTileDim>
57 __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
58 const CTileDim& c_tile_dim) const
59 {
60 if constexpr(DeviceCTileIndexCheck)
61 return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
62 else
63 return true;
64 }
65
66 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
67 {
68 if constexpr(DeviceCTileIndexCheck)
69 return true; // validity check moved to kernel
70
71 const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
72 if(M0 % M01_ == 0)
73 {
74 return true;
75 }
76 else
77 {
78 return false;
79 }
80 }
81
82 private:
83 __host__ __device__ static constexpr auto
84 GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
85 {
86 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
87 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
88
89 const auto M00 = math::integer_divide_ceil(M0, M01);
90
91 const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
95 make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
96 make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
97
98 const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
100 make_tuple(Sequence<0, 1, 2, 3>{}),
101 make_tuple(Sequence<0>{}));
102
103 const auto cblockid_to_m0_n0_block_cluster_adaptor =
104 chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
105 cblockid_to_m00_n0_m01_block_cluster_adaptor);
106
107 return cblockid_to_m0_n0_block_cluster_adaptor;
108 }
109
110 index_t M01_;
111 using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
112 UnderlyingMap underlying_map_;
113};
114
115// Rows of column-vectors
116// This C-tile map dynamically adjusts M01 when C-tile index is out of range
117template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
119
120template <index_t MPerBlock, index_t NPerBlock>
121struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
122{
123 static constexpr auto I0 = Number<0>{};
124 static constexpr auto I1 = Number<1>{};
125
126 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
127
128 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
129 const BlockToCTileMap_M00_N0_M01Adapt&) = default;
130 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
132 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
134 __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
136
137 __host__
138 __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
139 : M_(M), N_(N), M01_(M01)
140 {
141#if 0
143 printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
144 }
145#endif
146 }
147
148 template <typename CGridDesc_M_N>
149 __host__
150 __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
151 index_t M01 = 8)
153 c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
154 {
155 }
156
157 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
158 {
159 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
160 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
161
162 return M0 * N0;
163 }
164
165 template <typename CGridDesc_M_N>
166 __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
167 {
168 return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
169 }
170
171 template <typename CGridDesc_M_N>
172 __host__ __device__ constexpr bool
173 CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
174 {
175 return true;
176 }
177
178 template <typename TopIdx>
179 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
180 {
181 auto block_1d_id = idx_top[I0];
182
183 const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
184 const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
185
186 block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
187
188 index_t idx_N0 = block_1d_id % N0;
189 index_t idx_M0 = block_1d_id / N0;
190
191 const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
192
193 index_t idx_M00 = idx_M0 / M01_;
194 index_t idx_M01 = idx_M0 % M01_;
195 index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
196
240
241 return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
242 idx_N0_M01_local / M01_adapt);
243 }
244
245 template <typename CTileIdx, typename CTileDim>
246 __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
247 const CTileDim& /* c_tile_dim */) const
248 {
249 return true; // always valid provided that user gets grid size from CalculateGridSize()
250 }
251
252 private:
253 index_t M_;
254 index_t N_;
255 index_t M01_;
256};
257
258// keep the redundant type argument for backward compatibility
259template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
261{
262 using BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>::
263 BlockToCTileMap_M00_N0_M01Adapt;
264};
265
266// Grouped Rows of column-vectors WGP mapping
267// Optimized for gfx94x-like multipe-die chip
268
269template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
271{
272 static constexpr auto I0 = Number<0>{};
273 static constexpr auto I1 = Number<1>{};
274
275 __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
277 index_t N,
278 index_t M01 = 8)
279 : M_(M), N_(N), M01_(M01)
280 {
281 }
282
283 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
284 {
285 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
286 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
287
288 return M0 * N0;
289 }
290
291 template <typename CGridDesc_M_N>
292 __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
293 {
294 return true;
295 }
296
297 template <typename TopIdx>
298 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
299 {
300 auto block_1d_id = idx_top[I0];
301
302 const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
303 const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
304
305 if(M0 == 1)
306 {
307 return make_tuple(0, block_1d_id);
308 }
309 else if(N0 == 1)
310 {
311 return make_tuple(block_1d_id, 0);
312 }
313 // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
314 else
315 {
316 const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
317 const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
318 auto group_id_x = block_1d_id % GroupNum;
319 auto group_id_y = block_1d_id / GroupNum;
320 auto remap_block_1d_id =
321 group_id_x <= big_group_num
322 ? group_id_x * group_size + group_id_y
323 : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
324
325 index_t idx_N0 = remap_block_1d_id % N0;
326 index_t idx_M0 = remap_block_1d_id / N0;
327
328 const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
329
330 index_t idx_M00 = idx_M0 / M01_;
331 index_t idx_M01 = idx_M0 % M01_;
332 index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
333
377
378 return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
379 idx_N0_M01_local / M01_adapt);
380 }
381 }
382
383 template <typename CTileIdx, typename CTileDim>
384 __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
385 const CTileDim& /* c_tile_dim */) const
386 {
387 return true; // always valid provided that user gets grid size from CalculateGridSize()
388 }
389
390 private:
391 index_t M_;
392 index_t N_;
393 index_t M01_;
394};
395
396// columns of row-vectors
397// This C-tile map dynamically adjusts N01 when C-tile index is out of range
398template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
400
401template <index_t MPerBlock, index_t NPerBlock>
402struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
403{
404 static constexpr auto I0 = Number<0>{};
405 static constexpr auto I1 = Number<1>{};
406
407 __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
408
410 default;
412 default;
413 __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
415 __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
417
418 __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
419 : M_(M), N_(N), N01_(N01)
420 {
421#if 0
423 printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
424 }
425#endif
426 }
427
428 template <typename CGridDesc_M_N>
429 __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
430 index_t N01 = 8)
432 c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
433 {
434 }
435
436 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
437 {
438 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
439 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
440
441 return M0 * N0;
442 }
443
444 template <typename CGridDesc_M_N>
445 __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
446 {
447 return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
448 }
449
450 template <typename CGridDesc_M_N>
451 __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
452 {
453 return true;
454 }
455
456 template <typename TopIdx>
457 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
458 {
459 auto block_1d_id = idx_top[I0];
460
461 const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
462 const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
463
464 block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
465
466 index_t idx_M0 = block_1d_id % M0;
467 index_t idx_N0 = block_1d_id / M0;
468
469 const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
470
471 index_t idx_N00 = idx_N0 / N01_;
472 index_t idx_N01 = idx_N0 % N01_;
473 index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
474
519
520 return make_tuple(idx_M0_N01_local / N01_adapt,
521 idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
522 }
523
524 template <typename CTileIdx, typename CTileDim>
525 __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
526 const CTileDim& /* c_tile_dim */) const
527 {
528 return true; // always valid provided that user gets grid size from CalculateGridSize()
529 }
530
531 private:
532 index_t M_;
533 index_t N_;
534 index_t N01_;
535};
536
537// 2D slices of column-vectors in 3D space
538// This C-tile map dynamically adjusts M01 when C-tile index is out of range
539template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
541{
542 static constexpr auto I0 = Number<0>{};
543 static constexpr auto I1 = Number<1>{};
544 static constexpr auto I2 = Number<2>{};
545 static constexpr auto I3 = Number<3>{};
546
547 __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
548
549 __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
550 index_t M01 = 8,
551 index_t KSplit = 1)
552 : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
553 {
554 }
555
556 __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
557 {
558 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
559 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
560
561 const index_t grid_size = M0 * N0 * KSplit_;
562
563 return grid_size;
564 }
565
566 template <typename TopIdx>
567 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
568 {
569 auto block_1d_id = idx_top[I0];
570
571 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
572 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
573
574 block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
575
576 const index_t idx_ksplit = block_1d_id / (M0 * N0);
577 block_1d_id = block_1d_id % (M0 * N0);
578
579 index_t idx_N0 = block_1d_id % N0;
580 index_t idx_M0 = block_1d_id / N0;
581
582 const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
583
584 index_t idx_M00 = idx_M0 / M01_;
585 index_t idx_M01 = idx_M0 % M01_;
586 index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
587
588 return make_tuple(idx_ksplit,
589 idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
590 idx_N0_M01_local / M01_adapt);
591 }
592
593 template <typename CTileIdx, typename CTileDim>
594 __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
595 const CTileDim& /* c_tile_dim */) const
596 {
597 return true; // always valid provided that user gets grid size from CalculateGridSize()
598 }
599
600 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
601 {
602 return true;
603 }
604
605 private:
606 index_t M01_;
607 index_t KSplit_;
608 CGridDesc_M_N c_grid_desc_m_n_;
609};
610
611// Blocks of row-vectors
612template <index_t MPerBlock,
613 index_t NPerBlock,
614 typename CGridDesc_M_N,
615 bool DeviceCTileIndexCheck = false>
617{
618 static constexpr auto I0 = Number<0>{};
619 static constexpr auto I1 = Number<1>{};
620 static constexpr auto I2 = Number<2>{};
621 static constexpr auto I3 = Number<3>{};
622
623 __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default;
624
625 __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
626 index_t M01 = 1,
627 index_t N01 = 1)
628 : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01))
629 {
630 }
631
632 __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
633 {
634 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
635 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
636
637 const auto M00 = math::integer_divide_ceil(M0, M01_);
638 const auto N00 = math::integer_divide_ceil(N0, N01_);
639
640 const index_t grid_size = M00 * M01_ * N00 * N01_;
641
642 return grid_size;
643 }
644
645 template <typename TopIdx>
646 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
647 {
648 return underlying_map_.CalculateBottomIndex(idx_top);
649 }
650
651 template <typename CTileIdx, typename CTileDim>
652 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
653 const CTileDim& c_tile_dim) const
654 {
655 if constexpr(DeviceCTileIndexCheck)
656 return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
657 else
658 return true;
659 }
660
661 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
662 {
663 if constexpr(DeviceCTileIndexCheck)
664 return true; // validity check moved to kernel
665
666 const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
667 const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
668 if(M0 % M01_ == 0 && N0 % N01_ == 0)
669 {
670 return true;
671 }
672 else
673 {
674 return false;
675 }
676 }
677
678 private:
679 __host__ __device__ static constexpr auto
680 GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
681 {
682 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
683 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
684
685 const auto M00 = math::integer_divide_ceil(M0, M01);
686 const auto N00 = math::integer_divide_ceil(N0, N01);
687
688 const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
690 make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
693 make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
694 make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
695
696 const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
698 make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))),
699 make_tuple(Sequence<0, 1, 2, 3, 4>{}),
700 make_tuple(Sequence<0>{}));
701
702 const auto cblockid_to_m0_n0_block_cluster_adaptor =
703 chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
704 cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
705
706 return cblockid_to_m0_n0_block_cluster_adaptor;
707 }
708
709 index_t M01_, N01_;
710 using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1));
711 UnderlyingMap underlying_map_;
712};
713
714// 2D slices of row-vectors in 3D space
715template <index_t MPerBlock,
716 index_t NPerBlock,
717 typename CGridDesc_M_N,
718 bool DeviceCTileIndexCheck = false>
720{
721 static constexpr auto I0 = Number<0>{};
722 static constexpr auto I1 = Number<1>{};
723 static constexpr auto I2 = Number<2>{};
724 static constexpr auto I3 = Number<3>{};
725
727
728 __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
729 index_t M01 = 1,
730 index_t N01 = 1,
731 index_t KSplit = 1)
732 : c_grid_desc_m_n_(c_grid_desc_m_n),
733 M01_(M01),
734 N01_(N01),
735 KSplit_(KSplit),
736 underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
737 {
738 }
739
740 __host__ __device__ constexpr index_t
741 CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
742 {
743 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
744 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
745
746 const auto M00 = math::integer_divide_ceil(M0, M01_);
747 const auto N00 = math::integer_divide_ceil(N0, N01_);
748
749 const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_;
750
751 return grid_size;
752 }
753
754 template <typename TopIdx>
755 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
756 {
757 static_assert(TopIdx::Size() == 1);
758
759 return underlying_map_.CalculateBottomIndex(
761 }
762
763 template <typename CTileIdx, typename CTileDim>
764 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
765 const CTileDim& c_tile_dim) const
766 {
767 if constexpr(DeviceCTileIndexCheck)
768 return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
769 else
770 return true;
771 }
772
773 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
774 {
775 if constexpr(DeviceCTileIndexCheck)
776 return true; // validity check moved to kernel
777
778 const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
779 const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
780 if(M0 % M01_ == 0 && N0 % N01_ == 0)
781 {
782 return true;
783 }
784 else
785 {
786 return false;
787 }
788 }
789
790 private:
791 __device__ constexpr index_t CalculateGridSize() const
792 {
793 return CalculateGridSize(c_grid_desc_m_n_);
794 }
795
796 __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
797 index_t M01,
798 index_t N01,
799 index_t KSplit)
800 {
801 const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
802 const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
803
804 const auto M00 = math::integer_divide_ceil(M0, M01);
805 const auto N00 = math::integer_divide_ceil(N0, N01);
806
807 const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
812 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
813 make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
814
815 const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
817 make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))),
818 make_tuple(Sequence<0, 1, 2, 3, 4>{}),
819 make_tuple(Sequence<0>{}));
820
821 const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
822 chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
823 c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor);
824
825 return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
826 }
827
828 CGridDesc_M_N c_grid_desc_m_n_;
829 index_t M01_, N01_, KSplit_;
830 using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
831 UnderlyingMap underlying_map_;
832};
833
834template <typename CTileIdx, typename CTileDim>
835__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
836 const CTileDim& c_tile_dim)
837{
838 bool is_valid = false;
839
840 const index_t m_block = c_tile_dim[Number<0>{}];
841 const index_t n_block = c_tile_dim[Number<1>{}];
842
843 if constexpr(CTileIdx::Size() == 2)
844 {
845 const index_t m_block_idx = c_tile_idx[Number<0>{}];
846 const index_t n_block_idx = c_tile_idx[Number<1>{}];
847 if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
848 {
849 is_valid = true;
850 }
851 }
852 else if constexpr(CTileIdx::Size() == 3)
853 {
854 const index_t ksplit_idx = c_tile_idx[Number<0>{}];
855 const index_t m_block_idx = c_tile_idx[Number<1>{}];
856 const index_t n_block_idx = c_tile_idx[Number<2>{}];
857 if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
858 {
859 is_valid = true;
860 }
861 ignore = ksplit_idx;
862 }
863
864 return is_valid;
865}
866
867// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
868// workgroups assigned to a given gemm problem have top index offsetted to range [0,
869// grid_size_per_gemm]
870template <typename UnderlyingBlockToCTileMap>
872{
873 using underlying_type = UnderlyingBlockToCTileMap;
874
875 __host__ __device__ OffsettedBlockToCTileMap() = default;
876 __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
877 index_t block_start)
878 {
879 block_to_ctile_map_ = block_to_ctile_map;
880 block_start_ = block_start;
881 }
882
883 template <typename TopIdx>
884 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
885 {
886 return block_to_ctile_map_.CalculateBottomIndex(
888 }
889
890 template <typename CTileIdx, typename CTileDim>
891 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
892 const CTileDim& c_tile_dim) const
893 {
894 return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
895 }
896
897 template <typename CGridDesc_M_N>
898 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
899 {
900 return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
901 }
902
903 template <typename CGridDesc_M_N>
904 __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
905 {
906 return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
907 }
908
909 __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
910 {
911 return block_to_ctile_map_.CalculateGridSize(M, N);
912 }
913
914 UnderlyingBlockToCTileMap block_to_ctile_map_;
916};
917// second version with 2 offsets
918template <typename UnderlyingBlockToCTileMap>
920{
921 using underlying_type = UnderlyingBlockToCTileMap;
922
923 __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map,
924 index_t group_offset,
925 index_t tile_offset)
926 : block_to_ctile_map_{block_to_ctile_map},
927 group_offset_{group_offset},
928 tile_offset_{tile_offset}
929 {
930 }
931
932 template <typename TopIdx>
933 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
934 {
935 return block_to_ctile_map_.CalculateBottomIndex(
937 }
938
939 template <typename CTileIdx, typename CTileDim>
940 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
941 const CTileDim& c_tile_dim) const
942 {
943 return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
944 }
945
946 template <typename CGridDesc_M_N>
947 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
948 {
949 return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
950 }
951
952 __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
953 {
954 return block_to_ctile_map_.CalculateGridSize(M, N);
955 }
956
957 __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
958 UnderlyingBlockToCTileMap block_to_ctile_map_;
961};
962
975template <index_t MPerBlock, index_t NPerBlock>
977{
978
979 __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
980
981 __host__ __device__ constexpr auto
983 {
984 // Create 3D grid
985 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
986 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
987 return make_tuple(N0, M0, k_split);
988 }
989
990 template <typename TopIdx>
991 __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
992 {
993 return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
994 }
995
996 template <typename CTileIdx, typename CTileDim>
997 __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
998 const CTileDim& /* c_tile_dim */) const
999 {
1000 return true; // always valid provided that user gets grid size from CalculateGridSize()
1001 }
1002
1003 template <typename CGridDesc_M_N>
1004 __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
1005 {
1006 return true;
1007 }
1008};
1009
1011{
1012 Atomic = 0, // sk block use atomic to do reduction
1013 Reduction, // let some workgroup responsible for doing the reduction operation
1014};
1015
1016template <uint32_t MPerBlock_,
1017 uint32_t NPerBlock_,
1018 uint32_t KPerBlock_,
1020 uint32_t TileSwizzleSubM_ = 8>
1022{
1024 static constexpr uint32_t MPerBlock = MPerBlock_;
1025 static constexpr uint32_t NPerBlock = NPerBlock_;
1026 static constexpr uint32_t KPerBlock = KPerBlock_;
1027 static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
1028 static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1029
1030 //--------------------------------------
1031 // pass to device
1039 MDiv eqav_tiles_big; // for reduction
1040 MDiv eqav_tiles_little; // for reduction
1041
1042 // MDiv tile_swizzle_sub_m_rem;
1043 //--------------------------------------
1044
1045 // prefer construct on host
1047 uint32_t n,
1048 uint32_t k,
1049 uint32_t num_cu,
1050 uint32_t occupancy,
1051 uint32_t sk_blocks = 0xffffffff)
1052 {
1053 uint32_t num_tiles =
1056
1057 // one cu can hold one wg at one time, from the whole chip's point of view
1058 // if number of wg is same as num_cu, we call it 1 dispatch
1059 // if number of wg is 2x num_cu, we call it 2 dispatches.
1060 // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
1061 // dispatch)
1062 //
1063 uint32_t full_dispatches = num_tiles / num_cu;
1064 uint32_t full_dispatch_tiles = full_dispatches * num_cu;
1065 uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
1066
1067 uint32_t sk_occupancy = occupancy;
1068 uint32_t dp_tiles = full_dispatch_tiles;
1069 uint32_t sk_tiles = partial_dispatche_tiles;
1070
1071 if(full_dispatches < occupancy)
1072 {
1073 // in this case, we allocate all blocks as sk blocks
1074 // sk_occupancy = occupancy - full_dispatches;
1075 sk_occupancy = 1; // TODO: single occ seems better
1076 dp_tiles = full_dispatch_tiles;
1077 sk_tiles = partial_dispatche_tiles;
1078 }
1079 else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
1080 {
1081 // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
1082 // occupancy = 3, full_dispatches = 5, 8, 11 ...
1083 // occupancy = 4, full_dispatches = 7, 11 ...
1084 sk_occupancy = 1; // left 1 slot for sk occupancy
1085 dp_tiles = full_dispatch_tiles;
1086 sk_tiles = partial_dispatche_tiles;
1087 }
1088 else
1089 {
1090 // others, we reduce 1 dispatch from dp, together with partial dispatch,
1091 // to construct sk dispatch
1092 sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
1093 dp_tiles = full_dispatch_tiles - num_cu;
1094 sk_tiles = partial_dispatche_tiles + num_cu;
1095 }
1096
1097 // uint32_t dp_iters_per_block = k_iters_per_tile.get();
1098 uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1099 uint32_t dp_num_blocks = 0;
1100
1101 {
1102 uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
1103 uint32_t max_sk_tiles =
1104 (sk_tiles >= num_cu) ? num_cu * sk_occupancy
1105 : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
1106
1107 // if use dp for sk-block, how many iters do we need
1108 uint32_t dp_for_sk_iters = k_iters_per_tile.get();
1109
1110 uint32_t best_sk_score =
1111 NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
1112 for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
1113 tentative_sk_blocks++)
1114 {
1115 uint32_t tentative_sk_iters_per_block =
1116 (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
1117 uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
1118 uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
1119
1120 // TODO: carefully adjust this parameter
1121 // the more sk_blocks_per_tile, the worse the overhead
1122 uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
1123 if(tentative_sk_blocks % sk_tiles != 0)
1124 {
1125 // penalty for uneven divide
1126 cross_sk_blocks_overhead +=
1127 sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
1128 }
1129
1130 uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
1131
1132 if(tentative_sk_score < best_sk_score)
1133 {
1134 best_sk_score = tentative_sk_score;
1135 sk_num_blocks = tentative_sk_blocks;
1136 }
1137 }
1138
1139 if(best_sk_score >= dp_for_sk_iters)
1140 {
1141 sk_num_blocks = 0;
1142 }
1143
1144 // give a chance to control num of sk blocks
1145 sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
1146
1147 if(sk_num_blocks == 0)
1148 {
1151
1152 dp_num_blocks = num_tiles; // all tile to be dp block
1154 sk_total_iters = 0; // clear this tiles
1155 }
1156 else
1157 {
1158 // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1159 // we need to decide how many iters for each sk block
1160 // let m = k_iters_per_sk_block
1161 // some of the sk block (little) will cover m iters, some (big) will cover m+1
1162 // we have
1163 // 1) l + b = sk_blocks
1164 // 2) l * m + b * (m + 1) = sk_total_iters
1165 // => (l + b) * m + b = sk_total_iters
1166 // => sk_blocks * m + b = sk_total_iters
1167 // => b = sk_total_iters - m * sk_blocks
1168 // NOTE: big could be zero
1169 uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1170 sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1171 k_iters_per_big_block = k_iters_per_sk_block + 1;
1172
1173 dp_num_blocks = dp_tiles;
1174 dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
1175 }
1176 }
1179
1181 {
1183 uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
1184 eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1185 eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1186 }
1187
1188#if 0
1189 printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
1190 "sk_num_blocks:%d, "
1191 "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
1192 "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
1193 "sk_tiles:%u, workspace(acc float):%u\n",
1194 num_cu,
1195 occupancy,
1196 get_grid_dims().x,
1197 num_tiles,
1198 dp_tiles,
1201 sk_total_iters,
1203 dp_iters_per_block,
1204 dp_num_blocks,
1205 k_iters_per_tile.get(),
1208 get_sk_tiles(),
1209 get_workspace_size(sizeof(float)));
1210#endif
1211 }
1212
1213 __host__ __device__ uint32_t get_sk_total_iters() const
1214 {
1217 return sk_total_iters;
1218 }
1219
1220 __host__ __device__ uint32_t get_sk_tiles() const
1221 {
1222 // tiles for sk
1223 uint32_t sk_total_iters = get_sk_total_iters();
1224 return k_iters_per_tile.div(sk_total_iters);
1225 }
1226
1227 __host__ __device__ dim3 get_grid_dims() const
1228 {
1230 {
1231 return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1232 }
1233 else
1234 return dim3(reduction_start_block_idx, 1, 1);
1235 }
1236
1237 __device__ uint32_t get_block_idx() const
1238 {
1239 // TODO: swizzle block index for better locality
1240 return __builtin_amdgcn_readfirstlane(blockIdx.x);
1241 }
1242
1243 __device__ void
1244 get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1245 {
1246 if(block_idx < sk_num_big_blocks)
1247 {
1248 iter_start = block_idx * k_iters_per_big_block;
1249 iter_end = iter_start + k_iters_per_big_block;
1250 }
1251 else if(block_idx < sk_num_blocks)
1252 {
1253 iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1254 (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1255 iter_end = iter_start + (k_iters_per_big_block - 1);
1256 }
1257 else if(block_idx >= dp_start_block_idx)
1258 {
1259 uint32_t sk_total_iters = get_sk_total_iters();
1260 uint32_t dp_iters_per_block = k_iters_per_tile.get();
1261 iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1262 iter_end = iter_start + dp_iters_per_block;
1263 }
1264 }
1265
1267 uint32_t iter_end,
1268 uint32_t total_iter_length) const
1269 {
1270 uint32_t iter_length_mod, iter_length_quo /*unused*/;
1271 k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1272 uint32_t current_iter_length = math::min(
1273 iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1274 return current_iter_length;
1275 }
1276
1277 __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1278
1279 __device__ void
1280 get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1281 {
1282 k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1283 }
1284
1285 __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1286 {
1287 uint32_t m_tile_idx, n_tile_idx;
1288 uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1289 n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1290
1291 // swizzle tile
1293
1294 uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1295
1296 const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1298 : tile_swizzle_sub_m_rem;
1299
1300 uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1301 m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1302 m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1303
1304 uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1305
1306 uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1307
1308 n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1309 m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1310 return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1311 n_tile_idx_with_adapt);
1312 }
1313
1314 __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1315 {
1316 static constexpr uint32_t alignment = 128;
1317 uint32_t acc_buffer_bytes =
1318 MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1319 return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1320 }
1321
1322 __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1323 {
1324 return get_sk_tiles() * sizeof(uint32_t);
1325 }
1326
1327 __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1328 {
1330 }
1331
1332 __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1333 const MDiv& eqav_tiles_) const
1334 {
1335 uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1336 uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
1337 uint32_t quo_, rem_;
1338 eqav_tiles_.divmod(tile_idx_, quo_, rem_);
1339 return quo_ * max_eqav_tiles_ + rem_;
1340 }
1341
1342 __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1343 uint32_t iters_per_sk_block_) const
1344 {
1345 return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1346 1);
1347 }
1348
1349 __host__ __device__ uint32_t get_total_acc_buffers() const
1350 {
1351 uint32_t tiles_cover_big_blocks =
1353 uint32_t tiles_cover_little_blocks =
1355
1356 uint32_t total_intersec_big =
1357 get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
1358 uint32_t total_intersec_little =
1359 get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
1360
1361 return sk_num_blocks + total_intersec_big + total_intersec_little;
1362 }
1363
1365 {
1366 // TODO: from big to little
1367 uint32_t tiles_cover_big_blocks =
1369 if(tile_idx_ < tiles_cover_big_blocks)
1370 {
1371 uint32_t touched_sk_blocks =
1372 (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1374 uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
1375 return touched_sk_blocks + current_intersec;
1376 }
1377 else
1378 {
1379 uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1380 uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1381 uint32_t touched_sk_blocks =
1382 (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1383 iters_per_little_sk_block;
1384 uint32_t current_intersec =
1385 get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
1386 return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1387 }
1388 }
1389
1391 {
1392 uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1393 uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1394 if(block_idx_ < sk_num_big_blocks)
1395 {
1396 uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1397 k_iters_per_tile.get() - 1);
1398 uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
1399 return block_idx_ + current_intersec;
1400 }
1401 else
1402 {
1403 uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1404 uint32_t touched_tiles = k_iters_per_tile.div(
1405 block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1406 uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
1407 return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1408 }
1409 }
1410};
1411
1412template <uint32_t MPerBlock_,
1413 uint32_t NPerBlock_,
1414 uint32_t KPerBlock_,
1416 uint32_t TileSwizzleSubM_ = 8,
1417 index_t GroupNum = 8,
1418 index_t M01_ = 4>
1420{
1422 static constexpr uint32_t MPerBlock = MPerBlock_;
1423 static constexpr uint32_t NPerBlock = NPerBlock_;
1424 static constexpr uint32_t KPerBlock = KPerBlock_;
1425 static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
1426
1427 //--------------------------------------
1428 // pass to device
1436 MDiv equiv_tiles_big; // for reduction
1437 MDiv equiv_tiles_little; // for reduction
1439
1440 // prefer construct on host
1442 uint32_t m,
1443 uint32_t n,
1444 uint32_t k,
1445 uint32_t grid_size = 1,
1446 uint32_t streamk_sel = 1,
1448 : reduction_strategy(reduction_strategy_)
1449 {
1450
1451 // total output tiles
1452 uint32_t num_tiles =
1455
1456 uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
1457
1458 // Ensure grid_size is at least 1 to avoid division by zero
1459 grid_size = math::max(grid_size, 1u);
1460
1461 // default to regular DP GEMM if sk blocks == 0
1462 if(streamk_sel == 0)
1463 {
1464 sk_num_blocks = 0;
1465 dp_tiles = num_tiles;
1468
1469 dp_num_blocks = num_tiles; // all tile to be dp block
1471 sk_total_iters = 0; // clear this tiles
1472 }
1473 // 2-tile sk + DP GEMM
1474 else
1475 {
1476 // check if there's enough work for DP+ stream-k
1477 bool bigEnough = num_tiles > grid_size;
1478
1479 // Select between stream-k strategies
1480 // Add safety checks to prevent zero or negative values
1481 uint32_t sk_tiles = 0;
1482 if(streamk_sel == 1) // 1 tile stream-k
1483 {
1484 sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
1485
1486 // Ensure sk_tiles is at least 1
1487 sk_tiles = math::max(sk_tiles, 1u);
1488 }
1489 else if(streamk_sel == 2) // 2-tile stream-k
1490 {
1491 sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
1492
1493 // Ensure sk_tiles is at least 1 but not more than num_tiles
1494 sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1495 }
1496 else if(streamk_sel == 3) // 3-tile stream-k
1497 {
1498 sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
1499 : num_tiles;
1500
1501 // Ensure sk_tiles is at least 1 but not more than num_tiles
1502 sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1503 }
1504 else if(streamk_sel == 4) // 4-tile stream-k
1505 {
1506 sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
1507 : num_tiles;
1508
1509 // Ensure sk_tiles is at least 1 but not more than num_tiles
1510 sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
1511 }
1512
1513 sk_num_blocks = sk_tiles;
1514 // Remaining tiles are DP tiles
1515 dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
1516
1517 sk_total_iters = k_iters_per_tile.get() * sk_tiles;
1518
1519 // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
1520 // we need to decide how many iters for each sk block
1521 // let m = k_iters_per_sk_block
1522 // some of the sk block (little) will cover m iters, some (big) will cover m+1
1523 // we have
1524 // 1) l + b = sk_blocks
1525 // 2) l * m + b * (m + 1) = sk_total_iters
1526 // => (l + b) * m + b = sk_total_iters
1527 // => sk_blocks * m + b = sk_total_iters
1528 // => b = sk_total_iters - m * sk_blocks
1529 // NOTE: big could be zero
1530
1531 // Add safety check for sk_num_blocks to prevent division by zero
1532 if(sk_num_blocks > 0)
1533 {
1534 uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
1535 sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
1536 k_iters_per_big_block = k_iters_per_sk_block + 1;
1537 }
1538 else
1539 {
1540 // Fallback to default GEMM if no stream-k blocks
1541 sk_num_blocks = 0;
1544 dp_tiles = num_tiles;
1545 dp_num_blocks = num_tiles;
1547 sk_total_iters = 0;
1548 }
1549
1550 dp_num_blocks = dp_tiles;
1552 }
1553
1555 // Using multiple blocks for parallel reduction
1557
1559 {
1560 // Add additional safety checks
1561 if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
1562 {
1564 uint32_t upper_little =
1566 equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
1567 equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
1568 }
1569 else
1570 {
1571 // Default safe values
1572 equiv_tiles_big = MDiv(1);
1574 }
1575 }
1576 }
1577
1578 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
1579 {
1580 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
1581 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
1582
1583 return M0 * N0;
1584 }
1585 __host__ __device__ uint32_t get_sk_total_iters() const
1586 {
1589 return sk_total_iters;
1590 }
1591
1592 __host__ __device__ uint32_t get_sk_tiles() const
1593 {
1594 // tiles for sk
1595 uint32_t sk_total_iters = get_sk_total_iters();
1596 return k_iters_per_tile.div(sk_total_iters);
1597 }
1598
1599 __host__ __device__ index_t get_grid_dims() const
1600 {
1602 {
1603 // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
1605 }
1606 else
1608 }
1609
1610 __device__ uint32_t get_block_idx() const
1611 {
1612 // TODO: swizzle block index for better locality
1613 return __builtin_amdgcn_readfirstlane(blockIdx.x);
1614 }
1615
1616 __device__ void
1617 get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
1618 {
1619 if(block_idx < sk_num_big_blocks)
1620 {
1621 iter_start = block_idx * k_iters_per_big_block;
1622 iter_end = iter_start + k_iters_per_big_block;
1623 }
1624 else if(block_idx < sk_num_blocks)
1625 {
1626 iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
1627 (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
1628 iter_end = iter_start + (k_iters_per_big_block - 1);
1629 }
1630 else if(block_idx >= dp_start_block_idx)
1631 {
1632 uint32_t sk_total_iters = get_sk_total_iters();
1633 uint32_t dp_iters_per_block = k_iters_per_tile.get();
1634 iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
1635 iter_end = iter_start + dp_iters_per_block;
1636 }
1637 }
1638
1640 uint32_t iter_end,
1641 uint32_t total_iter_length) const
1642 {
1643 uint32_t iter_length_mod, iter_length_quo /*unused*/;
1644 k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
1645 uint32_t current_iter_length = math::min(
1646 iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
1647 return current_iter_length;
1648 }
1649
1650 __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
1651
1652 __device__ void
1653 get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
1654 {
1655 k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
1656 }
1657
1658 __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
1659 {
1660 uint32_t m_tile_idx, n_tile_idx;
1661 uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
1662 n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
1663
1664 // // swizzle tile
1666
1667 uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
1668
1669 const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
1671 : tile_swizzle_sub_m_rem;
1672
1673 uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
1674 m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
1675 m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
1676
1677 uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
1678
1679 uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
1680
1681 n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
1682 m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
1683 return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
1684 n_tile_idx_with_adapt);
1685 }
1686
1687 __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
1688 {
1689 static constexpr uint32_t alignment = 128;
1690 uint32_t acc_buffer_bytes =
1691 MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
1692 return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
1693 }
1694
1695 __host__ __device__ uint32_t get_workspace_size_for_semaphore() const
1696 {
1697 return get_sk_tiles() * sizeof(uint32_t);
1698 }
1699
1700 __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
1701 {
1703 }
1704
1705 __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
1706 const MDiv& equiv_tiles_) const
1707 {
1708 uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
1709 uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
1710 uint32_t quo_, rem_;
1711 equiv_tiles_.divmod(tile_idx_, quo_, rem_);
1712 return quo_ * max_equiv_tiles_ + rem_;
1713 }
1714
1715 __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
1716 uint32_t iters_per_sk_block_) const
1717 {
1718 return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1719 1);
1720 }
1721
1722 __host__ __device__ uint32_t get_total_acc_buffers() const
1723 {
1724 uint32_t tiles_cover_big_blocks =
1726 uint32_t tiles_cover_little_blocks =
1728
1729 uint32_t total_intersec_big =
1730 get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
1731 uint32_t total_intersec_little =
1732 get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
1733
1734 return sk_num_blocks + total_intersec_big + total_intersec_little;
1735 }
1736
1738 {
1739 // TODO: from big to little
1740 uint32_t tiles_cover_big_blocks =
1742 if(tile_idx_ < tiles_cover_big_blocks)
1743 {
1744 uint32_t touched_sk_blocks =
1745 (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
1747 uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
1748 return touched_sk_blocks + current_intersec;
1749 }
1750 else
1751 {
1752 uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1753 uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
1754 uint32_t touched_sk_blocks =
1755 (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
1756 iters_per_little_sk_block;
1757 uint32_t current_intersec =
1758 get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
1759 return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
1760 }
1761 }
1762
1764 {
1765 uint32_t iters_per_big_sk_block = k_iters_per_big_block;
1766 uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
1767 if(block_idx_ < sk_num_big_blocks)
1768 {
1769 uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
1770 k_iters_per_tile.get() - 1);
1771 uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
1772 return block_idx_ + current_intersec;
1773 }
1774 else
1775 {
1776 uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
1777 uint32_t touched_tiles = k_iters_per_tile.div(
1778 block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
1779 uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
1780 return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
1781 }
1782 }
1783};
1784
1785} // namespace ck
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
StreamKReductionStrategy
Definition block_to_ctile_map.hpp:1011
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex &up_idx)
Definition multi_index_transform_helper.hpp:157
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim)
Definition block_to_ctile_map.hpp:835
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
unsigned int uint32_t
Definition stdint.h:126
__host__ __device__ constexpr auto CalculateGridSize(index_t M, index_t N, index_t k_split) const
Definition block_to_ctile_map.hpp:982
__device__ constexpr auto CalculateBottomIndex(const TopIdx &) const
Definition block_to_ctile_map.hpp:991
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition block_to_ctile_map.hpp:997
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &) const
Definition block_to_ctile_map.hpp:1004
__host__ __device__ BlockToCTileMap_3DGrid_KSplit()=default
__host__ __device__ uint32_t get_sk_tiles() const
Definition block_to_ctile_map.hpp:1592
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition block_to_ctile_map.hpp:1700
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &equiv_tiles_) const
Definition block_to_ctile_map.hpp:1705
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition block_to_ctile_map.hpp:1763
__host__ __device__ uint32_t get_sk_total_iters() const
Definition block_to_ctile_map.hpp:1585
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition block_to_ctile_map.hpp:1722
__host__ __device__ index_t get_grid_dims() const
Definition block_to_ctile_map.hpp:1599
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition block_to_ctile_map.hpp:1650
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition block_to_ctile_map.hpp:1695
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition block_to_ctile_map.hpp:1617
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition block_to_ctile_map.hpp:1737
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition block_to_ctile_map.hpp:1658
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition block_to_ctile_map.hpp:1687
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition block_to_ctile_map.hpp:1639
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition block_to_ctile_map.hpp:1715
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size=1, uint32_t streamk_sel=1, StreamKReductionStrategy reduction_strategy_=StreamKReductionStrategy::Atomic)
Definition block_to_ctile_map.hpp:1441
__device__ uint32_t get_block_idx() const
Definition block_to_ctile_map.hpp:1610
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition block_to_ctile_map.hpp:1653
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:1578
uint32_t k_iters_per_big_block
Definition block_to_ctile_map.hpp:1036
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
Definition block_to_ctile_map.hpp:1327
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
Definition block_to_ctile_map.hpp:1390
__host__ __device__ uint32_t get_sk_total_iters() const
Definition block_to_ctile_map.hpp:1213
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const
Definition block_to_ctile_map.hpp:1342
static constexpr uint32_t MPerBlock
Definition block_to_ctile_map.hpp:1024
uint32_t dp_start_block_idx
Definition block_to_ctile_map.hpp:1034
__host__ __device__ uint32_t get_sk_tiles() const
Definition block_to_ctile_map.hpp:1220
static constexpr uint32_t KPerBlock
Definition block_to_ctile_map.hpp:1026
__host__ __device__ uint32_t get_total_acc_buffers() const
Definition block_to_ctile_map.hpp:1349
__device__ uint32_t get_current_iter_length(uint32_t iter_start, uint32_t iter_end, uint32_t total_iter_length) const
Definition block_to_ctile_map.hpp:1266
static constexpr uint32_t NPerBlock
Definition block_to_ctile_map.hpp:1025
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
Definition block_to_ctile_map.hpp:1364
uint32_t reduction_start_block_idx
Definition block_to_ctile_map.hpp:1035
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
Definition block_to_ctile_map.hpp:1314
MDiv k_iters_per_tile
Definition block_to_ctile_map.hpp:1038
__device__ void get_tile_idx_with_offset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const
Definition block_to_ctile_map.hpp:1280
static constexpr uint32_t tile_swizzle_sub_m
Definition block_to_ctile_map.hpp:1028
BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t num_cu, uint32_t occupancy, uint32_t sk_blocks=0xffffffff)
Definition block_to_ctile_map.hpp:1046
static constexpr StreamKReductionStrategy ReductionStrategy
Definition block_to_ctile_map.hpp:1027
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
Definition block_to_ctile_map.hpp:1285
__device__ uint32_t get_tile_idx(uint32_t iter) const
Definition block_to_ctile_map.hpp:1277
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv &eqav_tiles_) const
Definition block_to_ctile_map.hpp:1332
__device__ uint32_t get_block_idx() const
Definition block_to_ctile_map.hpp:1237
__device__ void get_block_itr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const
Definition block_to_ctile_map.hpp:1244
MDiv eqav_tiles_little
Definition block_to_ctile_map.hpp:1040
uint32_t sk_num_blocks
Definition block_to_ctile_map.hpp:1032
MDiv2 n_tiles
Definition block_to_ctile_map.hpp:1037
MDiv eqav_tiles_big
Definition block_to_ctile_map.hpp:1039
static constexpr uint32_t min_k_iters_per_sk_block
Definition block_to_ctile_map.hpp:1023
uint32_t sk_num_big_blocks
Definition block_to_ctile_map.hpp:1033
__host__ __device__ dim3 get_grid_dims() const
Definition block_to_ctile_map.hpp:1227
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
Definition block_to_ctile_map.hpp:1322
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:298
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt()=default
static constexpr auto I1
Definition block_to_ctile_map.hpp:273
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition block_to_ctile_map.hpp:292
static constexpr auto I0
Definition block_to_ctile_map.hpp:272
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition block_to_ctile_map.hpp:384
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition block_to_ctile_map.hpp:276
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:773
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:755
static constexpr auto I2
Definition block_to_ctile_map.hpp:723
static constexpr auto I0
Definition block_to_ctile_map.hpp:721
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1, index_t KSplit=1)
Definition block_to_ctile_map.hpp:728
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01()=default
static constexpr auto I3
Definition block_to_ctile_map.hpp:724
static constexpr auto I1
Definition block_to_ctile_map.hpp:722
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition block_to_ctile_map.hpp:764
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:741
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:556
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8, index_t KSplit=1)
Definition block_to_ctile_map.hpp:549
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt()=default
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:567
static constexpr auto I0
Definition block_to_ctile_map.hpp:542
static constexpr auto I1
Definition block_to_ctile_map.hpp:543
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &) const
Definition block_to_ctile_map.hpp:600
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition block_to_ctile_map.hpp:594
static constexpr auto I2
Definition block_to_ctile_map.hpp:544
static constexpr auto I3
Definition block_to_ctile_map.hpp:545
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01()=default
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:646
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1, index_t N01=1)
Definition block_to_ctile_map.hpp:625
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:661
static constexpr auto I0
Definition block_to_ctile_map.hpp:618
static constexpr auto I3
Definition block_to_ctile_map.hpp:621
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:632
static constexpr auto I1
Definition block_to_ctile_map.hpp:619
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition block_to_ctile_map.hpp:652
static constexpr auto I2
Definition block_to_ctile_map.hpp:620
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition block_to_ctile_map.hpp:246
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01=8)
Definition block_to_ctile_map.hpp:138
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:157
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt & operator=(const BlockToCTileMap_M00_N0_M01Adapt &)=default
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:179
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt &&)=default
static constexpr auto I0
Definition block_to_ctile_map.hpp:123
static __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition block_to_ctile_map.hpp:166
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt & operator=(BlockToCTileMap_M00_N0_M01Adapt &&)=default
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=8)
Definition block_to_ctile_map.hpp:150
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt()=default
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt &)=default
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N &) const
Definition block_to_ctile_map.hpp:173
static constexpr auto I1
Definition block_to_ctile_map.hpp:124
Definition block_to_ctile_map.hpp:261
static constexpr auto I3
Definition block_to_ctile_map.hpp:28
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01()=default
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:66
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:51
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N &c_grid_desc_m_n, index_t M01=1)
Definition block_to_ctile_map.hpp:32
static constexpr auto I2
Definition block_to_ctile_map.hpp:27
static constexpr auto I0
Definition block_to_ctile_map.hpp:25
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:38
static constexpr auto I1
Definition block_to_ctile_map.hpp:26
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition block_to_ctile_map.hpp:57
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition block_to_ctile_map.hpp:451
static __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition block_to_ctile_map.hpp:445
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N &c_grid_desc_m_n, index_t N01=8)
Definition block_to_ctile_map.hpp:429
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt & operator=(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition block_to_ctile_map.hpp:525
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:436
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:457
static constexpr auto I1
Definition block_to_ctile_map.hpp:405
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt &)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt &&)=default
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01=8)
Definition block_to_ctile_map.hpp:418
static constexpr auto I0
Definition block_to_ctile_map.hpp:404
Definition block_to_ctile_map.hpp:399
Definition magic_division.hpp:204
Definition magic_division.hpp:162
__host__ __device__ uint32_t get() const
Definition magic_division.hpp:200
__host__ __device__ void divmod(uint32_t dividend_, uint32_t &quotient_, uint32_t &remainder_) const
Definition magic_division.hpp:194
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition block_to_ctile_map.hpp:940
index_t tile_offset_
Definition block_to_ctile_map.hpp:960
Block2ETileMap block_to_ctile_map_
Definition block_to_ctile_map.hpp:958
__host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, index_t group_offset, index_t tile_offset)
Definition block_to_ctile_map.hpp:923
Block2ETileMap underlying_type
Definition block_to_ctile_map.hpp:921
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:933
index_t group_offset_
Definition block_to_ctile_map.hpp:959
__device__ void UpdateTileOffset(index_t offset)
Definition block_to_ctile_map.hpp:957
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:947
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
Definition block_to_ctile_map.hpp:952
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition block_to_ctile_map.hpp:891
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
Definition block_to_ctile_map.hpp:909
index_t block_start_
Definition block_to_ctile_map.hpp:915
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition block_to_ctile_map.hpp:884
__host__ __device__ OffsettedBlockToCTileMap()=default
__host__ constexpr bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:898
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:904
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
Definition block_to_ctile_map.hpp:876
Block2ETileMap underlying_type
Definition block_to_ctile_map.hpp:873
Block2ETileMap block_to_ctile_map_
Definition block_to_ctile_map.hpp:914