device_elementwise_dynamic_vector_dims_impl.hpp Source File

device_elementwise_dynamic_vector_dims_impl.hpp Source File#

Composable Kernel: device_elementwise_dynamic_vector_dims_impl.hpp Source File
device_elementwise_dynamic_vector_dims_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
9#include "ck/utility/math.hpp"
15
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename InDataTypeTuple,
24 typename OutDataTypeTuple,
25 typename ElementwiseOperation,
26 index_t NumDim,
27 index_t BlockSize,
28 index_t M0PerBlock,
29 index_t M1PerBlock,
30 index_t M0PerThread,
31 index_t M1PerThread,
32 typename ThreadClusterArrangeOrder,
33 typename InScalarPerVectorSeq,
34 typename OutScalarPerVectorSeq>
36 : public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
37{
38 static constexpr int NumInput = InDataTypeTuple::Size();
39 static constexpr int NumOutput = OutDataTypeTuple::Size();
40
41 static constexpr auto I0 = Number<0>{};
42 static constexpr auto I1 = Number<1>{};
43
44 static_assert(NumInput == InScalarPerVectorSeq::Size() &&
45 NumOutput == OutScalarPerVectorSeq::Size(),
46 "Tuple size is inconsistent with the number of in/out!");
47
49 {
50 return generate_tuple(
51 [&](auto I) {
52 using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
53
54 return static_cast<const DataType*>(nullptr);
55 },
57 };
58
60 {
61 return generate_tuple(
62 [&](auto I) {
63 using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
64
65 return static_cast<DataType*>(nullptr);
66 },
68 };
69
72
73 static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
74 {
75 index_t most_continous_dim = NumDim - 1;
76 index_t most_continous_dim_stride = strides[most_continous_dim];
77 for(index_t dim = 0; dim < NumDim; dim++)
78 {
79 if(strides[dim] < most_continous_dim_stride)
80 {
81 most_continous_dim_stride = strides[dim];
82 most_continous_dim = dim;
83 }
84 }
85 return most_continous_dim;
86 }
87
88 template <typename InOutDescriptor>
89 static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
90 {
91 const auto M0 = desc.GetLength(I0);
92 const auto M1 = desc.GetLength(I1);
93 const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
94 const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
95
96 const auto padded_desc = transform_tensor_descriptor(
97 desc,
101
102 return padded_desc;
103 }
104
105 static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
106 const index_t M0_dim,
107 const index_t M1_dim)
108 {
109 // Generate batch dims, they will be merged to M0
110 // Add one more dim than needed in case that M0 is equal to M1
111 // If M0 is equal to M1, then will be one more batch dim
112 std::array<index_t, NumDim - 1> batch_dims;
113 index_t batch_dim = 0;
114 for(index_t i = 0; i < NumDim; i++)
115 {
116 if(i != M0_dim && i != M1_dim)
117 {
118 batch_dims[batch_dim] = lengths[i];
119 batch_dim++;
120 }
121 }
122 // Add dummy dim if M0_dim is not equal to M1_dim
123 if(M0_dim != M1_dim && NumDim >= 2)
124 batch_dims[NumDim - 2] = 1;
125 return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
126 }
127
128 static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
129 const std::array<index_t, NumDim>& in_strides,
130 const std::array<index_t, NumDim>& out_strides,
131 const std::array<index_t, NumDim>& desc_strides)
132 {
133 const auto M0_dim = GetLowestStrideDim(out_strides);
134 const auto M1_dim = GetLowestStrideDim(in_strides);
135
136 // If M0_dim is equal to M1_dim, then make M0_dim dummy
137 const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
138 const auto M1 = lengths[M1_dim];
139 const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
140 const auto M1_stride = desc_strides[M1_dim];
141
142 const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
143 const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
144
145 const auto desc = make_naive_tensor_descriptor(
146 concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
147 concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
148 // Merged batch dims with M0
149 const auto transforms =
150 make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
152 using BatchElemsSequence =
153 typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
154 const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
155 const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
156 // desc: (merged_dims + M0, M1)
157 auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
158 return PadInputOutputDescriptor(merged_desc);
159 }
160
161 template <index_t NumTensors>
163 {
164 std::array<index_t, NumDim> ones;
165 for(index_t d = 0; d < NumDim; d++)
166 {
167 ones[d] = 1;
168 }
169
170 return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
172 };
173
176
178
184 ElementwiseOperation,
185 BlockSize,
186 M0PerBlock,
187 M1PerBlock,
188 M0PerThread,
189 M1PerThread,
190 ThreadClusterArrangeOrder,
191 InScalarPerVectorSeq,
192 OutScalarPerVectorSeq,
193 I1,
194 I0>;
195
201 ElementwiseOperation,
202 BlockSize,
203 M0PerBlock,
204 M1PerBlock,
205 M0PerThread,
206 M1PerThread,
207 ThreadClusterArrangeOrder,
208 InScalarPerVectorSeq,
209 OutScalarPerVectorSeq,
210 I1,
211 I1>;
212
213 struct Argument : public BaseArgument
214 {
215 Argument(const std::array<index_t, NumDim> lengths,
216 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
217 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
218 const std::array<const void*, NumInput> in_dev_buffers,
219 const std::array<void*, NumOutput> out_dev_buffers,
220 ElementwiseOperation elementwise_op)
221
222 : lengths_(lengths),
223 inStridesArray_(inStridesArray),
224 outStridesArray_(outStridesArray),
225 elementwise_op_(elementwise_op)
226 {
228 [&](auto I) {
229 using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
230 return static_cast<const DataType*>(in_dev_buffers[I.value]);
231 },
233
235 [&](auto I) {
236 using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
237 return static_cast<DataType*>(out_dev_buffers[I.value]);
238 },
240 }
241
244
245 std::array<index_t, NumDim> lengths_;
246 std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
247 std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
248
249 ElementwiseOperation elementwise_op_;
250 };
251
252 struct Invoker : public BaseInvoker
253 {
254 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
255 {
256 auto in_grid_desc_tuple = generate_tuple(
257 [&](auto src_i) {
258 // Use Strides from first tensor to assert that M0 dim and
259 // M1 dim are the same for each tensor.
260 return MakeDescriptor(arg.lengths_,
261 arg.inStridesArray_[I0],
262 arg.outStridesArray_[I0],
263 arg.inStridesArray_[src_i]);
264 },
266
267 auto out_grid_desc_tuple = generate_tuple(
268 [&](auto dst_i) {
269 return MakeDescriptor(arg.lengths_,
270 arg.inStridesArray_[I0],
271 arg.outStridesArray_[I0],
272 arg.outStridesArray_[dst_i]);
273 },
275
276 const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
277 const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
278
279 const auto block_2_tile_map = Block2TileMap(M0, M1);
280 const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
281
282 const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
284
285 const auto kernel = in_out_same_vector_dim
292 ElementwiseOperation>
299 ElementwiseOperation>;
300
301 float elapsed_time = launch_and_time_kernel(stream_config,
302 kernel,
303 dim3(grid_size),
304 dim3(BlockSize),
305 0,
306 in_grid_desc_tuple,
307 out_grid_desc_tuple,
308 arg.in_dev_buffers_,
310 block_2_tile_map,
311 arg.elementwise_op_);
312 return elapsed_time;
313 }
314
315 // polymorphic
316 float Run(const BaseArgument* p_arg,
317 const StreamConfig& stream_config = StreamConfig{}) override
318 {
319 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
320 }
321 };
322
323 static bool IsSupportedArgument(const Argument& arg)
324 {
325 const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
326 const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
327
328 auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
329 const std::array<index_t, NumDim>& strides,
330 index_t scalarPerVector,
331 index_t M_dim) {
332 if(scalarPerVector == 1)
333 {
334 return true;
335 }
336 if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
337 {
338 return true;
339 }
340 return false;
341 };
342
343 bool is_valid = true;
344 static_for<0, NumInput, 1>{}([&](auto I) {
345 static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
346 M1PerThread % InScalarPerVectorSeq::At(I) == 0);
347 is_valid &= IsScalarPerVectorValid(
348 arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
349 });
350
351 static_for<0, NumOutput, 1>{}([&](auto I) {
352 static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
353 M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
354 is_valid &= IsScalarPerVectorValid(
355 arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
356 });
357
358 return is_valid;
359 };
360
361 bool IsSupportedArgument(const BaseArgument* p_arg) override
362 {
363 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
364 }
365
366 static auto
367 MakeArgument(const std::array<index_t, NumDim> lengths,
368 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
369 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
370 const std::array<const void*, NumInput> in_dev_buffers,
371 const std::array<void*, NumOutput> out_dev_buffers,
372 ElementwiseOperation elementwise_op)
373 {
374 return Argument{lengths,
375 inStridesArray,
376 outStridesArray,
377 in_dev_buffers,
378 out_dev_buffers,
379 elementwise_op};
380 }
381
382 std::unique_ptr<BaseArgument>
383 MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
384 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
385 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
386 const std::array<const void*, NumInput> in_dev_buffers,
387 const std::array<void*, NumOutput> out_dev_buffers,
388 ElementwiseOperation elementwise_op) override
389 {
390 return std::make_unique<Argument>(lengths,
391 inStridesArray,
392 outStridesArray,
393 in_dev_buffers,
394 out_dev_buffers,
395 elementwise_op);
396 }
397
398 static auto MakeInvoker() { return Invoker{}; }
399 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
400 {
401 return std::make_unique<Invoker>();
402 };
403
404 std::string GetTypeString() const override
405 {
406 auto str = std::stringstream();
407
408 // clang-format off
409 str << "DeviceElementwiseImpl<";
410 str << NumDim << ", ";
411 str << BlockSize << ", ";
412 str << M0PerBlock << ", ";
413 str << M1PerBlock << ", ";
414 str << M0PerThread << ", ";
415 str << M1PerThread << ">";
416 // clang-format on
417
418 return str.str();
419 }
420};
421
422} // namespace device
423} // namespace tensor_operation
424} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition utility/sequence.hpp:43
Definition utility/sequence.hpp:256
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_elementwise.hpp:21
Definition device_elementwise_dynamic_vector_dims_impl.hpp:214
Argument(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:215
InDataTypePointerTuple in_dev_buffers_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:242
std::array< index_t, NumDim > lengths_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:245
OutDataTypePointerTuple out_dev_buffers_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:243
ElementwiseOperation elementwise_op_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:249
std::array< std::array< index_t, NumDim >, NumInput > inStridesArray_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:246
std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray_
Definition device_elementwise_dynamic_vector_dims_impl.hpp:247
Definition device_elementwise_dynamic_vector_dims_impl.hpp:253
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_elementwise_dynamic_vector_dims_impl.hpp:254
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_elementwise_dynamic_vector_dims_impl.hpp:316
Definition device_elementwise_dynamic_vector_dims_impl.hpp:37
static auto MakeInvoker()
Definition device_elementwise_dynamic_vector_dims_impl.hpp:398
static auto PadInputOutputDescriptor(const InOutDescriptor &desc)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:89
decltype(GenerateInOutGridDescTuple< NumOutput >()) OutGridDescTuple
Definition device_elementwise_dynamic_vector_dims_impl.hpp:175
GridwiseElementwise< InGridDescTuple, OutGridDescTuple, InDataTypePointerTuple, OutDataTypePointerTuple, Block2TileMap, ElementwiseOperation, BlockSize, M0PerBlock, M1PerBlock, M0PerThread, M1PerThread, ThreadClusterArrangeOrder, InScalarPerVectorSeq, OutScalarPerVectorSeq, I1, I0 > GridwiseElementwiseOp
Definition device_elementwise_dynamic_vector_dims_impl.hpp:179
static auto GenerateBatchDimsLenghtsTuple(const std::array< index_t, NumDim > &lengths, const index_t M0_dim, const index_t M1_dim)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:105
static constexpr auto I1
Definition device_elementwise_dynamic_vector_dims_impl.hpp:42
static constexpr auto I0
Definition device_elementwise_dynamic_vector_dims_impl.hpp:41
decltype(GenerateInDataTypePointerTuple()) InDataTypePointerTuple
Definition device_elementwise_dynamic_vector_dims_impl.hpp:70
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition device_elementwise_dynamic_vector_dims_impl.hpp:71
BlockToCTileMap_M00_N0_M01Adapt< M0PerBlock, M1PerBlock > Block2TileMap
Definition device_elementwise_dynamic_vector_dims_impl.hpp:177
static constexpr int NumInput
Definition device_elementwise_dynamic_vector_dims_impl.hpp:38
static bool IsSupportedArgument(const Argument &arg)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:323
static auto GenerateInOutGridDescTuple()
Definition device_elementwise_dynamic_vector_dims_impl.hpp:162
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_elementwise_dynamic_vector_dims_impl.hpp:399
GridwiseElementwise< InGridDescTuple, OutGridDescTuple, InDataTypePointerTuple, OutDataTypePointerTuple, Block2TileMap, ElementwiseOperation, BlockSize, M0PerBlock, M1PerBlock, M0PerThread, M1PerThread, ThreadClusterArrangeOrder, InScalarPerVectorSeq, OutScalarPerVectorSeq, I1, I1 > GridwiseElementwiseOpSameInOutVectorDim
Definition device_elementwise_dynamic_vector_dims_impl.hpp:196
GridwiseElementwise_1D< InGrid1dDescTuple, OutGrid1dDescTuple, InDataTypePointerTuple, OutDataTypePointerTuple, ElementwiseOperation, UnaryOperation, Scale, MPerThread, InScalarPerVectorSeq, OutScalarPerVectorSeq > GridwiseElementwise
Definition device_elementwise_scale_impl.hpp:136
std::string GetTypeString() const override
Definition device_elementwise_dynamic_vector_dims_impl.hpp:404
static auto MakeDescriptor(const std::array< index_t, NumDim > &lengths, const std::array< index_t, NumDim > &in_strides, const std::array< index_t, NumDim > &out_strides, const std::array< index_t, NumDim > &desc_strides)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:128
static index_t GetLowestStrideDim(const std::array< index_t, NumDim > &strides)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:73
static constexpr int NumOutput
Definition device_elementwise_dynamic_vector_dims_impl.hpp:39
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op) override
Definition device_elementwise_dynamic_vector_dims_impl.hpp:383
static auto GenerateInDataTypePointerTuple()
Definition device_elementwise_dynamic_vector_dims_impl.hpp:48
static auto GenerateOutDataTypePointerTuple()
Definition device_elementwise_dynamic_vector_dims_impl.hpp:59
decltype(GenerateInOutGridDescTuple< NumInput >()) InGridDescTuple
Definition device_elementwise_dynamic_vector_dims_impl.hpp:174
static auto MakeArgument(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op)
Definition device_elementwise_dynamic_vector_dims_impl.hpp:367
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_elementwise_dynamic_vector_dims_impl.hpp:361