block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp Source File

block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp Source File#

Composable Kernel: block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp Source File
block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
16{
17 static constexpr auto is_qr_qtr_dor_pipeline = true;
18
36 // using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
37
39
40 static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
41 static constexpr index_t kBlockSize = Problem::kBlockSize;
42
43 static constexpr index_t kM0 = BlockFmhaShape::kM0;
44 static constexpr index_t kN0 = BlockFmhaShape::kN0;
45 static constexpr index_t kK0 = BlockFmhaShape::kK0;
46 static constexpr index_t kK1 = BlockFmhaShape::kK1;
47 static constexpr index_t kK2 = BlockFmhaShape::kK2;
48 static constexpr index_t kK3 = BlockFmhaShape::kK3;
49 static constexpr index_t kK4 = BlockFmhaShape::kK4;
50 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
51 static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
52
53 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
54 static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
55 static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
56 static constexpr auto BiasEnum = Problem::BiasEnum;
57 static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
58 static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
59 static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
60 static_assert(kUseTrLoad, "This pipeline uses trload!");
61
62 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
63 // ... together with tensor distribution. tensor dist should able to overwrite this
64 static constexpr index_t kAlignmentQ =
65 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
66 static constexpr index_t kAlignmentK =
67 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
68 static constexpr index_t kAlignmentV =
69 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
70 static constexpr index_t kAlignmentOGrad =
71 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
72 static constexpr index_t kAlignmentQGrad = 1;
73 static constexpr index_t kAlignmentKGrad =
74 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
75 static constexpr index_t kAlignmentVGrad =
76 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
77 static constexpr index_t kAlignmentBias = 1;
78
79 static constexpr const char* name = "trload_kr_ktr_vr";
80
82 {
83 return Policy::template GetSmemSize<Problem>();
84 }
85
87 {
88 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
89 return (raw_lse == -numeric<LSEDataType>::infinity()) //
91 : raw_lse;
92 else
93 return raw_lse;
94 };
95
96 template <typename... Ts>
97 CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
98 {
99 // LDS allocation
100 const auto smem_ptr_ =
101 reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
102
103 const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
104 const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
105 smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
106
107 const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
108 const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
109 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
110 const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
111 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
112 Policy::template GetSmemSizeQ<Problem>());
113 const auto d_lds_ptr = reinterpret_cast<DDataType*>(
114 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
115 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
116
117 const auto ds_lds_ptr =
118 reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
119 Policy::template GetSmemSizeV<Problem>());
120 const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
121 return run(k_lds_ptr,
122 v_lds_ptr,
123 do_lds_ptr,
124 q_lds_ptr,
125 lse_lds_ptr,
126 d_lds_ptr,
127 ds_lds_ptr,
128 bias_lds_ptr,
129 std::forward<Ts>(args)...);
130 }
131
132 template <typename QDramBlockWindowTmp,
133 typename KDramBlockWindowTmp,
134 typename VDramBlockWindowTmp,
135 typename BiasDramBlockWindowTmp,
136 typename RandValDramBlockWindowTmp,
137 typename OGradDramBlockWindowTmp,
138 typename LSEDramBlockWindowTmp,
139 typename DDramBlockWindowTmp,
140 typename QGradDramBlockWindowTmp,
141 typename KGradDramBlockWindowTmp,
142 typename VGradDramBlockWindowTmp,
143 typename BiasGradDramBlockWindowTmp,
144 typename QGradEpilogue,
145 typename KGradEpilogue,
146 typename VGradEpilogue,
147 typename PositionEncoding>
149 KDataType* __restrict__ k_lds_ptr,
150 VDataType* __restrict__ v_lds_ptr,
151 OGradDataType* __restrict__ do_lds_ptr,
152 QDataType* __restrict__ q_lds_ptr,
153 LSEDataType* __restrict__ lse_lds_ptr,
154 DDataType* __restrict__ d_lds_ptr,
155 GemmDataType* __restrict__ ds_lds_ptr,
156 BiasDataType* __restrict__ bias_lds_ptr,
157 const QDramBlockWindowTmp& q_dram_block_window_tmp,
158 const KDramBlockWindowTmp& k_dram_block_window_tmp,
159 const VDramBlockWindowTmp& v_dram_block_window_tmp,
160 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
161 const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
162 const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
163 const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
164 const DDramBlockWindowTmp& d_dram_block_window_tmp,
165 const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
166 const KGradDramBlockWindowTmp& dk_dram_block_window_tmp,
167 const VGradDramBlockWindowTmp& dv_dram_block_window_tmp,
168 const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
169 const QGradEpilogue& dq_epilogue,
170 const KGradEpilogue& dk_epilogue,
171 const VGradEpilogue& dv_epilogue,
172 FmhaMask mask,
173 PositionEncoding position_encoding,
174 float raw_scale,
175 float scale,
176 float rp_undrop,
177 float scale_rp_undrop,
178 FmhaDropout& dropout) const
179 {
180 static_assert(
181 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
182 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
183 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
184 std::is_same_v<OGradDataType,
186 std::is_same_v<LSEDataType,
188 std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
189 "wrong!");
190
191 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
192 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
193 kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
194 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
195 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
196 kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
197 kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
198 kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
199 kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
200 kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
201 kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
202 "wrong!");
203
204 // Block GEMM
205 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
206 constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
207 constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
208 constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
209 constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
210
211 const auto q_origin = q_dram_block_window_tmp.get_window_origin();
212
213 // Early termination
214 const auto [seqlen_kv_start, seqlen_kv_end] =
215 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
216
217 const auto num_total_loop = integer_divide_ceil(seqlen_kv_end - seqlen_kv_start, kN0);
218
219 // K, HBM ->LDS ->Reg
220 auto k_dram_window =
221 make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
222 k_dram_block_window_tmp.get_bottom_tensor_view()),
223 k_dram_block_window_tmp.get_window_lengths(),
224 {seqlen_kv_start, 0},
225 Policy::template MakeKDramTileDistribution<Problem>());
226
228 k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
229 auto k_lds_write_window =
231
232 //------------------------------------------------------------------
233 // V, HBM ->LDS ->Reg
234 auto v_dram_window =
235 make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
236 v_dram_block_window_tmp.get_bottom_tensor_view()),
237 v_dram_block_window_tmp.get_window_lengths(),
238 {seqlen_kv_start, 0},
239 Policy::template MakeVDramTileDistribution<Problem>());
241 v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
242 auto v_lds_write_window =
244
245 //------------------------------------------------------------------
246 // KT, HBM -> LDS --trload-->Reg
247
248 //------------------------------------------------------------------
249 // Pre-Load KV into Registers
251 k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
252 auto k_lds_read_window =
253 make_tile_window(k_lds_read,
255 k_lds_write_window.get_window_origin(),
256 Policy::template MakeKRegBlockDescriptor<Problem>());
257
258 auto kt_lds_read_window =
259 make_tile_window(k_lds_read,
261 {0, 0},
262 Policy::template MakeKTRegBlockDescriptor<Problem>());
263
265 v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
266 auto v_lds_read_window =
267 make_tile_window(v_lds_read,
269 v_lds_write_window.get_window_origin(),
270 Policy::template MakeVRegBlockDescriptor<Problem>());
271
272 //---------------------------- Loop Load in ----------------------------//
273 // Q: HBM -->LDS
274 auto q_dram_window =
275 make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
276 q_dram_block_window_tmp.get_bottom_tensor_view()),
277 q_dram_block_window_tmp.get_window_lengths(),
278 {0, 0},
279 Policy::template MakeQDramTileDistribution<Problem>());
280
282 q_lds_ptr, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
283 auto q_lds_write_window =
285
287 q_lds_ptr, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
288 auto q_lds_read_window =
289 make_tile_window(q_lds_read,
291 q_lds_write_window.get_window_origin(),
292 Policy::template MakeQRegSliceBlockDescriptor<Problem>());
293 auto qt_lds_read_window =
294 make_tile_window(q_lds_read,
296 {0, 0},
297 Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
298
299 // dO: HBM ->LDS ---load--> Reg
300 // dOT: \-loadtr-> Reg
301 auto do_dram_window =
302 make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
303 do_dram_block_window_tmp.get_bottom_tensor_view()),
304 do_dram_block_window_tmp.get_window_lengths(),
305 {0, 0},
306 Policy::template MakeOGradDramTileDistribution<Problem>());
307
309 do_lds_ptr, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
310 auto do_lds_write_window =
312
314 do_lds_ptr, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
315 auto do_lds_read_window =
316 make_tile_window(do_lds_read,
318 do_lds_write_window.get_window_origin(),
319 Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
320 auto dot_lds_read_window =
321 make_tile_window(do_lds_read,
323 {0, 0},
324 Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
325
326 // dS: Reg -> Reg -> LDS
328 ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
329
330 auto ds_lds_window =
331 make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
332
333 // transform it to make it from col-major to row-major; prepared for load_tile_transpose
335 ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
336 auto ds_lds_read_window =
337 make_tile_window(ds_lds_t,
339 {0, 0},
340 Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
341
342 // Bias: HBM ->Reg ->Reg ->LDS
343 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
344
345 auto bias_dram_window =
346 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
347 bias_dram_block_window_tmp.get_window_lengths(),
348 {bias_origin.at(number<0>{}), seqlen_kv_start},
349 Policy::template MakeBiasTileDistribution<Problem>());
350
352 bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
353 auto bias_lds_write_window =
354 make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
355
356 auto bias_s_lds_read_window =
357 make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
358 bias_lds_write_window.get_window_lengths(),
359 bias_lds_write_window.get_window_origin(),
360 Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
361
362 static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
363 "BiasDataType and BiasGradDataType should be the same!");
364
365 // LSE: HBM -> LDS ->Reg
366 auto lse_dram_window =
367 make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
368 lse_dram_block_window_tmp.get_window_lengths(),
369 {0},
370 Policy::template MakeLSEDDramTileDistribution<Problem>());
371
373 lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
374
375 auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
376
377 auto lse_lds_read_window =
378 make_tile_window(lse_lds,
380 {0},
381 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
382
383 // D: HBM ->Reg
384 auto d_dram_window =
385 make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
386 d_dram_block_window_tmp.get_window_lengths(),
387 {0},
388 Policy::template MakeLSEDDramTileDistribution<Problem>());
389
391 d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
392 auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
393 auto d_lds_read_window =
394 make_tile_window(d_lds,
396 {0},
397 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
398
399 // RandVal: HBM ->Reg
400 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
401 randval_dram_block_window_tmp, seqlen_kv_start);
402
403 // BiasGrad
404 // Reg ->LDS ->Reg ->HBM
405 const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
406
407 auto dbias_dram_window =
408 make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
409 dbias_dram_block_window_tmp.get_window_lengths(),
410 {dbias_origin.at(number<0>{}), seqlen_kv_start}); // M/N
411
412 auto dbias_lds_read_window =
413 make_tile_window(bias_lds,
415 {0, 0},
416 Policy::template MakeShuffledBiasTileDistribution<Problem>());
417
418 // ----------------------------Loop write out------------------------------//
419 auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
420 dq_dram_block_window_tmp.get_window_lengths(),
421 {0, 0});
422 auto dk_dram_window = make_tile_window(dk_dram_block_window_tmp.get_bottom_tensor_view(),
423 dk_dram_block_window_tmp.get_window_lengths(),
424 {0, 0});
425 auto dv_dram_window = make_tile_window(dv_dram_block_window_tmp.get_bottom_tensor_view(),
426 dv_dram_block_window_tmp.get_window_lengths(),
427 {0, 0});
428
429 index_t i_total_loops = 0;
430 index_t seqlen_kv_step = seqlen_kv_start;
431 static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
432 static_assert(kM0 == kK1, "kM0 should equal to kK1");
433 static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
434 static_assert(kM0 == kK3, "kM0 should equal to kK3");
435 constexpr index_t k4_loops = kN0 / kK4;
436
437 __builtin_amdgcn_sched_barrier(0);
438
439 decltype(load_tile(q_lds_read_window)) q_reg_tensor;
440 decltype(load_tile(lse_lds_read_window)) lse;
441 decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
442 decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
443 decltype(load_tile(do_lds_read_window)) do_reg_tensor;
444 decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
445 decltype(load_tile(d_lds_read_window)) d;
446 decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
447 decltype(gemm_0.MakeCBlockTile()) s_acc, p;
448 decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
449 decltype(gemm_4.MakeCBlockTile()) dq_acc;
450 clear_tile(dq_acc);
451
452 decltype(load_tile(lse_dram_window)) lse_block_tile;
453 decltype(load_tile(d_dram_window)) d_block_tile;
454
455 async_load_tile(q_lds_write_window, q_dram_window);
456 async_load_tile(do_lds_write_window, do_dram_window);
457 __builtin_amdgcn_s_waitcnt(0);
458 qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
459 q_reg_tensor = load_tile(q_lds_read_window);
460 dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
461 do_reg_tensor = load_tile(do_lds_read_window);
462
463 lse_block_tile = load_tile(lse_dram_window);
464 d_block_tile = load_tile(d_dram_window);
465 __builtin_amdgcn_s_waitcnt(0);
466 store_tile(lse_lds_write_window, lse_block_tile);
467 store_tile(d_lds_write_window, d_block_tile);
468 __builtin_amdgcn_s_waitcnt(0);
469 lse = load_tile(lse_lds_read_window);
470 d = load_tile(d_lds_read_window);
471
472 auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
473 constexpr bool is_prologue = is_prologue_.value;
474 constexpr bool is_epilogue = is_epilogue_.value;
475 static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
476 constexpr bool is_main_body = is_prologue && is_epilogue;
477
478 // init VGrad & KGrad
479 decltype(gemm_1.MakeCBlockTile()) dv_acc;
480 decltype(gemm_3.MakeCBlockTile()) dk_acc;
481
482 decltype(load_tile(k_lds_read_window)) k_reg_tensor;
483 decltype(load_tile(v_lds_read_window)) v_reg_tensor;
484 decltype(load_tile_transpose(kt_lds_read_window)) kt_reg_tensor;
485
486 if constexpr(is_epilogue)
487 {
488 async_load_tile(k_lds_write_window, k_dram_window);
489 move_tile_window(k_dram_window, {kN0, 0});
490 async_load_tile(v_lds_write_window, v_dram_window);
491 move_tile_window(v_dram_window, {kN0, 0});
492 s_waitcnt</*vmcnt=*/0>();
493 k_reg_tensor = load_tile(k_lds_read_window);
494 v_reg_tensor = load_tile(v_lds_read_window);
495 kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
496 }
497 if constexpr(is_epilogue)
498 {
499 // STAGE 1, Q@K Gemm0
500 s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
501 }
502 if constexpr(is_main_body)
503 Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
504 __builtin_amdgcn_sched_barrier(0);
505 if constexpr(is_epilogue)
506 {
507 // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
508 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
509 {
510 const auto bias_tile = load_tile(bias_dram_window);
511 auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
512 Policy::template MakeShuffledBiasTileDistribution<Problem>());
513 shuffle_tile(shuffled_bias_tile, bias_tile);
514 store_tile(bias_lds_write_window, shuffled_bias_tile);
516 auto bias_s_tile = load_tile(bias_s_lds_read_window);
518 [&](auto& x, const auto& y) {
520 },
521 s_acc,
522 bias_s_tile);
523 move_tile_window(bias_dram_window, {kM0, 0});
524 __builtin_amdgcn_sched_barrier(0);
525 }
526 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
527 {
528 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
529 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
530 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
531 const auto tile_idx = get_x_indices_from_distributed_indices(
532 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
533
534 const auto row = tile_idx.at(number<0>{});
535 const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
536 constexpr auto i_j_idx = make_tuple(idx0, idx1);
537
538 s_acc(i_j_idx) *= scale;
539 position_encoding.update(s_acc(i_j_idx), row, col);
540 });
541 });
542 }
543
544 {
545 bool need_perpixel_check =
546 mask.IsEdgeTile(0, seqlen_kv_step, number<kM0>{}, number<kN0>{});
547 if(need_perpixel_check)
548 {
549 set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
550 const auto row = tile_idx.at(number<0>{});
551 const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
552 return mask.IsOutOfBound(row, col);
553 });
554 }
555 }
556
557 constexpr auto p_spans = decltype(p)::get_distributed_spans();
558 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
559 constexpr auto i_idx = make_tuple(idx0);
560 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
561
562 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
563 constexpr auto i_j_idx = make_tuple(idx0, idx1);
564
565 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
567 p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
568 else
569 p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
570 });
571 });
572
573 if constexpr(FmhaDropout::IsDropout)
574 {
575 dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
576 0, seqlen_kv_step, p, randval_dram_window);
577 }
578 const auto p_gemm = [&]() { // dropout / type conversion
579 if constexpr(FmhaDropout::IsDropout)
580 {
581 return tile_elementwise_in(
582 [](const auto& x) {
583 return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
584 },
585 p);
586 }
587 else
588 {
589 return cast_tile<GemmDataType>(p);
590 }
591 }();
592
593 // STAGE 4, OGrad@V Gemm2
594 dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
595
596 // STAGE 3, P^T@OGrad^T Gemm1
598 Policy::template MakePTRegSliceBlockDescriptor<Problem>());
599 pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
600
601 dv_acc = gemm_1(pt_reg_tensor, dot_reg_tensor);
602 }
604 if constexpr(is_main_body)
605 Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
606 __builtin_amdgcn_sched_barrier(0);
607 if constexpr(is_epilogue)
608 {
609 // STAGE 5, P^T(PGrad^T - D)
610 constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
611 sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
612 constexpr auto i_idx = make_tuple(idx0);
613 sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
614 constexpr auto i_j_idx = make_tuple(idx0, idx1);
615 bool undrop_flag = p[i_j_idx] >= 0;
616 ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
617 ? (dp_acc[i_j_idx] - d[i_idx])
618 : d[i_idx]);
619 });
620 });
621
622 if constexpr(kHasBiasGrad)
623 {
624 const auto dbias = [&]() {
625 if constexpr(FmhaDropout::IsDropout)
626 {
627 return tile_elementwise_in(
628 [&rp_undrop](const auto& x) {
629 return type_convert<BiasGradDataType>(x * rp_undrop);
630 },
631 ds);
632 }
633 else
634 {
636 }
637 }();
638 store_tile(bias_lds_write_window, dbias);
639 s_waitcnt</*vmcnt=*/0>();
641 auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
643 Policy::template MakeBiasTileDistribution<Problem>());
644 shuffle_tile(dbias_tile, shuffled_dbias_tile);
645 store_tile(dbias_dram_window, dbias_tile);
646 move_tile_window(dbias_dram_window, {kM0, 0});
647 __builtin_amdgcn_sched_barrier(0);
648 }
649 }
650 if constexpr(is_epilogue)
651 {
652 // STAGE 6, SGrad^T@Q^T Gemm3
653 const auto ds_gemm = cast_tile<GemmDataType>(ds);
655 Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
656 dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
657 dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
658
659 if constexpr(kHasBiasGrad)
660 {
661 // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
662 // LDS.
664 }
665 store_tile(ds_lds_window, ds_gemm);
666 }
667 s_waitcnt</*vmcnt=*/0>();
669 if constexpr(is_epilogue)
670 {
671 ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
672 move_tile_window(ds_lds_read_window, {kK4, 0});
673 }
674 if constexpr(is_main_body)
675 Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
676 __builtin_amdgcn_sched_barrier(0);
677 if constexpr(is_epilogue)
678 {
679 // STAGE7 SGrad@K^T Gemm4
680 static_for<0, k4_loops, 1>{}([&](auto i_k4) {
681 if constexpr(i_k4 < k4_loops - 1)
682 {
683 ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
684 move_tile_window(ds_lds_read_window, {kK4, 0});
685 }
686 auto kt_reg_tensor_slice = get_slice_tile( //
687 kt_reg_tensor,
689 sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
690 gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
691
692 if constexpr(i_k4 < k4_loops - 1)
693 {
694 ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
695 }
696 });
697 move_tile_window(ds_lds_read_window, {-kN0, 0});
698 }
700 if constexpr(is_main_body)
701 Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
702 if constexpr(is_epilogue)
703 {
704 // Results Scale
705 if constexpr(FmhaDropout::IsDropout)
706 {
707 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
708 dk_acc);
709 tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
710 }
711 else
712 {
713 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
714 }
715
716 dk_epilogue(dk_dram_window, dk_acc, nullptr);
717 move_tile_window(dk_dram_window, {kN0, 0});
718 dv_epilogue(dv_dram_window, dv_acc, nullptr);
719 move_tile_window(dv_dram_window, {kN0, 0});
720 }
721 };
722
723 for(index_t i = 0; i < seqlen_kv_start; i += kN0)
724 {
725 dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
726 move_tile_window(dk_dram_window, {kN0, 0});
727 dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
728 move_tile_window(dv_dram_window, {kN0, 0});
729 }
730
731 main_body(std::true_type{}, std::false_type{});
732 // Hot loop
733 if(num_total_loop > 1)
734 {
735 do
736 {
737 main_body(std::true_type{}, std::true_type{});
738 i_total_loops += 1;
739 seqlen_kv_step += kN0;
740 } while(i_total_loops < num_total_loop - 1);
741 }
742 main_body(std::false_type{}, std::true_type{});
743 seqlen_kv_step += kN0;
744
745 const auto k_length = k_dram_block_window_tmp.get_window_lengths();
746 const auto seqlen_kv_length = k_length.at(number<0>{});
747 for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
748 {
749 dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
750 move_tile_window(dk_dram_window, {kN0, 0});
751 dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
752 move_tile_window(dv_dram_window, {kN0, 0});
753 }
754
755 // QGrad Scale
756 if constexpr(FmhaDropout::IsDropout)
757 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
758 dq_acc);
759 else
760 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
761 dq_epilogue(dq_dram_window, dq_acc, nullptr);
762 return;
763 }
764};
765
766// We don't support C++20 concepts yet, so we use SFINAE check the existence and truthiness
767// of is_qr_qtr_dor_pipeline static member instead of using concepts directly.
768//
769// The template struct's value field is equivalent to the following commented concept definition.
770//
771// template <class T>
772// concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline;
773
774// SFINAE test for existence and truthiness of static member is_qr_qtr_dor_pipeline.
775template <typename, typename = void>
776struct fmha_bwd_qr_qtr_dor_pipeline : std::false_type
777{
778};
779
780template <typename T>
781struct fmha_bwd_qr_qtr_dor_pipeline<T, std::void_t<decltype(T::is_qr_qtr_dor_pipeline)>>
782 : std::bool_constant<T::is_qr_qtr_dor_pipeline>
783{
784};
785
786} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:119
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void s_waitcnt()
Definition arch.hpp:241
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
STL namespace.
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:16
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:30
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:51
static constexpr index_t kAlignmentOGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:70
static constexpr index_t kAlignmentQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:64
static constexpr bool kUseTrLoad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:59
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:58
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:25
static constexpr index_t kM0
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:43
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:24
static constexpr auto BiasEnum
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:56
remove_cvref_t< typename Problem::VGradDataType > VGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:32
remove_cvref_t< typename Problem::KGradDataType > KGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:31
remove_cvref_t< typename Problem::GemmDataType > GemmDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:22
static constexpr index_t kK2
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:47
static constexpr index_t kAlignmentQGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:72
static constexpr index_t kAlignmentBias
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:77
static constexpr index_t kBlockSize
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:41
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:40
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:38
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:50
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:54
static constexpr index_t kAlignmentKGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:73
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:27
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:55
static constexpr auto is_qr_qtr_dor_pipeline
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:17
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:34
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:28
static constexpr index_t kK1
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:46
CK_TILE_DEVICE auto run(KDataType *__restrict__ k_lds_ptr, VDataType *__restrict__ v_lds_ptr, OGradDataType *__restrict__ do_lds_ptr, QDataType *__restrict__ q_lds_ptr, LSEDataType *__restrict__ lse_lds_ptr, DDataType *__restrict__ d_lds_ptr, GemmDataType *__restrict__ ds_lds_ptr, BiasDataType *__restrict__ bias_lds_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const KGradDramBlockWindowTmp &dk_dram_block_window_tmp, const VGradDramBlockWindowTmp &dv_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, const QGradEpilogue &dq_epilogue, const KGradEpilogue &dk_epilogue, const VGradEpilogue &dv_epilogue, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:148
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:53
remove_cvref_t< typename Problem::OGradDataType > OGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:29
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:21
remove_cvref_t< typename Problem::DDataType > DDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:26
static constexpr index_t kAlignmentV
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:68
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:81
static constexpr bool kHasBiasGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:57
static CK_TILE_HOST_DEVICE LSEDataType get_validated_lse(const LSEDataType raw_lse)
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:86
CK_TILE_DEVICE auto operator()(void *smem_ptr, Ts &&... args) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:97
remove_cvref_t< typename Problem::FmhaDropout > FmhaDropout
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:35
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:19
static constexpr index_t kK0
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:45
remove_cvref_t< typename Problem::BiasGradDataType > BiasGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:33
static constexpr index_t kAlignmentK
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:66
static constexpr index_t kN0
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:44
static constexpr const char * name
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:79
static constexpr index_t kK3
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:48
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:20
static constexpr index_t kK4
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:49
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:23
static constexpr index_t kAlignmentVGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:75
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:777
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43