device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp Source File

device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp Source File#

Composable Kernel: device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp Source File
device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP
5#define DEVICE_CONV3D_FWD_NAIVE_HPP
6
7#include <iostream>
8#include <memory>
9#include <sstream>
10#include "conv_util.hpp"
11#include "device.hpp"
12#include "device_conv_fwd.hpp"
13#include "common_header.hpp"
14#include "naive_conv_fwd.hpp"
15
16namespace ck {
17namespace tensor_operation {
18namespace device {
19
20// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
21template <typename InDataType,
22 typename WeiDataType, // WeiDataType must be the same as InDataType
23 typename OutDataType,
24 typename AccDataType,
25 typename InElementwiseOperation,
26 typename WeiElementwiseOperation,
27 typename OutElementwiseOperation>
29 : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
30
31{
33
34 using ADataType = InDataType;
35 using BDataType = WeiDataType;
36 using CDataType = OutDataType;
37 // TODO make A/B datatype different
38 using ABDataType = InDataType;
39
40 // Argument
41 struct Argument : public BaseArgument
42 {
43 Argument(const InDataType* p_in,
44 const WeiDataType* p_wei,
45 OutDataType* p_out,
46 const index_t N,
47 const index_t K,
48 const index_t C,
49 std::vector<ck::index_t> input_spatial_lengths,
50 std::vector<ck::index_t> filter_spatial_lengths,
51 std::vector<ck::index_t> output_spatial_lengths,
52 std::vector<ck::index_t> conv_filter_strides,
53 std::vector<ck::index_t> conv_filter_dilations,
54 std::vector<ck::index_t> input_left_pads,
55 std::vector<ck::index_t> input_right_pads,
56 InElementwiseOperation in_element_op,
57 WeiElementwiseOperation wei_element_op,
58 OutElementwiseOperation out_element_op)
59 : params_{3,
60 N,
61 K,
62 C,
63 filter_spatial_lengths,
64 input_spatial_lengths,
65 conv_filter_strides,
66 conv_filter_dilations,
67 input_left_pads,
68 input_right_pads},
69 out_spatial_lengths_{output_spatial_lengths},
70 p_in_{p_in},
71 p_wei_{p_wei},
72 p_out_{p_out},
73 in_element_op_{in_element_op},
74 wei_element_op_{wei_element_op},
75 out_element_op_{out_element_op}
76
77 {
78 }
79
80 // private:
81 utils::conv::ConvParams params_;
82 std::vector<index_t> out_spatial_lengths_;
83
84 const InDataType* p_in_;
85 const WeiDataType* p_wei_;
86 OutDataType* p_out_;
87
88 InElementwiseOperation in_element_op_;
89 WeiElementwiseOperation wei_element_op_;
90 OutElementwiseOperation out_element_op_;
91 };
92
93 // Invoker
94 struct Invoker : public BaseInvoker
95 {
97
98 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
99 {
100 const auto naive_conv3d_fwd =
101 ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
102 WeiDataType,
103 OutDataType,
104 AccDataType,
105 InElementwiseOperation,
106 WeiElementwiseOperation,
107 OutElementwiseOperation>;
108
109 float ave_time = launch_and_time_kernel(stream_config,
110 naive_conv3d_fwd,
111 dim3(256),
112 dim3(256),
113 0,
114 arg.p_in_,
115 arg.p_wei_,
116 arg.p_out_,
117 arg.N_,
118 arg.K_,
119 arg.C_,
120 arg.in_spatial_lengths_[0],
121 arg.in_spatial_lengths_[1],
122 arg.in_spatial_lengths_[2],
123 arg.filter_spatial_lengths_[0],
124 arg.filter_spatial_lengths_[1],
125 arg.filter_spatial_lengths_[2],
129 arg.conv_filter_strides_[0],
130 arg.conv_filter_strides_[1],
131 arg.conv_filter_strides_[2],
132 arg.conv_filter_dilations_[0],
133 arg.conv_filter_dilations_[1],
134 arg.conv_filter_dilations_[2],
135 arg.in_left_pads_[0],
136 arg.in_left_pads_[1],
137 arg.in_left_pads_[2]);
138
139 return ave_time;
140 }
141
142 // polymorphic
143 float Run(const BaseArgument* p_arg,
144 const StreamConfig& stream_config = StreamConfig{}) override
145 {
146 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
147 }
148 };
149
150 static constexpr bool IsValidCompilationParameter()
151 {
152 // TODO: properly implement this check
153 return true;
154 }
155
156 static bool IsSupportedArgument(const Argument& arg)
157 {
158 std::vector<index_t> out_spatial_lengths = arg.params_.GetOutputSpatialLengths();
159
160 bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] &&
161 out_spatial_lengths[1] == arg.out_spatial_lengths_[1] &&
162 out_spatial_lengths[2] == arg.out_spatial_lengths_[2];
163 return out_lengths_are_consistent;
164 }
165
166 // polymorphic
167 bool IsSupportedArgument(const BaseArgument* p_arg) override
168 {
169 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
170 }
171
172 static auto MakeArgument(const InDataType* p_in,
173 const WeiDataType* p_wei,
174 OutDataType* p_out,
175 const index_t N,
176 const index_t K,
177 const index_t C,
178 std::vector<ck::index_t> input_spatial_lengths,
179 std::vector<ck::index_t> filter_spatial_lengths,
180 std::vector<ck::index_t> output_spatial_lengths,
181 std::vector<ck::index_t> conv_filter_strides,
182 std::vector<ck::index_t> conv_filter_dilations,
183 std::vector<ck::index_t> input_left_pads,
184 std::vector<ck::index_t> input_right_pads,
185 InElementwiseOperation in_element_op,
186 WeiElementwiseOperation wei_element_op,
187 OutElementwiseOperation out_element_op)
188 {
189 return Argument{p_in,
190 p_wei,
191 p_out,
192 N,
193 K,
194 C,
195 input_spatial_lengths,
196 filter_spatial_lengths,
197 output_spatial_lengths,
198 conv_filter_strides,
199 conv_filter_dilations,
200 input_left_pads,
201 input_right_pads,
202 in_element_op,
203 wei_element_op,
204 out_element_op};
205 }
206
207 static auto MakeInvoker() { return Invoker{}; }
208
209 // polymorphic
210 std::unique_ptr<BaseArgument>
211 MakeArgumentPointer(const void* p_in,
212 const void* p_wei,
213 void* p_out,
214 const index_t N,
215 const index_t K,
216 const index_t C,
217 std::vector<ck::index_t> input_spatial_lengths,
218 std::vector<ck::index_t> filter_spatial_lengths,
219 std::vector<ck::index_t> output_spatial_lengths,
220 std::vector<ck::index_t> conv_filter_strides,
221 std::vector<ck::index_t> conv_filter_dilations,
222 std::vector<ck::index_t> input_left_pads,
223 std::vector<ck::index_t> input_right_pads,
224 InElementwiseOperation in_element_op,
225 WeiElementwiseOperation wei_element_op,
226 OutElementwiseOperation out_element_op) override
227
228 {
229 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in),
230 static_cast<const WeiDataType*>(p_wei),
231 static_cast<OutDataType*>(p_out),
232 N,
233 K,
234 C,
235 input_spatial_lengths,
236 filter_spatial_lengths,
237 output_spatial_lengths,
238 conv_filter_strides,
239 conv_filter_dilations,
240 input_left_pads,
241 input_right_pads,
242 in_element_op,
243 wei_element_op,
244 out_element_op);
245 }
246
247 // polymorphic
248 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
249 {
250 return std::make_unique<Invoker>(Invoker{});
251 }
252
253 std::string GetTypeString() const override
254 {
255 auto str = std::stringstream();
256
257 // clang-format off
258 str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>";
259 // clang-format on
260
261 return str.str();
262 }
263};
264
265} // namespace device
266} // namespace tensor_operation
267} // namespace ck
268#endif
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition ck/stream_config.hpp:10
Definition device_base.hpp:197
const InDataType * p_in_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:84
std::vector< index_t > out_spatial_lengths_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:82
const WeiDataType * p_wei_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:85
InElementwiseOperation in_element_op_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:88
utils::conv::ConvParams params_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:81
Argument(const InDataType *p_in, const WeiDataType *p_wei, OutDataType *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:43
OutElementwiseOperation out_element_op_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:90
WeiElementwiseOperation wei_element_op_
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:89
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:98
DeviceOp::Argument Argument
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:96
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:143
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:248
InDataType ABDataType
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:38
static auto MakeInvoker()
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:207
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, const void *p_wei, void *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:211
static constexpr bool IsValidCompilationParameter()
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:150
WeiDataType BDataType
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:35
OutDataType CDataType
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:36
static auto MakeArgument(const InDataType *p_in, const WeiDataType *p_wei, OutDataType *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:172
InDataType ADataType
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:34
DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K DeviceOp
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:32
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:156
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:167
std::string GetTypeString() const override
Definition device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp:253
Definition device_conv_fwd.hpp:25