tile_fmha_shape.hpp Source File

tile_fmha_shape.hpp Source File#

Composable Kernel: tile_fmha_shape.hpp Source File
tile_fmha_shape.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10template <index_t Headdim>
11static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
12{
13 if constexpr(Headdim == 48)
14 return 48;
15 else if constexpr(Headdim == 96)
16 return 128;
17 else if constexpr(Headdim == 160)
18 return 256;
19 else if constexpr(Headdim == 192)
20 return 192;
21 else if constexpr(is_power_of_two_integer(Headdim))
22 return Headdim;
23 else
24 static_assert(Headdim == 0,
25 "only Headdim of 48, 96, 160, 192 and power-of-two is supported");
26};
27
28template <typename BlockTile_, // sequence<...
29 typename Gemm0BlockWarps_,
30 typename Gemm0WarpTile_,
31 typename Gemm1BlockWarps_,
32 typename Gemm1WarpTile_,
33 bool IsVLayoutRowMajor_>
35{
41
42 static constexpr index_t NumGemm0Warps =
44 static constexpr index_t NumGemm1Warps =
46 static_assert(NumGemm1Warps % NumGemm0Warps == 0);
47
49
50 static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
51 static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
52 static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
53 static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
54 static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
55 static constexpr index_t kQKHeaddim =
56 BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
57 // once (or repeately load Q as a whole tile)
58 static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
59
60 static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
61
62 // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
63 static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
64 using VLayout = std::conditional_t<IsVLayoutRowMajor,
67};
68
69template <typename BlockTile_, // sequence<...
70 typename Gemm0BlockWarps_,
71 typename Gemm0WarpTile_,
72 typename Gemm1BlockWarps_,
73 typename Gemm1WarpTile_,
74 typename Gemm2BlockWarps_,
75 typename Gemm2WarpTile_,
76 typename Gemm3BlockWarps_,
77 typename Gemm3WarpTile_,
78 typename Gemm4BlockWarps_,
79 typename Gemm4WarpTile_,
80 index_t kMaxSeqLenQ_ = 0>
82{
94
95 static constexpr index_t NumWarps =
97
100
101 static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
102 static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
103 static constexpr index_t kK0 =
104 BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll
105 static constexpr index_t kK1 =
106 BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll
107 static constexpr index_t kK2 =
108 BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll
109 static constexpr index_t kK3 =
110 BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll
111 static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
112 static constexpr index_t kQKHeaddim =
113 BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or
114 // K/K^T at once
115 static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
116 // that need load V at once
117
118 static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
119 static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
120 "kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
121};
122
123} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
Definition tile_fmha_shape.hpp:82
static constexpr index_t kQKHeaddim
Definition tile_fmha_shape.hpp:112
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition tile_fmha_shape.hpp:84
static constexpr index_t kK3
Definition tile_fmha_shape.hpp:109
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition tile_fmha_shape.hpp:87
static constexpr index_t kN0
Definition tile_fmha_shape.hpp:102
static constexpr index_t kMaxSeqLenQ
Definition tile_fmha_shape.hpp:118
remove_cvref_t< Gemm4WarpTile_ > Gemm4WarpTile
Definition tile_fmha_shape.hpp:93
remove_cvref_t< Gemm4BlockWarps_ > Gemm4BlockWarps
Definition tile_fmha_shape.hpp:92
static constexpr index_t kVHeaddim
Definition tile_fmha_shape.hpp:115
remove_cvref_t< Gemm2BlockWarps_ > Gemm2BlockWarps
Definition tile_fmha_shape.hpp:88
remove_cvref_t< Gemm3BlockWarps_ > Gemm3BlockWarps
Definition tile_fmha_shape.hpp:90
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition tile_fmha_shape.hpp:85
remove_cvref_t< Gemm2WarpTile_ > Gemm2WarpTile
Definition tile_fmha_shape.hpp:89
static constexpr index_t kM0
Definition tile_fmha_shape.hpp:101
remove_cvref_t< BlockTile_ > BlockTile
Definition tile_fmha_shape.hpp:83
static constexpr index_t kK4
Definition tile_fmha_shape.hpp:111
static constexpr index_t kK0
Definition tile_fmha_shape.hpp:103
static constexpr index_t kK2
Definition tile_fmha_shape.hpp:107
static constexpr index_t NumWarps
Definition tile_fmha_shape.hpp:95
static constexpr index_t kK1
Definition tile_fmha_shape.hpp:105
remove_cvref_t< Gemm3WarpTile_ > Gemm3WarpTile
Definition tile_fmha_shape.hpp:91
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition tile_fmha_shape.hpp:86
Definition tile_fmha_shape.hpp:35
static constexpr bool IsVLayoutRowMajor
Definition tile_fmha_shape.hpp:63
static constexpr index_t kQKHeaddim
Definition tile_fmha_shape.hpp:55
static constexpr index_t kK0
Definition tile_fmha_shape.hpp:52
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition tile_fmha_shape.hpp:39
static constexpr index_t NumGemm0Warps
Definition tile_fmha_shape.hpp:42
static constexpr index_t NumWarps
Definition tile_fmha_shape.hpp:48
static constexpr index_t kK1
Definition tile_fmha_shape.hpp:54
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition tile_fmha_shape.hpp:40
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition tile_fmha_shape.hpp:38
static constexpr index_t kSubQKHeaddim
Definition tile_fmha_shape.hpp:60
static constexpr index_t kN0
Definition tile_fmha_shape.hpp:51
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition tile_fmha_shape.hpp:37
static constexpr index_t kN1
Definition tile_fmha_shape.hpp:53
static constexpr index_t kM0
Definition tile_fmha_shape.hpp:50
std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition tile_fmha_shape.hpp:64
remove_cvref_t< BlockTile_ > BlockTile
Definition tile_fmha_shape.hpp:36
static constexpr index_t NumGemm1Warps
Definition tile_fmha_shape.hpp:44
Definition tile/core/numeric/math.hpp:98
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17