mxfp_utils.hpp Source File

mxfp_utils.hpp Source File#

Composable Kernel: mxfp_utils.hpp Source File
mxfp_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8#ifdef CK_CODE_GEN_RTC
9#define UINT_MAX 4294967295
10#endif
11namespace ck::utils {
12
18
19template <typename DTYPE>
20inline bool getDataHasInf()
21{
22 return DTYPE::dataInfo.hasInf;
23}
24
25template <typename T>
26__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
27
28template <typename T>
29__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
30
31template <typename T>
32__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
33
34template <typename T>
35__host__ __device__ inline constexpr int32_t get_exponent_value(T x)
36{
38
39 x &= ((1 << NumericUtils<T>::exp) - 1);
40
41 return static_cast<int32_t>(x);
42}
43
44template <typename T>
45__host__ __device__ inline bool is_subnormal(T x)
46{
47 return get_exponent_value<T>(x) == 0;
48}
49
50template <typename T>
51__host__ __device__ inline double get_mantissa_value(T x)
52{
53 double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
54
55 for(uint i = 0; i < NumericUtils<T>::mant; i++)
56 {
57
58 mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
59
60 x >>= 1;
61 }
62
63 return mantissa;
64}
65
66template <typename T>
67__host__ __device__ inline bool get_data_has_inf()
68{
70}
71
72template <typename T>
73__host__ __device__ float convert_to_float(T data, int scale_exp)
74{
75 float d_sign =
76 std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
77
78 float d_exp;
79 if(is_subnormal<T>(data))
80 d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
81 else
82 d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
83 float d_mant = get_mantissa_value<T>(data);
84
85 float data_value = d_sign * d_exp * d_mant;
86 float scale_value = std::pow(
87 2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
88
89 return data_value * scale_value;
90}
91
92template <typename T>
93__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
94
95template <typename T>
96__host__ __device__ T sat_convert_to_type(float value);
97
98template <typename T>
99__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
100
101template <typename T>
102__host__ __device__ inline T convert_to_type(float value)
103{
104 using bitwise_type = typename NumericUtils<T>::bitwise_type;
105
106 if(std::abs(value) > NumericLimits<T>::Max())
107 {
108 float max_value = NumericLimits<T>::Max();
109
110 cvt t;
111
112 // cppcheck-suppress redundantAssignment
113 t.value_float = max_value;
114 uint32_t max_bitwise = t.value_bitwise;
115
116 // cppcheck-suppress redundantAssignment
117 t.value_float = value;
118 bitwise_type sign =
120 bitwise_type exp =
123 bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
124
125 uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
126 mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
127 mant_prev--;
128
130 uint32_t prev_bit =
131 ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
132
133 t.value_bitwise = prev_bit;
134 float prev_val = t.value_float;
135 float diff = max_value - prev_val;
136
137 float actual_max = max_value + (diff / 2);
138
139 if(std::abs(value) < actual_max)
140 {
141 return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
142 (exp << NumericUtils<T>::mant) | mantissa;
143 }
144 else
145 {
147 {
148
149 return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
150 }
151 else
152 {
153 exp++;
154 return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
155 (exp << NumericUtils<T>::mant);
156 }
157 }
158 }
159 const int mfmt = NumericUtils<float>::mant;
160 uint32_t x;
162
163 uint32_t head, mantissa;
164 int32_t exponent, bias;
165 uint32_t sign;
166
168 mantissa = x & NumericUtils<float>::mant_mask;
172
173 if(x == 0)
174 {
175 return 0b0;
176 }
177
178 const int mini_bias = NumericUtils<T>::bias;
179 const int mini_denormal_act_exponent = 1 - mini_bias;
180
181 int act_exponent, out_exponent, exponent_diff;
182
183 bool is_subnorm = false;
184
185 if(exponent == 0)
186 {
187 act_exponent = exponent - bias + 1;
188 exponent_diff = mini_denormal_act_exponent - act_exponent;
189 is_subnorm = true;
190 }
191 else
192 {
193 act_exponent = exponent - bias;
194 if(act_exponent <= mini_denormal_act_exponent)
195 {
196 exponent_diff = mini_denormal_act_exponent - act_exponent;
197 is_subnorm = true;
198 }
199 else
200 {
201 exponent_diff = 0;
202 }
203 mantissa += (1UL << mfmt);
204 }
205
206 auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
207 shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
208 bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
209
210 float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
211
212 if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
213 {
214 // closer to 0
215 if(std::abs(value) <= std::abs(min_subnorm - value))
217 else
218 return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
219 }
220
221 if(exponent_diff > 0)
222 mantissa >>= exponent_diff;
223 else if(exponent_diff == -1)
224 mantissa <<= -exponent_diff;
225 bool implicit_one = mantissa & (1 << mfmt);
226 out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
227
228 uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
229 bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
230 mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
231
232 if(out_exponent == 0)
233 {
234 if((1UL << mfmt) & mantissa)
235 {
236 out_exponent = 1;
237 }
238 }
239 else
240 {
241 if((1UL << (mfmt + 1)) & mantissa)
242 {
243 mantissa >>= 1;
244 out_exponent++;
245 }
246 }
247
248 mantissa >>= (mfmt - NumericUtils<T>::mant);
249
250 if(out_exponent == 0 && mantissa == 0)
251 {
253 }
254
255 mantissa &= (1UL << NumericUtils<T>::mant) - 1;
256 return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
257 (out_exponent << NumericUtils<T>::mant) | mantissa;
258}
259
260template <typename T>
261__host__ __device__ inline T convert_to_type_sr(float value, uint32_t seed)
262{
263 if(std::abs(value) > NumericLimits<T>::Max())
264 {
265 float max_value = NumericLimits<T>::Max();
266
267 cvt t;
268
269 // cppcheck-suppress redundantAssignment
270 t.value_float = max_value;
271 uint max_bitwise = t.value_bitwise;
272
273 // cppcheck-suppress redundantAssignment
274 t.value_float = value;
276 T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
278
279 uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
280 mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
281 mant_prev--;
282
284 uint32_t prev_bit =
285 ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
286
287 t.value_bitwise = prev_bit;
288 float prev_val = t.value_float;
289 float diff = max_value - prev_val;
290
291 float actual_max = max_value + (diff / 2);
292
293 if(std::abs(value) < actual_max)
294 {
295 double d_max_value = static_cast<double>(max_value);
296 double d_actual_max = static_cast<double>(actual_max);
297 double d_value = static_cast<double>(value);
298 double d_is = std::abs(d_max_value - d_actual_max);
299 double d_seed = static_cast<double>(seed);
300 double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
301
302 double thresh = UINT_MAX * d_prob;
303
304 if(!get_data_has_inf<T>() || d_seed <= thresh)
305 // return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
308 else
309 {
310 exp++;
311 return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
312 | (exp << NumericUtils<T>::mant);
313 }
314 }
315 else
316 {
318 return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
319 else
320 {
321 exp++;
322 return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
323 | (exp << NumericUtils<T>::mant);
324 }
325 }
326 }
327
329
330 auto f32_mant = f32 & NumericUtils<float>::mant_mask;
331 auto head = f32 & NumericUtils<float>::head_mask;
333
334 auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
335 auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
336
337 f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
338 int32_t exp = f32_exp;
339 auto mant = f32_mant;
340 bool subnorm = false;
341
342 if(f32 == 0)
343 return 0b0;
344
346 {
347 mant = f32_mant;
348 }
349 // if the exponent bit is 8, then the subnormal is exactly the same as f32
350 else if(exp < NumericUtils<T>::unbiased_exp_min &&
352 {
353 subnorm = true;
354 auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
355 if(diff >= 32)
356 {
357 mant = 0;
358 f32_mant = 0;
359 }
360 else
361 {
362 f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
363 f32_mant >>= diff;
364 }
365 exp = 0;
366 mant = f32_mant;
367 }
368
370
371 // For stochastic-rounding we add the aligned random value to the
372 // mantissa and then truncate (RTZ).
373 mant += seed >> sr_shift;
374
375 // Increment exponent when mantissa overflows due to rounding
376 if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
377 ++exp;
379 mant &= ((1 << NumericUtils<T>::mant) - 1);
380
381 auto biased_exp = static_cast<uint32_t>(exp);
382 if(!subnorm)
383 biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
384 biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
385 auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
386 return val;
387}
388} // namespace ck::utils
Definition library/utility/check_err.hpp:24
__host__ __device__ T sat_convert_to_type(float value)
__host__ __device__ bool is_subnormal(T x)
Definition mxfp_utils.hpp:45
__host__ __device__ bool get_data_has_inf()
Definition mxfp_utils.hpp:67
__host__ __device__ constexpr int32_t get_exponent_value(T x)
Definition mxfp_utils.hpp:35
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed)
__host__ __device__ float convert_to_float(T data, int scale_exp)
Definition mxfp_utils.hpp:73
__host__ __device__ T convert_to_type_sr(float value, uint32_t seed)
Definition mxfp_utils.hpp:261
__host__ __device__ bool is_zero(e8m0_bexp_t const scale, T const data)
__host__ __device__ T convert_to_type(float value)
Definition mxfp_utils.hpp:102
__host__ __device__ bool is_inf(e8m0_bexp_t const scale, T const data)
__host__ __device__ double get_mantissa_value(T x)
Definition mxfp_utils.hpp:51
__host__ __device__ bool is_nan(e8m0_bexp_t const scale, T const data)
bool getDataHasInf()
Definition mxfp_utils.hpp:20
__host__ __device__ float to_float(e8m0_bexp_t const scale, T const data)
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
Definition numeric_limits.hpp:309
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
Definition numeric_utils.hpp:10
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition mxfp_utils.hpp:14
float value_float
Definition mxfp_utils.hpp:15
uint32_t value_bitwise
Definition mxfp_utils.hpp:16