blockwise_tensor_slice_transfer_v5r1.hpp Source File

blockwise_tensor_slice_transfer_v5r1.hpp Source File#

Composable Kernel: blockwise_tensor_slice_transfer_v5r1.hpp Source File
blockwise_tensor_slice_transfer_v5r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// this version does following things to avoid scratch memory issue
15// 1. Use StaticallyIndexedArray instead of C array for thread buffer
16// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
17// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
18template <index_t BlockSize,
20 typename BlockSliceLengths,
21 typename ThreadSliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
24 typename SrcData,
25 typename DstData,
26 typename SrcDesc,
27 typename DstDesc,
28 typename SrcDimAccessOrder,
29 typename DstDimAccessOrder,
30 typename SrcVectorTensorLengths,
31 typename DstVectorTensorLengths,
32 typename SrcVectorTensorContiguousDimOrder,
33 typename DstVectorTensorContiguousDimOrder,
34 bool ThreadTransferSrcResetCoordinateAfterRun,
35 bool ThreadTransferDstResetCoordinateAfterRun>
37{
39
41
42 __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
43 const Index& src_block_slice_origin,
44 const DstDesc& dst_desc,
45 const Index& dst_block_slice_origin)
46 : threadwise_transfer_(
47 src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
48
49 {
52 nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
53 nDim == ThreadClusterLengths::Size() &&
54 nDim == ThreadClusterArrangeOrder::Size() &&
55 nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
56 "wrong! nDim not consistent");
57
58 static_assert(
59 is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
60 "wrong! threads should be mapped to cover entire slicing window");
61
62 static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
63 "wrong! BlockSize too small");
64
65 if(BlockSize == thread_cluster_desc_.GetElementSize() or
66 get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
67 {
68 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
69 make_multi_index(get_thread_local_1d_id()));
70
71 const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
72
73 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
74 src_block_slice_origin + thread_data_idx_begin);
75 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
76 dst_block_slice_origin + thread_data_idx_begin);
77 }
78 }
79
80 template <typename SrcBuffer>
81 __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
82 {
83 if(BlockSize == thread_cluster_desc_.GetElementSize() or
84 get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
85 {
86 threadwise_transfer_.RunRead(src_desc, src_buf);
87 }
88 }
89
90 template <typename DstBuffer>
91 __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
92 {
93 if(BlockSize == thread_cluster_desc_.GetElementSize() or
94 get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
95 {
96 threadwise_transfer_.RunWrite(dst_desc, dst_buf);
97 }
98 }
99
100 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
101 {
102 if(BlockSize == thread_cluster_desc_.GetElementSize() or
103 get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
104 {
105 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
106 }
107 }
108
109 // SrcMoveSliceWindowStepHack to control index calculation move slice window
110 template <typename SrcMoveSliceWindowStepHack>
111 __device__ void
112 MoveSrcSliceWindow(const SrcDesc& src_desc,
113 const Index& step,
114 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
115 {
116 if(BlockSize == thread_cluster_desc_.GetElementSize() or
117 get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
118 {
119 threadwise_transfer_.MoveSrcSliceWindow(
120 src_desc, step, src_move_slice_window_step_hack);
121 }
122 }
123
124 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
125 {
126 if(BlockSize == thread_cluster_desc_.GetElementSize() or
127 get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
128 {
129 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
130 }
131 }
132
133 private:
134 static constexpr auto thread_cluster_desc_ =
135 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
136
137 using ThreadwiseTransfer =
138 ThreadwiseTensorSliceTransfer_v5r1<ThreadSliceLengths,
139 DstInMemOp,
140 SrcData,
141 DstData,
142 SrcDesc,
143 DstDesc,
144 SrcDimAccessOrder,
145 DstDimAccessOrder,
146 SrcVectorTensorLengths,
147 DstVectorTensorLengths,
148 SrcVectorTensorContiguousDimOrder,
149 DstVectorTensorContiguousDimOrder,
150 ThreadTransferSrcResetCoordinateAfterRun,
151 ThreadTransferDstResetCoordinateAfterRun>;
152
153 ThreadwiseTransfer threadwise_transfer_;
154};
155
156} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr index_t nDim
Definition blockwise_tensor_slice_transfer_v5r1.hpp:38
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:81
__device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:42
MultiIndex< nDim > Index
Definition blockwise_tensor_slice_transfer_v5r1.hpp:40
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:91
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:100
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:124
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:112
Definition threadwise_tensor_slice_transfer_v5r1.hpp:37
Definition type.hpp:177