20 typename BlockSliceLengths,
21 typename ThreadSliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
28 typename SrcDimAccessOrder,
29 typename DstDimAccessOrder,
30 typename SrcVectorTensorLengths,
31 typename DstVectorTensorLengths,
32 typename SrcVectorTensorContiguousDimOrder,
33 typename DstVectorTensorContiguousDimOrder,
34 bool ThreadTransferSrcResetCoordinateAfterRun,
35 bool ThreadTransferDstResetCoordinateAfterRun>
43 const Index& src_block_slice_origin,
44 const DstDesc& dst_desc,
45 const Index& dst_block_slice_origin)
46 : threadwise_transfer_(
52 nDim == BlockSliceLengths::Size() &&
nDim == ThreadSliceLengths::Size() &&
53 nDim == ThreadClusterLengths::Size() &&
54 nDim == ThreadClusterArrangeOrder::Size() &&
55 nDim == SrcDimAccessOrder::Size() &&
nDim == DstDimAccessOrder::Size(),
56 "wrong! nDim not consistent");
59 is_same<BlockSliceLengths,
decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
60 "wrong! threads should be mapped to cover entire slicing window");
62 static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
63 "wrong! BlockSize too small");
65 if(BlockSize == thread_cluster_desc_.GetElementSize() or
68 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
69 make_multi_index(get_thread_local_1d_id()));
71 const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
73 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
74 src_block_slice_origin + thread_data_idx_begin);
75 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
76 dst_block_slice_origin + thread_data_idx_begin);
80 template <
typename SrcBuffer>
81 __device__
void RunRead(
const SrcDesc& src_desc,
const SrcBuffer& src_buf)
83 if(BlockSize == thread_cluster_desc_.GetElementSize() or
86 threadwise_transfer_.RunRead(src_desc, src_buf);
90 template <
typename DstBuffer>
91 __device__
void RunWrite(
const DstDesc& dst_desc, DstBuffer& dst_buf)
93 if(BlockSize == thread_cluster_desc_.GetElementSize() or
96 threadwise_transfer_.RunWrite(dst_desc, dst_buf);
102 if(BlockSize == thread_cluster_desc_.GetElementSize() or
105 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
110 template <
typename SrcMoveSliceWindowStepHack>
114 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
116 if(BlockSize == thread_cluster_desc_.GetElementSize() or
119 threadwise_transfer_.MoveSrcSliceWindow(
120 src_desc, step, src_move_slice_window_step_hack);
126 if(BlockSize == thread_cluster_desc_.GetElementSize() or
129 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
134 static constexpr auto thread_cluster_desc_ =
137 using ThreadwiseTransfer =
146 SrcVectorTensorLengths,
147 DstVectorTensorLengths,
148 SrcVectorTensorContiguousDimOrder,
149 DstVectorTensorContiguousDimOrder,
150 ThreadTransferSrcResetCoordinateAfterRun,
151 ThreadTransferDstResetCoordinateAfterRun>;
153 ThreadwiseTransfer threadwise_transfer_;
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr index_t nDim
Definition blockwise_tensor_slice_transfer_v5r1.hpp:38
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:81
__device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:42
MultiIndex< nDim > Index
Definition blockwise_tensor_slice_transfer_v5r1.hpp:40
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:91
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:100
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:124
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition blockwise_tensor_slice_transfer_v5r1.hpp:112
Definition threadwise_tensor_slice_transfer_v5r1.hpp:37