18template <
typename ThreadGroup,
19 typename ElementwiseOperation,
21 typename SliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
28 typename DimAccessOrder,
31 bool ThreadTransferSrcResetCoordinateAfterRun,
32 bool ThreadTransferDstResetCoordinateAfterRun>
42 const Index& src_block_slice_origin,
43 const DstDesc& dst_desc,
44 const Index& dst_block_slice_origin,
45 const ElementwiseOperation& element_op)
46 : threadwise_transfer_(src_desc,
55 nDim == ThreadClusterLengths::Size() &&
56 nDim == ThreadClusterArrangeOrder::Size() &&
57 nDim == DimAccessOrder::Size(),
58 "wrong! nDim not consistent");
62 "wrong! threads should be mapped to cover entire slicing window");
64 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
65 "wrong! ThreadGroup::GetNumOfThread() too small");
67 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
68 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
70 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
71 make_multi_index(ThreadGroup::GetThreadId()));
73 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
75 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
76 src_block_slice_origin + thread_data_idx_begin);
77 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
78 dst_block_slice_origin + thread_data_idx_begin);
82 template <
typename SrcBuffer,
typename DstBuffer>
83 __device__
void Run(
const SrcDesc& src_desc,
84 const SrcBuffer& src_buf,
85 const DstDesc& dst_desc,
88 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
89 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
91 threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
97 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
98 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
100 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
106 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
107 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
109 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
114 static constexpr auto thread_cluster_desc_ =
117 using ThreadwiseTransfer =
118 ThreadwiseTensorSliceTransfer_v6r1<SrcData,
122 ElementwiseOperation,
128 ThreadTransferSrcResetCoordinateAfterRun,
129 ThreadTransferDstResetCoordinateAfterRun>;
131 ThreadwiseTransfer threadwise_transfer_;
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r1.hpp:39
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r1.hpp:37
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:95
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:41
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:104
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r1.hpp:35
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:83