gridwise_gemm_pipeline_v3.hpp Source File

gridwise_gemm_pipeline_v3.hpp Source File#

Composable Kernel: gridwise_gemm_pipeline_v3.hpp Source File
gridwise_gemm_pipeline_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
11{
12 __host__ __device__ static constexpr bool IsSupported(index_t)
13 {
14 // TODO: improve applicability
15 return true;
16 }
17
18 template <typename AGridDesc,
19 typename ABlockDesc,
20 typename ABlockTransfer,
21 typename AGridBuffer,
22 typename ABlockBuffer,
23 typename ABlockTransferStep,
24 typename BGridDesc,
25 typename BBlockDesc,
26 typename BBlockTransfer,
27 typename BGridBuffer,
28 typename BBlockBuffer,
29 typename BBlockTransferStep,
30 typename BlockwiseGemm,
31 typename CThreadBuffer>
32 __device__ static void Run(const AGridDesc& a_grid_desc,
33 const ABlockDesc& a_block_desc,
34 ABlockTransfer& a_blockwise_copy,
35 const AGridBuffer& a_grid_buf,
36 ABlockBuffer& a_block_buf,
37 const ABlockTransferStep& a_block_copy_step,
38 const BGridDesc& b_grid_desc,
39 const BBlockDesc& b_block_desc,
40 BBlockTransfer& b_blockwise_copy,
41 const BGridBuffer& b_grid_buf,
42 BBlockBuffer& b_block_buf,
43 const BBlockTransferStep& b_block_copy_step,
44 const BlockwiseGemm& blockwise_gemm,
45 CThreadBuffer& c_thread_buf,
46 index_t num_loop)
47 {
48 // global read 0
49 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
50 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
51
52 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
53 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
54
55 // Initialize C
56 c_thread_buf.Clear();
57
58 // LDS write 0
59 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
60 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
61
62 num_loop--;
63
64 while(num_loop > 0)
65 {
66 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
68 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
69
70 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
71
73
74 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
75 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
76 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
77 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
78
79 num_loop--;
80 }
81 // tail
82 {
84 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
85 }
86 }
87};
88
89} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Definition gridwise_gemm_pipeline_v3.hpp:11
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v3.hpp:32
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v3.hpp:12