gemm_tile_partitioner.hpp Source File

gemm_tile_partitioner.hpp Source File#

Composable Kernel: gemm_tile_partitioner.hpp Source File
gemm_tile_partitioner.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
8
9#pragma once
10
11#include "ck_tile/core.hpp"
13
14namespace ck_tile {
15
20template <typename BlockGemmShapeType>
22{
24
25 static constexpr index_t MPerBlock = BlockGemmShape::kM;
26 static constexpr index_t NPerBlock = BlockGemmShape::kN;
27 static constexpr index_t KPerBlock = BlockGemmShape::kK;
28
31 [[maybe_unused]] index_t N) noexcept;
32
40 CK_TILE_HOST static auto
41 GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
42 {
43 const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
44 const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
45 return dim3(GridDimX, GridDimY, 1);
46 }
47
54 CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
55 {
57 }
58
65
73 CK_TILE_DEVICE static auto
74 GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple<index_t, index_t>
75 {
76 const index_t iM = amd_wave_read_first_lane(blockIdx);
77 const index_t iN = amd_wave_read_first_lane(blockIdy);
78 return make_tuple(iM, iN);
79 }
80};
81
87template <typename BlockGemmShape_>
89{
91
92 static constexpr index_t MPerBlock = BlockGemmShape::kM;
93 static constexpr index_t NPerBlock = BlockGemmShape::kN;
94 static constexpr index_t KPerBlock = BlockGemmShape::kK;
95
97
105 {
106 N_ = N;
107 }
108
116 CK_TILE_HOST_DEVICE static auto
117 GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
118 {
119 const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
120 const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
121 return GridDimX * GridDimY;
122 }
123
131 {
133 }
134
141 CK_TILE_DEVICE static auto
143 {
144 const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
145
146 const index_t iM = amd_wave_read_first_lane(blockIdx / NBlocks);
147 const index_t iN = amd_wave_read_first_lane(blockIdx - iM * NBlocks);
148 return make_tuple(iM, iN);
149 }
150
151 private:
152 CK_TILE_DEVICE static index_t N_;
153};
154
159template <typename, typename = void>
160struct HasFnOneArgImpl : std::false_type
161{
162};
163
169template <typename T>
170struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
171 : std::true_type
172{
173};
174
181template <typename TilePartitioner,
182 typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
184{
192 [[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex(
193 index_t block_start, index_t M, index_t N) noexcept -> const tuple<index_t, index_t>
194 {
195 const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
196 return make_tuple(iM, iN);
197 }
198
207 [[nodiscard]] CK_TILE_DEVICE static auto
208 GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept
210 {
211 const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(block_idx - block_start);
212 return make_tuple(iM, iN);
213 }
214};
215
227template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
229{
231
232 static constexpr index_t MPerBlock = BlockGemmShape::kM;
233 static constexpr index_t NPerBlock = BlockGemmShape::kN;
234 static constexpr index_t KPerBlock = BlockGemmShape::kK;
235
238 : M(M_), N(N_)
239 {
240 }
241
249 CK_TILE_HOST_DEVICE static auto
250 GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
251 {
252 const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
253 const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
254 return GridDimX * GridDimY;
255 }
256
264 {
266 }
267
274 CK_TILE_DEVICE auto
276 {
277 const auto M0 = integer_divide_ceil(M, MPerBlock);
278 const auto N0 = integer_divide_ceil(N, NPerBlock);
279
280 if(M0 == 1)
281 {
282 return make_tuple(0, block_1d_id);
283 }
284 else if(N0 == 1)
285 {
286 return make_tuple(block_1d_id, 0);
287 }
288 // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
289 else
290 {
291 const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
292 const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
293 const auto group_id_y = block_1d_id / GroupNum;
294 const auto group_id_x = block_1d_id - group_id_y * GroupNum;
295 const auto remap_block_1d_id =
296 group_id_x <= big_group_num
297 ? group_id_x * group_size + group_id_y
298 : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
299
300 const index_t idx_M0 = remap_block_1d_id / N0;
301 const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
302
303 const index_t M0_tmp = M0 / M01;
304 const index_t M0_mod_M01 = M0 - M0_tmp * M01;
305
306 const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
307
308 const index_t idx_M00 = idx_M0 / M01;
309 const index_t idx_M01 = idx_M0 - idx_M00 * M01;
310 const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
311
355
356 const index_t N_out = idx_N0_M01_local / M01_adapt;
357 const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
358
359 return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
360 }
361 }
362
363 private:
364 index_t M;
365 index_t N;
366};
367
384template <typename BlockGemmShapeType,
386 uint32_t TileSwizzleSubM = 8>
388{
389 using BlockGemmShape = BlockGemmShapeType;
390
391 static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
392 static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
393 static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
394
396
401 uint32_t N,
402 uint32_t K,
403 uint32_t num_cu,
404 uint32_t occupancy,
405 uint32_t sk_blocks = 0xffffffff) noexcept
406 : M_(M), N_(N), K_(K)
407 {
408 num_tile_m_ = integer_divide_ceil(M, MPerBlock);
409 num_tile_n_ = integer_divide_ceil(N, NPerBlock);
410 num_tile_k_ = integer_divide_ceil(K, KPerBlock);
411
412 constexpr uint32_t min_k_iters_per_sk_block = 2;
413 uint32_t num_tiles = num_tile_m_ * num_tile_n_;
414 k_iters_per_tile = mdiv(num_tile_k_);
415
416 // one cu can hold one wg at one time, from the whole cZ's point of view
417 // if number of wg is same as num_cu, we call it 1 dispatch
418 // if number of wg is 2x num_cu, we call it 2 dispatches.
419 // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
420 // dispatch)
421 //
422 const uint32_t full_dispatches = num_tiles / num_cu;
423 const uint32_t full_dispatch_tiles = full_dispatches * num_cu;
424 const uint32_t partial_dispatch_tiles = num_tiles - full_dispatch_tiles;
425
426 uint32_t sk_occupancy = occupancy;
427 uint32_t dp_tiles = full_dispatch_tiles;
428 uint32_t sk_tiles = partial_dispatch_tiles;
429
430 if(full_dispatches < occupancy)
431 {
432 // in this case, we allocate all blocks as sk blocks
433 // sk_occupancy = occupancy - full_dispatches;
434 sk_occupancy = 1;
435 dp_tiles = full_dispatch_tiles;
436 sk_tiles = partial_dispatch_tiles;
437 }
438 else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
439 {
440 // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
441 // occupancy = 3, full_dispatches = 5, 8, 11 ...
442 // occupancy = 4, full_dispatches = 7, 11 ...
443 sk_occupancy = 1; // left 1 slot for sk occupancy
444 dp_tiles = full_dispatch_tiles;
445 sk_tiles = partial_dispatch_tiles;
446 }
447 else
448 {
449 // otherwise, we reduce 1 dispatch from dp, together with partial dispatch,
450 // to construct sk dispatch
451 sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
452 dp_tiles = full_dispatch_tiles - num_cu;
453 sk_tiles = partial_dispatch_tiles + num_cu;
454 }
455
456 // uint32_t dp_iters_per_block = k_iters_per_tile.get();
457 uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
458 uint32_t dp_num_blocks = 0;
459
460 {
461 const uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
462 const uint32_t max_sk_tiles =
463 (sk_tiles >= num_cu) ? num_cu * sk_occupancy
464 : min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
465
466 // if use dp for sk-block, how many iters do we need
467 const uint32_t dp_for_sk_iters = k_iters_per_tile.get();
468
469 uint32_t best_sk_score =
470 std::numeric_limits<int>::max(); // we need to find the smallest sk iters
471 for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
472 tentative_sk_blocks++)
473 {
474 const uint32_t tentative_sk_iters_per_block =
475 (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
476 const uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
477 const uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
478
479 // the more sk_blocks_per_tile, the worse the overhead
480 uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
481 if(tentative_sk_blocks % sk_tiles != 0)
482 {
483 // penalty for uneven divide
484 cross_sk_blocks_overhead +=
485 sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
486 }
487
488 const uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
489
490 if(tentative_sk_score < best_sk_score)
491 {
492 best_sk_score = tentative_sk_score;
493 sk_num_blocks = tentative_sk_blocks;
494 }
495 }
496
497 if(best_sk_score >= dp_for_sk_iters)
498 {
499 sk_num_blocks = 0;
500 }
501
502 // give a chance to control num of sk blocks
503 sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
504
505 if(sk_num_blocks == 0)
506 {
509
510 dp_num_blocks = num_tiles; // all tile to be dp block
512 sk_total_iters = 0; // clear this tiles
513 }
514 else
515 {
516 // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
517 // we need to decide how many iters for each sk block
518 // let m = k_iters_per_sk_block
519 // some of the sk block (little) will cover m iters, some (big) will cover m+1
520 // we have
521 // 1) l + b = sk_blocks
522 // 2) l * m + b * (m + 1) = sk_total_iters
523 // => (l + b) * m + b = sk_total_iters
524 // => sk_blocks * m + b = sk_total_iters
525 // => b = sk_total_iters - m * sk_blocks
526 // NOTE: big could be zero
527 const uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
528 sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
529 k_iters_per_big_block = k_iters_per_sk_block + 1;
530
531 dp_num_blocks = dp_tiles;
532 dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
533 }
534 }
535 n_tiles = mdiv2(num_tile_n_);
537
538 if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
539 {
540 const uint32_t upper_big = lcm(k_iters_per_big_block, k_iters_per_tile.get());
541 const uint32_t upper_little = lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
542 equiv_tiles_big = mdiv(upper_big / k_iters_per_tile.get());
543 equiv_tiles_little = mdiv(upper_little / k_iters_per_tile.get());
544 }
545 }
546
550 CK_TILE_HOST auto GridSize() const noexcept -> dim3
551 {
552 if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
553 {
554 return dim3(reduction_start_block_idx + GetSkTiles(), 1, 1);
555 }
556 else
557 return dim3(reduction_start_block_idx, 1, 1);
558 }
559
564 {
565 return integer_divide_ceil(K, KPerBlock); // Stream-K processes one K-slice at a time
566 }
567
571 CK_TILE_DEVICE auto
573 {
574 uint32_t m_tile_idx, n_tile_idx;
575 n_tiles.divmod(tile_idx, num_tile_n_, m_tile_idx, n_tile_idx);
576
577 // swizzle tile
578
579 uint32_t tile_swizzle_sub_m_rem = num_tile_m_ % TileSwizzleSubM;
580
581 const auto sub_m_adapt = (m_tile_idx < (num_tile_m_ - tile_swizzle_sub_m_rem))
582 ? TileSwizzleSubM
583 : tile_swizzle_sub_m_rem;
584
585 uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
586 m_tile_idx_sub0 = m_tile_idx / TileSwizzleSubM;
587 m_tile_idx_sub1 = m_tile_idx % TileSwizzleSubM;
588
589 uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * num_tile_n_;
590
591 uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
592
593 n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
594 m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
595 return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * TileSwizzleSubM,
596 n_tile_idx_with_adapt);
597 }
598
602 CK_TILE_DEVICE void
603 GetBlockItr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const noexcept
604 {
605 if(block_idx < sk_num_big_blocks)
606 {
607 iter_start = block_idx * k_iters_per_big_block;
608 iter_end = iter_start + k_iters_per_big_block;
609 }
610 else if(block_idx < sk_num_blocks)
611 {
613 (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
614 iter_end = iter_start + (k_iters_per_big_block - 1);
615 }
616 else if(block_idx >= dp_start_block_idx)
617 {
618 uint32_t sk_total_iters = GetSkTotalIters();
619 uint32_t dp_iters_per_block = k_iters_per_tile.get();
620 iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
621 iter_end = iter_start + dp_iters_per_block;
622 }
623 }
624
629 {
632 return sk_total_iters;
633 }
634
639 {
640 // tiles for sk
641 uint32_t sk_total_iters = GetSkTotalIters();
642 return k_iters_per_tile.div(sk_total_iters);
643 }
644
649 uint32_t iter_end) const noexcept
650 {
651 // A WG's iter_end is either in the current C macro tile or not.
652 // If it is not, then the macro tile boundary is where the WG must stop.
653 uint32_t distance_to_tile_boundary =
654 k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get());
655 return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start;
656 }
657
662 {
663 return k_iters_per_tile.div(iter);
664 }
665
669 CK_TILE_DEVICE void
670 GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
671 {
672 k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
673 }
674
679 {
680 static constexpr uint32_t alignment = 128;
681 uint32_t acc_buffer_bytes =
682 MPerBlock * NPerBlock * GetTotalAccBuffers() * acc_element_bytes;
683 return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
684 }
685
690 {
691 return GetSkTiles() * sizeof(uint32_t);
692 }
693
697 CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept
698 {
699 return GetWorkSpaceSizeForAcc(acc_element_bytes) + GetWorkSpaceSizeForSemaphore();
700 }
701
706 const mdiv& equiv_tiles_) const noexcept
707 {
708 uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
709 uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
710 uint32_t quo_, rem_;
711 equiv_tiles_.divmod(tile_idx_, quo_, rem_);
712 return quo_ * max_equiv_tiles_ + rem_;
713 }
714
719 uint32_t iters_per_sk_block_) const noexcept
720 {
721 return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
722 1);
723 }
724
729 {
730 uint32_t tiles_cover_big_blocks =
732 uint32_t tiles_cover_little_blocks =
734
735 uint32_t total_intersec_big = GetTileIntersections(tiles_cover_big_blocks, equiv_tiles_big);
736 uint32_t total_intersec_little =
737 GetTileIntersections(tiles_cover_little_blocks, equiv_tiles_little);
738
739 return sk_num_blocks + total_intersec_big + total_intersec_little;
740 }
741
746 {
747 uint32_t tiles_cover_big_blocks =
749 if(tile_idx_ < tiles_cover_big_blocks)
750 {
751 uint32_t touched_sk_blocks =
752 (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
754 uint32_t current_intersec = GetTileIntersections(tile_idx_, equiv_tiles_big);
755 return touched_sk_blocks + current_intersec;
756 }
757 else
758 {
759 uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
760 uint32_t tile_idx_little_reverse = GetSkTiles() - tile_idx_;
761 uint32_t touched_sk_blocks =
762 (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
763 iters_per_little_sk_block;
764 uint32_t current_intersec =
765 GetTileIntersections(tile_idx_little_reverse, equiv_tiles_little);
766 return GetTotalAccBuffers() - (touched_sk_blocks + current_intersec);
767 }
768 }
769
774 {
775 uint32_t iters_per_big_sk_block = k_iters_per_big_block;
776 uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
777 if(block_idx_ < sk_num_big_blocks)
778 {
779 uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
780 k_iters_per_tile.get() - 1);
781 uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_big);
782 return block_idx_ + current_intersec;
783 }
784 else
785 {
786 uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
787 uint32_t touched_tiles = k_iters_per_tile.div(
788 block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
789 uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_little);
790 return GetTotalAccBuffers() - (block_idx_little_reverse + current_intersec);
791 }
792 }
793
794 // Getters for problem dimensions
795 CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept { return num_tile_m_; }
796 CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept { return num_tile_n_; }
797 CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept { return num_tile_k_; }
798
806 mdiv equiv_tiles_big; // for reduction
807 mdiv equiv_tiles_little; // for reduction
808
809 private:
810 uint32_t M_, N_, K_;
811 uint32_t num_tile_m_, num_tile_n_, num_tile_k_;
812};
813} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
StreamKReductionStrategy
Definition streamk_common.hpp:10
@ Atomic
Definition streamk_common.hpp:11
@ Reduction
Definition streamk_common.hpp:12
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
Definition tile/core/numeric/math.hpp:314
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
STL namespace.
unsigned int uint32_t
Definition stdint.h:126
static constexpr index_t MPerBlock
Definition gemm_tile_partitioner.hpp:232
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition gemm_tile_partitioner.hpp:250
static constexpr index_t KPerBlock
Definition gemm_tile_partitioner.hpp:234
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition gemm_tile_partitioner.hpp:263
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept=delete
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition gemm_tile_partitioner.hpp:230
static constexpr index_t NPerBlock
Definition gemm_tile_partitioner.hpp:233
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition gemm_tile_partitioner.hpp:275
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept=delete
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition gemm_tile_partitioner.hpp:130
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition gemm_tile_partitioner.hpp:142
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition gemm_tile_partitioner.hpp:90
static constexpr index_t MPerBlock
Definition gemm_tile_partitioner.hpp:92
static constexpr index_t NPerBlock
Definition gemm_tile_partitioner.hpp:93
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition gemm_tile_partitioner.hpp:117
static constexpr index_t KPerBlock
Definition gemm_tile_partitioner.hpp:94
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple< index_t, index_t >
The function returns 2D output tile space.
Definition gemm_tile_partitioner.hpp:74
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> dim3
Calculates GEMM kernel grid size.
Definition gemm_tile_partitioner.hpp:41
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition gemm_tile_partitioner.hpp:23
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition gemm_tile_partitioner.hpp:54
static constexpr index_t NPerBlock
Definition gemm_tile_partitioner.hpp:26
static constexpr index_t KPerBlock
Definition gemm_tile_partitioner.hpp:27
static constexpr index_t MPerBlock
Definition gemm_tile_partitioner.hpp:25
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept=delete
GemmTile1DPartitioner::GetOutputTileIndex's std::false specialization, checking expression validity i...
Definition gemm_tile_partitioner.hpp:161
Struct used to calculate offseted tile indexes.
Definition gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition gemm_tile_partitioner.hpp:192
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from a given block index.
Definition gemm_tile_partitioner.hpp:208
CK_TILE_HOST_DEVICE uint32_t GetTileIntersections(uint32_t tiles_, const mdiv &equiv_tiles_) const noexcept
Get location of intersection of tiles for reduction.
Definition gemm_tile_partitioner.hpp:705
CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept
Definition gemm_tile_partitioner.hpp:797
uint32_t k_iters_per_big_block
Definition gemm_tile_partitioner.hpp:803
CK_TILE_HOST_DEVICE uint32_t GetSkTotalIters() const noexcept
Get total number of iterations for sk tiles.
Definition gemm_tile_partitioner.hpp:628
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept=delete
static constexpr uint32_t MPerBlock
Definition gemm_tile_partitioner.hpp:391
CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept
Definition gemm_tile_partitioner.hpp:795
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromBlock(uint32_t block_idx_) const noexcept
Calculate offset based on block_idx index for big/little streamk blocks.
Definition gemm_tile_partitioner.hpp:773
CK_TILE_DEVICE void GetTileIdxWithOffset(uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const noexcept
Get index of tile during a specified iteration.
Definition gemm_tile_partitioner.hpp:670
uint32_t sk_num_blocks
Definition gemm_tile_partitioner.hpp:799
mdiv equiv_tiles_little
Definition gemm_tile_partitioner.hpp:807
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromTile(uint32_t tile_idx_) const noexcept
Calculate offset based on tile index for big/little tiles.
Definition gemm_tile_partitioner.hpp:745
mdiv2 n_tiles
Definition gemm_tile_partitioner.hpp:804
CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept
Definition gemm_tile_partitioner.hpp:796
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept
Calculates the total buffer space needed for accumulation and the semaphore.
Definition gemm_tile_partitioner.hpp:697
static constexpr uint32_t NPerBlock
Definition gemm_tile_partitioner.hpp:392
static constexpr uint32_t KPerBlock
Definition gemm_tile_partitioner.hpp:393
CK_TILE_HOST_DEVICE uint32_t GetTilesCoverSkBlock(uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const noexcept
Calculate the number of tiles needed for the number of sk blocks.
Definition gemm_tile_partitioner.hpp:718
static CK_TILE_HOST_DEVICE auto GetLoopNum(uint32_t K) noexcept -> uint32_t
Calculate number of loop iterations over K dimension for given work unit.
Definition gemm_tile_partitioner.hpp:563
mdiv equiv_tiles_big
Definition gemm_tile_partitioner.hpp:806
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForSemaphore() const noexcept
Calculates the buffer space needed for the semaphore.
Definition gemm_tile_partitioner.hpp:689
CK_TILE_HOST auto GridSize() const noexcept -> dim3
Calculate optimal grid size for Stream-K.
Definition gemm_tile_partitioner.hpp:550
CK_TILE_HOST_DEVICE uint32_t GetSkTiles() const noexcept
Get total number of sk tiles.
Definition gemm_tile_partitioner.hpp:638
CK_TILE_DEVICE auto GetOutputTileIndex(uint32_t tile_idx) const noexcept -> tuple< uint32_t, uint32_t >
Get output tile index for standard 2D mapping (compatibility).
Definition gemm_tile_partitioner.hpp:572
uint32_t sk_num_big_blocks
Definition gemm_tile_partitioner.hpp:800
uint32_t dp_start_block_idx
Definition gemm_tile_partitioner.hpp:801
CK_TILE_DEVICE void GetBlockItr(uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const noexcept
Get work range for a given block ID.
Definition gemm_tile_partitioner.hpp:603
mdiv k_iters_per_tile
Definition gemm_tile_partitioner.hpp:805
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForAcc(uint32_t acc_element_bytes) const noexcept
Calculates the buffer space needed for accumulation.
Definition gemm_tile_partitioner.hpp:678
CK_TILE_HOST_DEVICE uint32_t GetTotalAccBuffers() const noexcept
Calculate the amount of total accumulation buffers required for stream-k.
Definition gemm_tile_partitioner.hpp:728
BlockGemmShapeType BlockGemmShape
Definition gemm_tile_partitioner.hpp:389
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, uint32_t iter_end) const noexcept
Get length of loop iterations for stream-k loop.
Definition gemm_tile_partitioner.hpp:648
CK_TILE_DEVICE uint32_t GetTileIdx(uint32_t iter) const noexcept
Get index of tile during a specified iteration.
Definition gemm_tile_partitioner.hpp:661
uint32_t reduction_start_block_idx
Definition gemm_tile_partitioner.hpp:802
Definition magic_div.hpp:228
Definition magic_div.hpp:186
Definition tile/core/container/tuple.hpp:192