tile_window_base.hpp Source File

tile_window_base.hpp Source File#

Composable Kernel: tile_window_base.hpp Source File
tile_window_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck_tile {
20
29template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
31{
32
35 using BottomTensorDesc = typename BottomTensorView::TensorDesc;
37
38 static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
39
41 "wrong! lengths should be static");
42
44
45 CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
46 CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
49
50 CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
51 {
52 window_origin_ = new_window_origin;
53
54 // Delegate to child if it implements extra logic
55 static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
56 }
57 // Default no-op; can be overridden in child
59
60 CK_TILE_DEVICE constexpr void
61 set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
62 {
63 bottom_tensor_view_.buf_.p_data_ = data;
64 }
65
66 // move window-origin
68 {
69 window_origin_ += step;
70
71 // Delegate to child if it implements extra movement logic
72 static_cast<TileWindowType_*>(this)->move_extended(step);
73 }
74
75 // Default no-op; can be overridden in child
77
78 // origin ([x0', x1', ...]) of window on bottom tensor
80
82
83 // this is the bottom tensor view
84 // [x0', x1', ...] ==> [offset]
86};
87
88template <typename TileWindowType_,
89 typename BottomTensorView_,
90 typename WindowLengths_,
91 typename StaticTileDistribution_>
93 : public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
94{
97
98 using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
99
100 static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
101
102 static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
103 static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
104
106 // using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
107
110
113
114 static_assert(TileDstr::is_static(), "wrong!");
115 static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
116 "wrong! inconsistent # of diemsnions");
117
118 CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
120
122 {
123 return TileDstr::is_static();
124 }
125
126 // move thread's window adaptor coordinate and bottom tensor coordinate
127 // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
128 template <typename ATopIndex>
130 WindowAdaptorCoord& window_adaptor_thread_coord,
131 BottomTensorCoord& bottom_tensor_thread_coord,
132 const ATopIndex& idx_diff_adaptor_top) const
133 {
135
136 move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
137 window_adaptor_thread_coord,
138 idx_diff_adaptor_top,
139 idx_diff_adaptor_bottom);
140
141 move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
142 bottom_tensor_thread_coord,
143 idx_diff_adaptor_bottom);
144 }
145
146 struct Traits
147 {
148 public:
149 static constexpr index_t PackedSize =
151
153 {
154 const auto [ys_vector_lengths, ys_vector_strides] =
156
157 index_t VectorDimY_ = 0;
158 index_t ScalarPerVector_ = 1;
159
160 for(index_t i = 0; i < NDimY; ++i)
161 {
162 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
163 {
164 ScalarPerVector_ = ys_vector_lengths[i];
165 VectorDimY_ = i;
166 }
167 }
168
169 return make_tuple(VectorDimY_, ScalarPerVector_);
170 }
171
172 static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
173 static constexpr index_t ScalarPerVector =
174 get_vector_dim_y_scalar_per_vector().template at<1>();
175 using vector_t =
177
178 static constexpr auto scalars_per_access_ = [] {
179 constexpr auto scalars_per_access_arr = generate_array(
180 [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
181
183 constexpr auto NDimY_ = NDimY;
184
185 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
186 }();
187
188 static constexpr auto get_space_filling_curve()
189 {
190 constexpr auto thread_tensor_lengths_ys =
191 to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
192
193 // FIXME: need logic to judge dim access order
194 using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
195
196 return space_filling_curve<decltype(thread_tensor_lengths_ys),
197 DimAccessOrder,
198 decltype(scalars_per_access_),
199 false >{};
200 }
201
202 using SFC_Ys = decltype(get_space_filling_curve());
203
204 static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
205
206 static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
207 };
208
209 // return vector dimension among [y0, y1, ...]
211 {
212 // bottom tensor top dimension vector lengths and strides
213 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
214 TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
215
216 // window vector lengths/strides
217 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
218 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
219
220 // window adaptor [p0, p1, ..., y0, y1, ...]
221 array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
222 -1};
223 array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
224 -1};
225
226 constexpr auto window_adaptor_bottom_dims =
227 WindowAdaptor::get_bottom_dimension_hidden_ids();
228
229 set_container_subset(window_adaptor_vector_lengths,
230 window_adaptor_bottom_dims,
231 window_adaptor_bottom_dim_vector_lengths);
232 set_container_subset(window_adaptor_vector_strides,
233 window_adaptor_bottom_dims,
234 window_adaptor_bottom_dim_vector_strides);
235
236 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
237 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
238 window_adaptor_vector_lengths, window_adaptor_vector_strides);
239
240 // [y0, y1, ...]
241 constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
243 1>::type{};
244
245 return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
246 get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
247 }
248
249 CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
250 // Tile tensor distribution, which contains:
251 // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
252 // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
254};
255
256} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
typename std::remove_reference< T >::type remove_reference_t
Definition type_traits.hpp:15
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition tensor_adaptor_coordinate.hpp:97
CK_TILE_HOST_DEVICE constexpr auto generate_array(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1115
CK_TILE_HOST_DEVICE constexpr void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition tile/core/container/container_helper.hpp:420
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
CK_TILE_HOST_DEVICE constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition tensor_coordinate.hpp:72
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition tensor_coordinate.hpp:60
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition tile/core/container/sequence.hpp:287
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
static constexpr bool value
Definition type_traits.hpp:77
Definition tile/core/numeric/numeric.hpp:81
Definition space_filling_curve.hpp:20
Definition tile/core/utility/debug.hpp:67
This class provides description of tile windowed view on the device memory.
Definition tile_window_base.hpp:31
CK_TILE_DEVICE constexpr auto get_window_origin() const
Definition tile_window_base.hpp:45
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition tile_window_base.hpp:36
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const
Definition tile_window_base.hpp:47
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition tile_window_base.hpp:50
CK_TILE_DEVICE constexpr auto get_window_lengths() const
Definition tile_window_base.hpp:46
static CK_TILE_DEVICE constexpr index_t get_num_of_dimension()
Definition tile_window_base.hpp:48
CK_TILE_DEVICE void move_extended(const BottomTensorIndex &)
Definition tile_window_base.hpp:76
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition tile_window_base.hpp:67
CK_TILE_DEVICE constexpr void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition tile_window_base.hpp:61
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex &)
Definition tile_window_base.hpp:58
Definition tile_window_base.hpp:147
decltype(get_space_filling_curve()) SFC_Ys
Definition tile_window_base.hpp:202
static constexpr auto get_space_filling_curve()
Definition tile_window_base.hpp:188
static constexpr index_t ScalarPerVector
Definition tile_window_base.hpp:173
static constexpr index_t PackedSize
Definition tile_window_base.hpp:149
static constexpr auto scalars_per_access_
Definition tile_window_base.hpp:178
static constexpr index_t VectorDimY
Definition tile_window_base.hpp:172
thread_buffer< typename TileWindowBase::DataType, ScalarPerVector/PackedSize > vector_t
Definition tile_window_base.hpp:175
static constexpr index_t NumAccess
Definition tile_window_base.hpp:204
static constexpr auto get_vector_dim_y_scalar_per_vector()
Definition tile_window_base.hpp:152
Definition tile_window_base.hpp:94
CK_TILE_DEVICE constexpr auto get_num_of_access() const
Definition tile_window_base.hpp:249
CK_TILE_DEVICE constexpr auto get_tile_distribution() const
Definition tile_window_base.hpp:118
decltype(make_tensor_coordinate(typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{})) BottomTensorCoord
Definition tile_window_base.hpp:111
static CK_TILE_DEVICE constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
Definition tile_window_base.hpp:210
CK_TILE_HOST_DEVICE void init_raw()
Definition tile_window_base.hpp:119
tile_window_base< tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, StaticTileDistribution_, NumCoord >, BottomTensorView_, WindowLengths_ > TileWindowBase
Definition tile_window_base.hpp:96
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition tile_window_base.hpp:129
static CK_TILE_DEVICE constexpr bool has_static_tile_distribution()
Definition tile_window_base.hpp:121
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10