helper.hpp Source File

helper.hpp Source File#

Composable Kernel: helper.hpp Source File
helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11#include <fstream>
12#include <variant>
13
14// functions to return the corresponding structs based on generated template parameters
15
21// return the layout type: currently this is the only type supported in MIOpen
22auto layout_type(std::string type)
23{
24 if(type == "ck::tensor_layout::convolution::NHWGK")
25 {
27 }
28 throw std::runtime_error("Incorrect layout");
29}
30// return the right gemm spec based on the generated template parameters
32{
33 if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
34 {
36 }
37 if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
38 {
40 }
41 throw std::runtime_error("Incorrect gemm spec: " + type);
42}
43
44// return the type of convolution
46{
47 if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
48 {
50 }
51 if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
52 {
54 }
55 if(type ==
56 "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
57 {
59 }
60 if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
61 {
63 }
64 throw std::runtime_error("Incorrect conv spec: " + type);
65}
66
67// Function to call on MatrixPadder via a wrapper struct
68// NOTE: CK only uses MNKPadding for forward convolution
69template <typename CDesc_MRaw_NRaw>
71 ck::index_t npb,
72 ck::index_t kpb,
74 CDesc_MRaw_NRaw conv)
75{
77 {
83 a;
84 a.MPerTile_ = mpb;
85 a.NPerTile_ = npb;
86 a.KPerTile_ = kpb;
87 auto tmp = grid_desc(a, conv);
88 return tmp;
89 }
90 throw std::runtime_error("Incorrect template parameters, check gemm spec");
91}
92
93// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
94// dims
95// FIXME: add a way to properly pass in the layout
98 ck::Array<ck::index_t, 5> out_lengths,
99 ck::Array<ck::index_t, 5> out_strides)
100{
101 ck::Array<ck::index_t, 5> dummy_dims;
102 ck::Array<ck::index_t, 2> dummy_spatial_dims;
103 if(num_dim == 2 &&
105 {
107 2,
109 conv_fwd{dummy_dims,
110 dummy_dims,
111 dummy_dims,
112 dummy_dims,
113 out_lengths,
114 out_strides,
115 dummy_spatial_dims,
116 dummy_spatial_dims,
117 dummy_spatial_dims,
118 dummy_spatial_dims};
119
121 return res.transform_func(conv_fwd);
122 }
123 if(num_dim == 2 &&
125 {
127 2,
129 conv_fwd{dummy_dims,
130 dummy_dims,
131 dummy_dims,
132 dummy_dims,
133 out_lengths,
134 out_strides,
135 dummy_spatial_dims,
136 dummy_spatial_dims,
137 dummy_spatial_dims,
138 dummy_spatial_dims};
139
141 return res.transform_func(conv_fwd);
142 }
143 if(num_dim == 2 &&
145 {
147 2,
149 conv_fwd{dummy_dims,
150 dummy_dims,
151 dummy_dims,
152 dummy_dims,
153 out_lengths,
154 out_strides,
155 dummy_spatial_dims,
156 dummy_spatial_dims,
157 dummy_spatial_dims,
158 dummy_spatial_dims};
159
161 return res.transform_func(conv_fwd);
162 }
164 {
166 2,
168 conv_fwd{dummy_dims,
169 dummy_dims,
170 dummy_dims,
171 dummy_dims,
172 out_lengths,
173 out_strides,
174 dummy_spatial_dims,
175 dummy_spatial_dims,
176 dummy_spatial_dims,
177 dummy_spatial_dims};
178
180 return res.transform_func(conv_fwd);
181 }
182 throw std::runtime_error("Incorrect conv spec");
183}
184
187 ck::Array<ck::index_t, 6> out_lengths,
188 ck::Array<ck::index_t, 6> out_strides)
189{
190 ck::Array<ck::index_t, 6> dummy_dims;
191 ck::Array<ck::index_t, 3> dummy_spatial_dims;
192
193 if(num_dim == 3 &&
195 {
197 3,
199 conv_fwd{dummy_dims,
200 dummy_dims,
201 dummy_dims,
202 dummy_dims,
203 out_lengths,
204 out_strides,
205 dummy_spatial_dims,
206 dummy_spatial_dims,
207 dummy_spatial_dims,
208 dummy_spatial_dims};
209
211 return res.transform_func(conv_fwd);
212 }
213 if(num_dim == 3 &&
215 {
217 3,
219 conv_fwd{dummy_dims,
220 dummy_dims,
221 dummy_dims,
222 dummy_dims,
223 out_lengths,
224 out_strides,
225 dummy_spatial_dims,
226 dummy_spatial_dims,
227 dummy_spatial_dims,
228 dummy_spatial_dims};
229
231 return res.transform_func(conv_fwd);
232 }
233 if(num_dim == 3 &&
235 {
237 3,
239 conv_fwd{dummy_dims,
240 dummy_dims,
241 dummy_dims,
242 dummy_dims,
243 out_lengths,
244 out_strides,
245 dummy_spatial_dims,
246 dummy_spatial_dims,
247 dummy_spatial_dims,
248 dummy_spatial_dims};
249
251 return res.transform_func(conv_fwd);
252 }
254 {
256 3,
258 conv_fwd{dummy_dims,
259 dummy_dims,
260 dummy_dims,
261 dummy_dims,
262 out_lengths,
263 out_strides,
264 dummy_spatial_dims,
265 dummy_spatial_dims,
266 dummy_spatial_dims,
267 dummy_spatial_dims};
268
270 return res.transform_func(conv_fwd);
271 }
272 throw std::runtime_error("Incorrect conv spec");
273}
274
277 ck::Array<ck::index_t, 4> out_lengths,
278 ck::Array<ck::index_t, 4> out_strides)
279{
280 ck::Array<ck::index_t, 4> dummy_dims;
281 ck::Array<ck::index_t, 1> dummy_spatial_dims;
282
283 if(num_dim == 1 &&
285 {
287 1,
289 conv_fwd{dummy_dims,
290 dummy_dims,
291 dummy_dims,
292 dummy_dims,
293 out_lengths,
294 out_strides,
295 dummy_spatial_dims,
296 dummy_spatial_dims,
297 dummy_spatial_dims,
298 dummy_spatial_dims};
299
301 return res.transform_func(conv_fwd);
302 }
303 if(num_dim == 1 &&
305 {
307 1,
309 conv_fwd{dummy_dims,
310 dummy_dims,
311 dummy_dims,
312 dummy_dims,
313 out_lengths,
314 out_strides,
315 dummy_spatial_dims,
316 dummy_spatial_dims,
317 dummy_spatial_dims,
318 dummy_spatial_dims};
319
321 return res.transform_func(conv_fwd);
322 }
323 if(num_dim == 1 &&
325 {
327 1,
329 conv_fwd{dummy_dims,
330 dummy_dims,
331 dummy_dims,
332 dummy_dims,
333 out_lengths,
334 out_strides,
335 dummy_spatial_dims,
336 dummy_spatial_dims,
337 dummy_spatial_dims,
338 dummy_spatial_dims};
339
341 return res.transform_func(conv_fwd);
342 }
344 {
346 1,
348 conv_fwd{dummy_dims,
349 dummy_dims,
350 dummy_dims,
351 dummy_dims,
352 out_lengths,
353 out_strides,
354 dummy_spatial_dims,
355 dummy_spatial_dims,
356 dummy_spatial_dims,
357 dummy_spatial_dims};
358
360 return res.transform_func(conv_fwd);
361 }
362 throw std::runtime_error("Incorrect dims or conv spec");
363}
364
365template <typename CGridDesc_M_N>
366auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
367{
368 if(m_per_block == 32 && n_per_block == 64)
369 {
371 return b2e.CalculateGridSize(matrix_padder);
372 }
373 if(m_per_block == 32 && n_per_block == 128)
374 {
376 return b2e.CalculateGridSize(matrix_padder);
377 }
378 if(m_per_block == 64 && n_per_block == 32)
379 {
381 return b2e.CalculateGridSize(matrix_padder);
382 }
383 if(m_per_block == 64 && n_per_block == 64)
384 {
386 return b2e.CalculateGridSize(matrix_padder);
387 }
388 if(m_per_block == 64 && n_per_block == 128)
389 {
391 return b2e.CalculateGridSize(matrix_padder);
392 }
393 if(m_per_block == 128 && n_per_block == 32)
394 {
396 return b2e.CalculateGridSize(matrix_padder);
397 }
398 if(m_per_block == 128 && n_per_block == 64)
399 {
401 return b2e.CalculateGridSize(matrix_padder);
402 }
403 if(m_per_block == 128 && n_per_block == 128)
404 {
406 return b2e.CalculateGridSize(matrix_padder);
407 }
408 if(m_per_block == 128 && n_per_block == 256)
409 {
411 return b2e.CalculateGridSize(matrix_padder);
412 }
413 if(m_per_block == 256 && n_per_block == 128)
414 {
416 return b2e.CalculateGridSize(matrix_padder);
417 }
418 throw std::runtime_error("Incorrect template parameters");
419}
420
421// wrapper functions by dims to get grid size - uses above 3 functions
422// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
423auto get_launch_params_1d(ck::host::Solution solution,
424 ck::Array<ck::index_t, 4> out_lengths,
425 ck::Array<ck::index_t, 4> out_strides)
426{
427 auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
428 auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
429 auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
430 auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
431 auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
432 auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
435 auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
436 auto matrix_padder =
437 pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
438 auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
439 return b2e;
440}
441
442auto get_launch_params(ck::host::Solution solution,
443 ck::Array<ck::index_t, 5> out_lengths,
444 ck::Array<ck::index_t, 5> out_strides)
445{
446 auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
447 auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
448 auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
449 auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
450 auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
451 auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
454 auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
455 auto matrix_padder =
456 pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
457 auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
458 return b2e;
459}
460
461auto get_launch_params_3d(ck::host::Solution solution,
462 ck::Array<ck::index_t, 6> out_lengths,
463 ck::Array<ck::index_t, 6> out_strides)
464{
465 auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
466 auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
467 auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
468 auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
469 auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
470 auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
473 auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
474 auto matrix_padder =
475 pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
476 auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
477 return b2e;
478}
auto transform_conv(ck::index_t num_dim, ck::tensor_operation::device::ConvolutionForwardSpecialization spec, ck::Array< ck::index_t, 5 > out_lengths, ck::Array< ck::index_t, 5 > out_strides)
Definition helper.hpp:96
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
Definition helper.hpp:366
std::variant< ck::tensor_layout::convolution::GNWK, ck::tensor_layout::convolution::GNHWK, ck::tensor_layout::convolution::NHWGK, ck::tensor_layout::convolution::GNDHWK, ck::tensor_layout::convolution::NDHWGK > layouts
Definition helper.hpp:16
auto transform_conv_1d(ck::index_t num_dim, ck::tensor_operation::device::ConvolutionForwardSpecialization spec, ck::Array< ck::index_t, 4 > out_lengths, ck::Array< ck::index_t, 4 > out_strides)
Definition helper.hpp:275
auto layout_type(std::string type)
Definition helper.hpp:22
auto get_launch_params_3d(ck::host::Solution solution, ck::Array< ck::index_t, 6 > out_lengths, ck::Array< ck::index_t, 6 > out_strides)
Definition helper.hpp:461
auto get_launch_params(ck::host::Solution solution, ck::Array< ck::index_t, 5 > out_lengths, ck::Array< ck::index_t, 5 > out_strides)
Definition helper.hpp:442
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
Definition helper.hpp:31
auto transform_conv_3d(ck::index_t num_dim, ck::tensor_operation::device::ConvolutionForwardSpecialization spec, ck::Array< ck::index_t, 6 > out_lengths, ck::Array< ck::index_t, 6 > out_strides)
Definition helper.hpp:185
auto get_launch_params_1d(ck::host::Solution solution, ck::Array< ck::index_t, 4 > out_lengths, ck::Array< ck::index_t, 4 > out_strides)
Definition helper.hpp:423
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
Definition helper.hpp:45
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
@ MNKPadding
Definition gemm_specialization.hpp:20
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ OddC
Definition convolution_forward_specialization.hpp:19
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Default
Definition convolution_forward_specialization.hpp:16
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
int32_t index_t
Definition ck.hpp:299
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition utility/array.hpp:14
Definition block_to_ctile_map.hpp:261
Definition tensor_operation/gpu/device/tensor_layout.hpp:345
Definition tensor_operation/gpu/device/tensor_layout.hpp:340
Definition tensor_operation/gpu/device/tensor_layout.hpp:335
Definition tensor_operation/gpu/device/tensor_layout.hpp:362
Definition tensor_operation/gpu/device/tensor_layout.hpp:357
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:1768
Definition matrix_padder.hpp:180