block_position_encoding.hpp Source File

block_position_encoding.hpp Source File#

Composable Kernel: block_position_encoding.hpp Source File
block_position_encoding.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <cmath>
9#include <vector>
10
11namespace ck_tile {
12
14{
15 NO = 0,
16 ALIBI = 1,
17};
18
19/*
20VERTICAL:
21 [0] 1 2 3 4 5
22 [0] 1 2 3 4 5
23 [0] 1 2 3 4 5
24 [0] 1 2 3 4 5
25
26TOP_LEFT(but negative):
27 [0] 1 2 3 4 5
28 1 [0] 1 2 3 4
29 2 1 [0] 1 2 3
30 3 2 1 [0] 1 2
31
32FROM_BOTTOM_RIGHT(but negative):
33 2 1 [0] 1 2 3
34 3 2 1 [0] 1 2
35 4 3 2 1 [0] 1
36 5 4 3 2 1 [0]
37*/
38
39enum struct AlibiMode
40{
42 FROM_TOP_LEFT = 1, // keep sync with mask enum
44};
45
46template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
47struct Alibi
48{
49 static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
50 "for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
51
52 // RowMajor here means if pixel within the same thread are along the row, or col
53 // this may impact the performance of update(), while the result are the same.
54 // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
55 CK_TILE_HOST_DEVICE Alibi(DataType slope_,
56 index_t y_total_,
57 index_t x_total_,
59 {
60 slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
61
62 shift_left_up = [&]() {
63 if(RowMajor)
64 {
65 return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
66 }
67 else
68 {
69 return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
70 }
71 }();
72 shift_right_down = [&]() {
73 if(RowMajor)
74 {
75 return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
76 }
77 else
78 {
79 return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
80 }
81 }();
82 mode = mode_;
83 }
84
85 CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
86
88 {
89 if constexpr(LogMaxSadOprndSize <= 16)
90 {
91 return sad_u16(
92 static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
93 }
94
95 return sad_u32(x, y, acc);
96 }
97
98 CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
99 {
100 if constexpr(RowMajor)
101 {
102 // at least 3 instructions per row
103 index_t current_zero_point =
105
106 // for every threads, most of the pixels are along the row, below operation should be
107 // the main hot spot.
108 auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
110 0));
111 pixel += slope * position;
112 }
113 else
114 {
115 // at least 3 instructions per col;
116 index_t current_zero_point = mode == AlibiMode::VERTICAL
117 ? row_idx + col_idx + shift_right_down
118 : col_idx + shift_right_down;
119
120 // for every threads, most of the pixels are along the col, below operation should be
121 // the main hot spot.
122 auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
124 0));
125 pixel += slope * position;
126 }
127 }
128
129 DataType slope; // float?
130 index_t shift_left_up; // always possitive
131 index_t shift_right_down; // always possitive
133};
134
135template <typename DataType>
137{
138 CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
139 {
140 }
141};
142
143//
144// can convert from the FA style left/right to our generic coordinate
145// if left_size < 0 && right_size = 0, it is normal causal mask
146// local is left_size >=0 or right_size >=0
147template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
149 index_t window_left_size,
150 index_t window_right_size,
151 index_t y_total,
152 index_t x_total,
153 GenericAttentionMaskEnum mask_enum)
154{
155 // assume mask_enum will never be NO_MASK, since if we do not have mask, it's
156 // totally OK to use constexpr
157 bool is_causal = window_left_size < 0 && window_right_size == 0;
158 AlibiMode alibi_mode =
159 is_causal ? AlibiMode::VERTICAL
160 : static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
161 return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
162}
163
164// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
165// Do we need a device version?
166template <typename DataType>
167CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
168{
169 auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
170 float start = std::powf(
171 static_cast<float>(2),
172 -std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
173
174 std::vector<DataType> rtn;
175 for(auto i = 0; i < n; i++)
176 {
177 rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
178 }
179 return rtn;
180 };
181 if(is_power_of_two_integer(nheads))
182 {
183 // power of 2 calculation
184 return get_slopes_power_of_2(nheads);
185 }
186 else
187 {
188 ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
189 auto v0 = get_slopes_power_of_2(closest_power_of_2);
190 auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
191 auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
192 std::vector<DataType> sliced;
193 for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
194 {
195 if(i % 2 == 0)
196 sliced.push_back(vec[i]);
197 }
198 std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
199 return sliced_2;
200 }(v1, nheads - closest_power_of_2);
201 v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
202 return v0;
203 }
204}
205} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
constexpr uint32_t ALIBI
Definition variants.hpp:232
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
Definition tile/core/numeric/math.hpp:504
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
Definition tile/core/numeric/math.hpp:499
PositionEncodingEnum
Definition block_position_encoding.hpp:14
@ NO
Definition block_position_encoding.hpp:15
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
Definition tile/core/numeric/math.hpp:455
GenericAttentionMaskEnum
Definition block_masking.hpp:11
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
AlibiMode
Definition block_position_encoding.hpp:40
@ VERTICAL
Definition block_position_encoding.hpp:41
@ FROM_TOP_LEFT
Definition block_position_encoding.hpp:42
@ FROM_BOTTOM_RIGHT
Definition block_position_encoding.hpp:43
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST std::vector< DataType > get_alibi_slopes(ck_tile::index_t nheads)
Definition block_position_encoding.hpp:167
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
Definition block_position_encoding.hpp:48
CK_TILE_HOST_DEVICE Alibi(DataType slope_, index_t y_total_, index_t x_total_, AlibiMode mode_=AlibiMode::VERTICAL)
Definition block_position_encoding.hpp:55
AlibiMode mode
Definition block_position_encoding.hpp:132
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
Definition block_position_encoding.hpp:87
index_t shift_right_down
Definition block_position_encoding.hpp:131
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
Definition block_position_encoding.hpp:85
DataType slope
Definition block_position_encoding.hpp:129
CK_TILE_HOST_DEVICE void update(DataType &pixel, index_t row_idx, index_t col_idx)
Definition block_position_encoding.hpp:98
index_t shift_left_up
Definition block_position_encoding.hpp:130
Definition block_position_encoding.hpp:137
CK_TILE_HOST_DEVICE void update(DataType &, index_t, index_t)
Definition block_position_encoding.hpp:138