195 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
231 return static_cast<const DDataType*
>(
nullptr);
244 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245 const index_t gridy = NSwizzle ? 1 : mblock;
247 return std::make_tuple(gridx, gridy, 1);
267 auto K_t = K_Batch * KPerBlock;
268 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
273 auto K_t = K_Batch * KPerBlock;
274 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
279 auto K_t = K_Batch * KPerBlock;
280 return (K + K_t - 1) / K_t * KPerBlock;
286 auto K_t = K_Batch * KReadVec;
287 return (K + K_t - 1) / K_t * KReadVec;
300 template <
index_t MNXdlPerWave,
304 typename TileDesc_K0_MN_K1>
322 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
324 const auto a_grid_desc_mraw_kraw = [&]() {
337 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
338 GemmSpec == GemmSpecialization::MNKPadding)
341 const auto a_grid_desc_m_k =
355 return a_grid_desc_ak0_m_ak1;
357 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
358 GemmSpec == GemmSpecialization::MNPadding)
362 a_grid_desc_mraw_kraw,
368 return a_grid_desc_ak0_m_ak1;
370 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
371 GemmSpec == GemmSpecialization::NKPadding)
375 a_grid_desc_mraw_kraw,
387 return a_grid_desc_ak0_m_ak1;
393 a_grid_desc_mraw_kraw,
399 return a_grid_desc_ak0_m_ak1;
406 const auto b_grid_desc_nraw_kraw = [&]() {
420 GemmSpec != GemmSpecialization::Default),
421 "pk_i4_t does not support padding");
423 GemmSpec != GemmSpecialization::Default),
424 "f4x2_pk_t does not support padding");
426 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
427 GemmSpec == GemmSpecialization::MNKPadding)
430 const auto b_grid_desc_n_k =
444 return b_grid_desc_bk0_n_bk1;
446 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
447 GemmSpec == GemmSpecialization::MNPadding)
451 b_grid_desc_nraw_kraw,
457 return b_grid_desc_bk0_n_bk1;
459 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
460 GemmSpec == GemmSpecialization::MKPadding)
464 b_grid_desc_nraw_kraw,
476 return b_grid_desc_bk0_n_bk1;
482 b_grid_desc_nraw_kraw,
488 return b_grid_desc_bk0_n_bk1;
492 template <
typename ABlockDesc_AK0_M_AK1>
493 __host__ __device__
static constexpr auto
496 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
499 ABlockDesc_AK0_M_AK1{});
502 template <
typename BBlockDesc_BK0_N_BK1>
503 __host__ __device__
static constexpr auto
506 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
509 BBlockDesc_BK0_N_BK1{});
512 template <
typename ELayout>
514 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
516 const auto c_grid_desc_mraw_nraw = [&]() {
535 template <
typename DLayout>
536 __host__ __device__
static auto
539 const auto c_grid_desc_mraw_nraw = [&]() {
569 template <
typename DsGr
idDesc>
571 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
576 ds_grid_desc_m_n[i], MBlock, NBlock);
592 std::array<index_t, NumDTensor> StrideDs_,
620 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
621 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
625 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
626 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
627 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
656 const index_t* p_sorted_expert_ids_,
657 const index_t* p_max_token_id_,
658 const ADataType* p_a_grid_,
659 const AScaleDataType* p_a_scale_grid_,
660 const BDataType* p_b_grid_,
661 const BScaleDataType* p_b_scale_grid_,
662 std::array<const void*, NumDTensor> p_ds_grid_,
663 CDataType* p_c_grid_,
673 std::array<index_t, NumDTensor> StrideDs_,
676 AElementwiseOperation a_element_op_,
677 BElementwiseOperation b_element_op_,
678 CElementwiseOperation c_element_op_)
710 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
774 if(k_id < karg.
KBatch - 1)
792 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
793 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
794 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
807 constexpr auto a_lds_block_desc =
819 return a_lds_block_desc_permuted;
826 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
827 constexpr auto M1 = MPerBlock / M0;
829 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
830 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
831 constexpr auto KThreadRead = WaveSize / MPerXdl;
832 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
834 constexpr auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
836 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
837 constexpr auto KThreadReadPerm =
838 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
839 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
843 constexpr auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
845 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
847 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
853 Number<kfold * M0 / mpair>{},
872 a_lds_block_desc_permuted,
894 a_lds_block_desc_unmerged,
897 Number<KThreadWrite / kfold / KThreadReadPerm>{},
906 return a_lds_block_desc_ak0_m_ak1;
912 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
913 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
914 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
926 constexpr auto b_lds_block_desc =
938 return b_lds_block_desc_permuted;
942 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
943 constexpr auto N1 = NPerBlock / N0;
945 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
946 constexpr auto K0PerThreadWrite =
BK0Number / KThreadWrite;
947 constexpr auto KThreadRead = WaveSize / NPerXdl;
948 constexpr auto K0PerThreadRead =
BK0Number / KThreadRead;
950 constexpr auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
952 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
953 constexpr auto KThreadReadPerm =
954 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
955 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
959 constexpr auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
961 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
963 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
969 Number<kfold * N0 / npair>{},
988 b_lds_block_desc_permuted,
1010 b_lds_block_desc_unmerged,
1013 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1022 return b_lds_block_desc_bk0_n_bk1;
1028 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1029 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1031 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1038 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1059 ABlockTransferSrcScalarPerVector,
1060 BBlockTransferSrcScalarPerVector,
1081 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1084 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1087 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1090 constexpr auto c_block_size =
1091 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1093 if constexpr(IsInputGemm)
1095 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1096 b_block_space_size_aligned *
sizeof(BDataType)) *
1098 c_block_size *
sizeof(CShuffleDataType));
1102 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
1103 b_block_space_size_aligned *
sizeof(BDataType)),
1104 c_block_size *
sizeof(CShuffleDataType));
1113 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1114 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1115 "Invalid tuning param!");
1117 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1118 "KPerBlock should be multiple of ScaleBlockSize");
1126 if(!(karg.M % MPerBlock == 0))
1130 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1131 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1144 if(!(karg.N % NPerBlock == 0))
1148 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1149 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1161 auto K_t = karg.KBatch * KPerBlock;
1162 if(!(karg.K % K_t == 0))
1166 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1167 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1168 <<
", in function: " << __func__ << std::endl;
1176 auto K_t = karg.KBatch * KReadVec;
1178 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1186 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1190 std::cout <<
"Arg K (" << karg.K
1191 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1192 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1193 << __LINE__ <<
", in function: " << __func__ << std::endl;
1200 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1204 std::cout <<
"Arg M (" << karg.M
1205 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1206 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1207 << __LINE__ <<
", in function: " << __func__ << std::endl;
1215 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1219 std::cout <<
"Arg N (" << karg.N
1220 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1221 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1222 << __LINE__ <<
", in function: " << __func__ << std::endl;
1229 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1233 std::cout <<
"Arg K (" << karg.K
1234 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1235 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1236 << __LINE__ <<
", in function: " << __func__ << std::endl;
1248 std::cout <<
"Arg N (" << karg.N
1249 <<
") value is not a multiple of "
1250 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1252 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1264 std::cout <<
"Arg M (" << karg.M
1265 <<
") value is not a multiple of "
1266 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1268 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1278 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1280 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1291 const index_t num_loop = K / KPerBlock;
1293 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1298 const index_t num_loop = K / KPerBlock;
1300 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1303 template <
typename CGr
idDesc>
1305 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1314 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1326 "A scale pack data type too large!");
1328 "B scale pack data type too large!");
1330 template <
bool HasMainKBlockLoop,
1334 const index_t* p_sorted_expert_ids,
1335 const index_t* p_max_token_id,
1336 const ADataType* p_a_grid,
1337 const AScaleDataType* p_a_scale_grid,
1338 const BDataType* p_b_grid,
1339 const BScaleDataType* p_b_scale_grid,
1341 CDataType* p_c_grid,
1343 const Problem& problem,
1344 AElementwiseOperation a_element_op,
1345 BElementwiseOperation b_element_op,
1346 CElementwiseOperation c_element_op)
1350 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1357 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1359 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1377 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1379 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1381 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1382 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1383 if(expert_block_id * MPerBlock >= max_token_id)
1386 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1388 const auto block_mn = [&]() -> std::pair<int, int> {
1389 if constexpr(NSwizzle)
1391 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1392 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1393 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1394 const index_t expert_swizzle =
1395 ecnt > 0 ? ecnt : 1;
1396 const index_t bid_new = blockIdx.x - prefix_block;
1397 const index_t nid = __builtin_amdgcn_readfirstlane(
1398 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1400 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1405 return {blockIdx.x, blockIdx.y};
1409 const index_t block_n_id = block_mn.first;
1410 const index_t block_m_id = block_mn.second;
1412 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1415 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1416 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1417 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1418 constexpr auto AKThreads = AK0Threads * AK1Threads;
1419 constexpr auto AMRepeats = MPerBlock / AMThreads;
1420 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1422 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1426 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1427 index_t token_offset = fused_token & 0xffffff;
1428 if constexpr(!IsInputGemm)
1430 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1432 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.K;
1436 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1437 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1438 problem.N * (IsInputGemm ? 2 : 1) *
1442 const index_t n_block_data_idx_on_grid =
1443 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1447 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1449 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1453 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1455 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
1456 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1470 AElementwiseOperation,
1474 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1475 ABlockTransferThreadClusterArrangeOrder,
1478 decltype(a_grid_desc_ak0_m_ak1),
1479 decltype(a_block_desc_ak0_m_ak1),
1480 ABlockTransferSrcAccessOrder,
1482 ABlockTransferSrcVectorDim,
1484 ABlockTransferSrcScalarPerVector,
1485 ABlockTransferDstScalarPerVector_AK1,
1488 AThreadTransferSrcResetCoordinateAfterRun,
1492 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1495 a_block_desc_ak0_m_ak1,
1501 auto b_blockwise_copy =
1503 BElementwiseOperation,
1507 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1508 BBlockTransferThreadClusterArrangeOrder,
1511 decltype(b_grid_desc_bk0_n_bk1),
1512 decltype(b_block_desc_bk0_n_bk1),
1513 BBlockTransferSrcAccessOrder,
1515 BBlockTransferSrcVectorDim,
1517 BBlockTransferSrcScalarPerVector,
1518 BBlockTransferDstScalarPerVector_BK1,
1521 BThreadTransferSrcResetCoordinateAfterRun,
1523 BlockwiseGemmPipe::GlobalBufferNum>(
1524 b_grid_desc_bk0_n_bk1,
1527 b_block_desc_bk0_n_bk1,
1533 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1537 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1540 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1541 a_block_space_size_aligned *
sizeof(ADataType)),
1542 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1548 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1550 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1551 decltype(c_thread_buf) c_thread_buf_up;
1555 c_thread_buf.num_of_v_,
1556 c_thread_buf.s_per_v,
1560 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1561 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1565 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1566 const auto waveId_m = wave_idx[
I0];
1567 const auto waveId_n = wave_idx[
I1];
1569 auto thread_offset_shuffled =
1572 auto a_thread_offset_m = waveId_m;
1577 decltype(a_scale_grid_desc_am_ak),
1578 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1584 true>(a_scale_grid_desc_am_ak,
1590 auto b_thread_offset_n = waveId_n;
1595 decltype(b_scale_grid_desc_bn_ak),
1596 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1602 true>(b_scale_grid_desc_bn_ak,
1607 if constexpr(IsInputGemm)
1610 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1612 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1613 a_block_space_size_aligned *
sizeof(ADataType) +
1614 b_block_space_size_aligned *
sizeof(BDataType)),
1615 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1617 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1619 p_b_grid_up + expert_id * expert_stride,
1620 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1622 auto b_blockwise_copy_up =
1624 BElementwiseOperation,
1628 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1629 BBlockTransferThreadClusterArrangeOrder,
1632 decltype(b_grid_desc_bk0_n_bk1),
1633 decltype(b_block_desc_bk0_n_bk1),
1634 BBlockTransferSrcAccessOrder,
1636 BBlockTransferSrcVectorDim,
1638 BBlockTransferSrcScalarPerVector,
1639 BBlockTransferDstScalarPerVector_BK1,
1642 BThreadTransferSrcResetCoordinateAfterRun,
1644 BlockwiseGemmPipe::GlobalBufferNum>(
1645 b_grid_desc_bk0_n_bk1,
1648 b_block_desc_bk0_n_bk1,
1652 const BScaleDataType* p_b_scale_grid_up =
1653 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
1655 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
1656 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1661 decltype(b_scale_grid_desc_bn_ak),
1662 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1669 b_scale_grid_desc_bn_ak,
1676 a_grid_desc_ak0_m_ak1,
1677 a_block_desc_ak0_m_ak1,
1681 a_block_slice_copy_step,
1683 b_grid_desc_bk0_n_bk1,
1684 b_block_desc_bk0_n_bk1,
1686 b_blockwise_copy_up,
1691 b_block_slice_copy_step,
1696 a_scale_grid_desc_am_ak,
1697 a_scale_thread_copy,
1700 b_scale_grid_desc_bn_ak,
1701 b_scale_thread_copy,
1702 b_scale_thread_copy_up,
1704 b_scale_grid_buf_up,
1705 num_k_block_main_loop);
1710 a_grid_desc_ak0_m_ak1,
1711 a_block_desc_ak0_m_ak1,
1715 a_block_slice_copy_step,
1716 b_grid_desc_bk0_n_bk1,
1717 b_block_desc_bk0_n_bk1,
1721 b_block_slice_copy_step,
1723 a_scale_grid_desc_am_ak,
1724 a_scale_thread_copy,
1726 b_scale_grid_desc_bn_ak,
1727 b_scale_thread_copy,
1729 num_k_block_main_loop);
1734 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1735 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1737 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1738 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1741 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1742 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1745 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1746 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1750 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1751 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1753 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1754 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1755 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1756 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1757 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1758 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1759 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1760 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1761 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1762 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1765 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1766 static_assert(M5 == 4);
1776 const index_t m_pos = block_m_id * MPerBlock +
1777 m0 * M2 * M1 * M3 * M4 * M5 +
1778 m1 * M2 * M3 * M4 * M5 +
1779 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1781 if constexpr(MulRoutedWeight)
1785 p_ds_grid[
I2] + m_pos);
1789 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1790 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1793 if constexpr(IsInputGemm)
1795 if constexpr(ActivationOperation ==
1796 Activation::silu_and_mul)
1798 float gate = c_thread_buf[cidx];
1799 float up = c_thread_buf_up[cidx];
1800 if constexpr(MulRoutedWeight)
1802 gate = gate * topk_weights.AsType<
float>()[m5];
1803 up = up * topk_weights.AsType<
float>()[m5];
1806 c_thread_buf_fp32(cidx) = gate * up;
1808 else if(ActivationOperation == Activation::gelu_and_mul)
1810 float gate = c_thread_buf[cidx];
1811 float up = c_thread_buf_up[cidx];
1812 if constexpr(MulRoutedWeight)
1814 gate = gate * topk_weights.AsType<
float>()[m5];
1815 up = up * topk_weights.AsType<
float>()[m5];
1818 c_thread_buf_fp32(cidx) = gate * up;
1833 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1834 if constexpr(MulRoutedWeight)
1836 c_thread_buf_fp32(cidx) =
1837 topk_weights.AsType<
float>()[m5] *
1838 c_thread_buf_fp32[cidx];
1848 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1852 static_cast<CShuffleDataType*
>(p_shared),
1853 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1856 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1882 const auto c_thread_mtx_on_block =
1883 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1885 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1886 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1888 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1894 const auto m_thread_data_on_block_idx =
1895 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1898 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1904 const auto n_thread_data_on_block_idx =
1905 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1912 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1913 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1916 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1925 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1930 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1933 m_thread_data_on_block_idx[
I1],
1934 n_thread_data_on_block_idx[
I1],
1935 m_thread_data_on_block_idx[
I2],
1936 n_thread_data_on_block_idx[
I2],
1937 m_thread_data_on_block_idx[
I3],
1938 m_thread_data_on_block_idx[
I4],
1939 m_thread_data_on_block_idx[
I5],
1940 n_thread_data_on_block_idx[
I3]),
1943 using EDataType = CDataType;
1946 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1948 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1950 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1955 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1961 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1963 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1968 tie(c_shuffle_block_buf),
1970 {
return ds_grid_buf[i]; },
1974 const auto idx_c_ds_block_begin =
1984 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1985 c_grid_desc_mblock_mperblock_nblock_nperblock;
1987 using CDEBlockTransferCluster =
1988 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1989 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1990 constexpr index_t scatter_weight_idx = 3;
1995 decltype(c_ds_desc_refs),
1996 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1997 CElementwiseOperation,
2002 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2004 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2005 CDEBlockTransferCluster,
2011 CDEShuffleBlockTransferScalarPerVectors,
2023 idx_c_ds_block_begin,
2024 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2029 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2031 constexpr auto sfc_c_vgpr =
2042 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2044 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2054 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2057 constexpr auto sfc_cde_block =
2061 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2063 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2065 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2066 constexpr auto EMThreads =
2067 CDEBlockTransferCluster{}.
At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2068 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2069 constexpr auto ENThreads =
2070 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2075 auto dstidx = sfc_cde_block.GetIndex(access_id);
2077 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2079 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2080 IndexType token_offset = fused_token & 0xffffff;
2081 if constexpr(IsInputGemm)
2083 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2085 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.N;
2091 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2092 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2094 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2095 c_shuffle_block_buf);
2101 cde_block_copy_lds_and_global.Run(
2104 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2108 if constexpr(access_id < num_access - 1)
2110 constexpr auto cde_lds_and_global_step =
2111 sfc_cde_block.GetForwardStep(access_id);
2115 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2116 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2120 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2121 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2123 cde_lds_and_global_step);
2130 template <
bool HasMainKBlockLoop,
2133 __device__
static void Run_2Lds(
const index_t* p_sorted_token_ids,
2134 const index_t* p_sorted_expert_ids,
2135 const index_t* p_max_token_id,
2136 const ADataType* p_a_grid,
2137 const AScaleDataType* p_a_scale_grid,
2138 const BDataType* p_b_grid,
2139 const BScaleDataType* p_b_scale_grid,
2141 CDataType* p_c_grid,
2144 const Problem& problem,
2145 AElementwiseOperation a_element_op,
2146 BElementwiseOperation b_element_op,
2147 CElementwiseOperation c_element_op)
2151 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2158 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2160 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2178 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2180 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2181 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2183 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2184 if(expert_block_id * MPerBlock >= max_token_id)
2187 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2188 const auto block_mn = [&]() -> std::pair<int, int> {
2189 if constexpr(NSwizzle)
2191 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2192 const index_t prefix_block = ecnt_prefix * problem.NBlock;
2193 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2194 const index_t expert_swizzle =
2195 ecnt > 0 ? ecnt : 1;
2196 const index_t bid_new = blockIdx.x - prefix_block;
2197 const index_t nid = __builtin_amdgcn_readfirstlane(
2198 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2200 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2205 return {blockIdx.x, blockIdx.y};
2209 const index_t block_n_id = block_mn.first;
2210 const index_t block_m_id = block_mn.second;
2212 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2215 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2216 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2217 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2218 constexpr auto AKThreads = AK0Threads * AK1Threads;
2219 constexpr auto AMRepeats = MPerBlock / AMThreads;
2220 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2222 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2225 static_for<0, AMRepeats, 1>{}([&](
auto m0) {
2226 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2227 index_t token_offset = fused_token & 0xffffff;
2228 if constexpr(!IsInputGemm)
2230 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2232 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.K;
2236 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2237 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2241 const index_t n_block_data_idx_on_grid =
2242 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2245 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2248 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2251 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2253 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
2254 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2263 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2265 AElementwiseOperation,
2266 ck::tensor_operation::element_wise::PassThrough,
2268 Sequence<AK0Number, MPerBlock, AK1Number>,
2269 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2270 ABlockTransferThreadClusterArrangeOrder,
2273 decltype(a_grid_desc_ak0_m_ak1),
2274 decltype(a_block_desc_ak0_m_ak1),
2275 ABlockTransferSrcAccessOrder,
2277 ABlockTransferSrcVectorDim,
2279 ABlockTransferSrcScalarPerVector,
2280 ABlockTransferDstScalarPerVector_AK1,
2283 AThreadTransferSrcResetCoordinateAfterRun,
2287 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2290 a_block_desc_ak0_m_ak1,
2292 ck::tensor_operation::element_wise::PassThrough{},
2298 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2300 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2301 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2303 auto b_blockwise_copy =
2304 ThreadwiseTensorSliceTransfer_v2<BDataType,
2306 decltype(b_grid_desc_bpreshuffled),
2307 decltype(b_block_desc_bk0_n_bk1),
2313 Sequence<1, 2, 0, 3, 4>,
2315 BBlockTransferSrcScalarPerVector,
2316 BThreadTransferSrcResetCoordinateAfterRun,
2318 b_grid_desc_bpreshuffled,
2328 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2330 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2331 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2334 constexpr auto b_block_slice_copy_step =
make_multi_index(0, 0, 0, KRepeat, 0);
2337 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2339 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2340 decltype(c_thread_buf) c_thread_buf_up;
2344 c_thread_buf.num_of_v_,
2345 c_thread_buf.s_per_v,
2349 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2350 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2354 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2355 const auto waveId_m = wave_idx[
I0];
2356 const auto waveId_n = wave_idx[
I1];
2358 auto thread_offset_shuffled =
2361 auto a_thread_offset_m = waveId_m;
2364 const index_t token_scale_pos = block_m_id * MPerBlock;
2365 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2368 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2371 decltype(a_scale_grid_desc_am_ak),
2372 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2378 true>(a_scale_grid_desc_am_ak,
2384 auto b_thread_offset_n = waveId_n;
2386 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2389 decltype(b_scale_grid_desc_bn_ak),
2390 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2396 true>(b_scale_grid_desc_bn_ak,
2401 if constexpr(IsInputGemm)
2403 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2405 p_b_grid_up + expert_id * expert_stride /
BPackedSize,
2406 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2407 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2410 decltype(b_grid_desc_bpreshuffled),
2411 decltype(b_block_desc_bk0_n_bk1),
2413 Sequence<1, 2, 0, 3>,
2415 BBlockTransferSrcScalarPerVector,
2416 BThreadTransferSrcResetCoordinateAfterRun,
2417 true>(b_grid_desc_bpreshuffled,
2422 const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2424 p_b_scale_grid_up + expert_id * expert_scale_stride,
2425 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2426 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2429 decltype(b_scale_grid_desc_bn_ak),
2430 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2437 b_scale_grid_desc_bn_ak,
2443 a_grid_desc_ak0_m_ak1,
2444 a_block_desc_ak0_m_ak1,
2448 a_block_slice_copy_step,
2449 b_grid_desc_bpreshuffled,
2450 b_block_desc_bk0_n_bk1,
2452 b_blockwise_copy_up,
2456 b_block_slice_copy_step,
2459 a_scale_grid_desc_am_ak,
2460 a_scale_thread_copy,
2462 b_scale_grid_desc_bn_ak,
2463 b_scale_thread_copy,
2464 b_scale_thread_copy_up,
2466 b_scale_grid_buf_up,
2467 num_k_block_main_loop);
2472 a_grid_desc_ak0_m_ak1,
2473 a_block_desc_ak0_m_ak1,
2477 a_block_slice_copy_step,
2478 b_grid_desc_bpreshuffled,
2479 b_block_desc_bk0_n_bk1,
2483 b_block_slice_copy_step,
2485 a_scale_grid_desc_am_ak,
2486 a_scale_thread_copy,
2488 b_scale_grid_desc_bn_ak,
2489 b_scale_thread_copy,
2491 num_k_block_main_loop);
2496 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2497 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2501 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2502 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2506 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2507 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2509 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2510 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2511 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2512 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2513 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2514 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2515 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2516 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2520 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2521 static_assert(M4 == 4);
2525 vector_type<float, 4> topk_weights;
2526 static_for<0, NXdlPerWave, 1>{}([&](
auto n0) {
2527 static_for<0, MXdlPerWave, 1>{}([&](
auto m0) {
2528 static_for<0, M2, 1>{}([&](
auto m2) {
2529 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2530 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2531 if constexpr(MulRoutedWeight)
2534 p_ds_grid[
I2] + m_pos);
2536 static_for<0, M4, 1>{}([&](
auto m4) {
2538 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2546 if constexpr(IsInputGemm)
2548 if constexpr(ActivationOperation == Activation::silu_and_mul)
2550 float gate = c_thread_buf[cidx];
2551 float up = c_thread_buf_up[cidx];
2552 if constexpr(MulRoutedWeight)
2554 gate = gate * topk_weights.AsType<
float>()[m4];
2555 up = up * topk_weights.AsType<
float>()[m4];
2557 tensor_operation::element_wise::Silu{}(gate, gate);
2558 c_thread_buf_fp32(cidx) = gate * up;
2560 else if(ActivationOperation == Activation::gelu_and_mul)
2562 float gate = c_thread_buf[cidx];
2563 float up = c_thread_buf_up[cidx];
2564 if constexpr(MulRoutedWeight)
2566 gate = gate * topk_weights.AsType<
float>()[m4];
2567 up = up * topk_weights.AsType<
float>()[m4];
2569 tensor_operation::element_wise::Gelu{}(gate, gate);
2570 c_thread_buf_fp32(cidx) = gate * up;
2575 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2576 if constexpr(MulRoutedWeight)
2578 c_thread_buf_fp32(cidx) =
2579 topk_weights.AsType<
float>()[m4] * c_thread_buf_fp32[cidx];
2587 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2591 static_cast<CShuffleDataType*
>(p_shared),
2592 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2595 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2610 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2612 Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2616 const auto c_thread_mtx_on_block =
2617 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2619 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2620 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2622 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2628 const auto m_thread_data_on_block_idx =
2629 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2632 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2638 const auto n_thread_data_on_block_idx =
2639 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2643 auto c_thread_copy_vgpr_to_lds =
2644 ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2646 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2647 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2648 ck::tensor_operation::element_wise::PassThrough,
2649 Sequence<CShuffleMXdlPerWavePerShuffle,
2650 CShuffleNXdlPerWavePerShuffle,
2657 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2663 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2666 m_thread_data_on_block_idx[
I1],
2667 n_thread_data_on_block_idx[
I1],
2668 m_thread_data_on_block_idx[
I2],
2669 m_thread_data_on_block_idx[
I3],
2670 m_thread_data_on_block_idx[
I4],
2671 n_thread_data_on_block_idx[
I2]),
2672 ck::tensor_operation::element_wise::PassThrough{}};
2674 using EDataType = CDataType;
2677 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2679 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2681 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2686 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2692 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2694 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2699 tie(c_shuffle_block_buf),
2701 {
return ds_grid_buf[i]; },
2705 const auto idx_c_ds_block_begin =
2715 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2716 c_grid_desc_mblock_mperblock_nblock_nperblock;
2718 using CDEBlockTransferCluster =
2719 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2720 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2721 constexpr index_t scatter_weight_idx = 3;
2722 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2726 decltype(c_ds_desc_refs),
2727 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2728 CElementwiseOperation,
2729 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
2733 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2735 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2736 CDEBlockTransferCluster,
2737 Sequence<0, 1, 2, 3>,
2738 Sequence<0, 1, 2, 3>,
2739 Sequence<0, 1, 2, 3>,
2742 CDEShuffleBlockTransferScalarPerVectors,
2754 idx_c_ds_block_begin,
2755 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2760 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2761 constexpr auto sfc_c_vgpr =
2762 SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2763 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2764 Sequence<CShuffleMXdlPerWavePerShuffle,
2765 CShuffleNXdlPerWavePerShuffle,
2773 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2776 constexpr auto sfc_cde_block =
2777 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2778 Sequence<0, 2, 1, 3>,
2780 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2782 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2784 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2785 constexpr auto EMThreads =
2786 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2787 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2788 constexpr auto ENThreads =
2789 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2790 static_for<0, num_access, 1>{}([&](
auto access_id) {
2794 auto dstidx = sfc_cde_block.GetIndex(access_id);
2796 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2797 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
2798 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2799 IndexType token_offset = fused_token & 0xffffff;
2800 if constexpr(IsInputGemm)
2802 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2804 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.N;
2810 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2811 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2813 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2814 c_shuffle_block_buf);
2820 cde_block_copy_lds_and_global.Run(
2823 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2827 if constexpr(access_id < num_access - 1)
2829 constexpr auto cde_lds_and_global_step =
2830 sfc_cde_block.GetForwardStep(access_id);
2833 static_for<0, NumDTensor, 1>{}([&](
auto i) {
2834 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2835 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2839 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2840 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2842 cde_lds_and_global_step);