HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
float8.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #if !defined(DISABLE_FLOAT8_TYPES)
7 
8 #include "endian.h"
9 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
10 #include "cuda_fp8.h"
11 #endif
12 
13 #if !defined(__CUDACC__) && !defined(__HIPCC__)
14 #include "core/common/narrow.h"
15 #endif
16 
17 #include "core/common/common.h"
18 
19 namespace onnxruntime {
20 
21 #if defined(__CUDACC__) || defined(__HIPCC__)
22 #define ORT_HOST_DEVICE __host__ __device__
23 #else
24 #define ORT_HOST_DEVICE
25 #endif
26 
27 // Float8E4M3FN
28 struct Float8E4M3FN {
29  uint8_t val{0};
30 #if defined(__HIP__)
31  ORT_HOST_DEVICE Float8E4M3FN() = default;
32 #else
33  Float8E4M3FN() = default;
34 #endif
35  struct FromBitsT {};
36  static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
37  constexpr ORT_HOST_DEVICE Float8E4M3FN(unsigned char bits, FromBitsT) : val(bits) {}
38 
39  inline explicit ORT_HOST_DEVICE Float8E4M3FN(float v, bool saturate = true) {
40 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
41  val = __nv_cvt_float_to_fp8(v, saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E4M3);
42 #else
43  uint32_t b;
44  std::memcpy(&b, &v, sizeof(b));
45 
46  val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
47  if ((b & 0x7fffffff) == 0x7f800000) { // infinity
48  if (saturate) {
49  val |= 126;
50  } else {
51  val |= 0x7f;
52  }
53  } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
54  val |= 0x7f;
55  } else {
56  uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
57  uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
58  if (e != 0) {
59  if (e < 117) {
60  } else if (e < 121) {
61  // denormalized number
62  auto d = 120 - e;
63  if (d < 3) {
64  val |= 1 << (2 - d);
65  val |= m >> (21 + d);
66  } else if (m > 0) {
67  val |= 1;
68  }
69  auto mask = 1 << (20 + d);
70  if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
71  // rounding
72  val += 1;
73  }
74  } else if (e < 136) {
75  // normalized number
76  auto ex = e - 120;
77  if (ex == 0) {
78  val |= 0x4;
79  val |= m >> 21;
80  } else {
81  val |= ex << 3;
82  val |= m >> 20;
83  if ((val & 0x7F) == 0x7F) {
84  val &= 0xFE;
85  }
86  }
87  if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
88  if ((val & 0x7F) < 0x7E) {
89  // rounding
90  val += 1;
91  } else if (!saturate) {
92  val |= 0x7F;
93  }
94  }
95  } else if (saturate) {
96  val |= 126; // 0b01111110
97  } else {
98  val |= 0x7F;
99  }
100  }
101  }
102 #endif
103  }
104 
105  inline ORT_HOST_DEVICE bool IsNaN() const {
106  return (val & 0b01111111) == 0b01111111;
107  }
108 
109  inline ORT_HOST_DEVICE float ToFloat() const {
110 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
111  return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3));
112 #else
113  uint32_t res;
114  if (val == 255) {
115  res = 0xffc00000;
116  } else if (val == 127) {
117  res = 0x7fc00000;
118  } else {
119  uint32_t expo = (val & 0x78) >> 3;
120  uint32_t mant = val & 0x07;
121  uint32_t sign = val & 0x80;
122  res = sign << 24;
123  if (expo == 0) {
124  if (mant > 0) {
125  expo = 0x7F - 7;
126  if ((mant & 0x4) == 0) {
127  mant &= 0x3;
128  mant <<= 1;
129  expo -= 1;
130  }
131  if ((mant & 0x4) == 0) {
132  mant &= 0x3;
133  mant <<= 1;
134  expo -= 1;
135  }
136  res |= (mant & 0x3) << 21;
137  res |= expo << 23;
138  }
139  } else {
140  res |= mant << 20;
141  expo -= 0x7;
142  expo += 0x7F;
143  res |= expo << 23;
144  }
145  }
146  float float_res;
147  std::memcpy(&float_res, &res, sizeof(float));
148  return float_res;
149 #endif
150  }
151 
152  inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
153 
154 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
155  explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast<const unsigned char*>(&value); }
156  explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast<const __nv_fp8_e4m3*>(&val); }
157 #endif
158 };
159 
160 inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; }
161 inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; }
162 inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; }
163 
164 // User defined suffixes to make it easier to declare
165 // initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char
166 #if !defined(__CUDACC__) && !defined(__HIPCC__)
167 
168 inline Float8E4M3FN operator"" _f8e4m3fn(unsigned long long int v) {
169  return Float8E4M3FN(narrow<uint8_t>(v), Float8E4M3FN::FromBits());
170 }
171 
172 inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) {
173  return Float8E4M3FN(static_cast<float>(v), true);
174 }
175 
176 #endif
177 
178 inline void Float8E4M3FNToFloat(const Float8E4M3FN* blf, float* flt, size_t size) {
179  auto src = blf;
180  auto d = flt;
181  for (; size != 0; ++src, ++d, --size) {
182  *d = src->ToFloat();
183  }
184 }
185 
186 inline void FloatToFloat8E4M3FN(const float* flt, Float8E4M3FN* blf, size_t size, bool saturate) {
187  auto src = flt;
188  auto d = blf;
189  for (; size != 0; ++src, ++d, --size) {
190  new (d) Float8E4M3FN(*src, saturate);
191  }
192 }
193 
194 // Float8E4M3FNUZ
196  uint8_t val{0};
197 #if defined(__HIP__)
198  ORT_HOST_DEVICE Float8E4M3FNUZ() = default;
199 #else
200  Float8E4M3FNUZ() = default;
201 #endif
202 
203  struct FromBitsT {};
204  static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
205  constexpr ORT_HOST_DEVICE Float8E4M3FNUZ(unsigned char bits, FromBitsT) : val(bits) {}
206 
207  inline explicit ORT_HOST_DEVICE Float8E4M3FNUZ(float v, bool saturate = true) {
208  // This type does not exist on CUDA.
209  uint32_t b;
210  std::memcpy(&b, &v, sizeof(b));
211 
212  val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
213  if ((b & 0x7fffffff) == 0x7f800000) { // infinity
214  if (saturate) {
215  // the highest available value
216  val |= 0x7F;
217  } else {
218  // NaN
219  val = 0x80;
220  }
221  } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
222  val = 0x80;
223  } else {
224  uint8_t e = static_cast<uint8_t>((b & 0x7F800000) >> 23); // exponent
225  uint32_t m = static_cast<uint32_t>(b & 0x007FFFFF); // mantissa
226 
227  if (e < 116) {
228  // all near-zero numbers round to positive zero:
229  val = 0;
230  } else if (e < 120) {
231  // denormalized number
232  auto d = 119 - e;
233  if (d < 3) {
234  val |= 1 << (2 - d);
235  val |= m >> (21 + d);
236  } else if (m > 0) {
237  val |= 1;
238  } else {
239  // round to positive zero:
240  val = 0;
241  }
242  auto mask = 1 << (20 + d);
243  if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
244  // rounding
245  val += 1;
246  }
247  } else if (e < 135) {
248  // normalized number
249  auto ex = e - 119;
250  if (ex == 0) {
251  val |= 0x4;
252  val |= m >> 21;
253  } else {
254  val |= ex << 3;
255  val |= m >> 20;
256  }
257  if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
258  if ((val & 0x7F) < 0x7F) {
259  // rounding
260  val += 1;
261  } else if (!saturate) {
262  val = 0x80;
263  }
264  }
265  } else if (saturate) {
266  val |= 0x7F;
267  } else {
268  val = 0x80;
269  }
270  }
271  }
272 
273  inline ORT_HOST_DEVICE bool IsNaN() const {
274  return val == 0b10000000;
275  }
276 
277  inline ORT_HOST_DEVICE float ToFloat() const {
278  // This type does not exist on CUDA.
279  uint32_t res;
280  if (val == 0x80) {
281  res = 0xffc00000;
282  } else {
283  uint32_t expo = (val & 0x78) >> 3;
284  uint32_t mant = val & 0x07;
285  uint32_t sign = val & 0x80;
286  res = sign << 24;
287  if (expo == 0) {
288  if (mant > 0) {
289  expo = 0x7F - 8;
290  if ((mant & 0x4) == 0) {
291  mant &= 0x3;
292  mant <<= 1;
293  expo -= 1;
294  }
295  if ((mant & 0x4) == 0) {
296  mant &= 0x3;
297  mant <<= 1;
298  expo -= 1;
299  }
300  res |= (mant & 0x3) << 21;
301  res |= expo << 23;
302  }
303  } else {
304  res |= mant << 20;
305  expo -= 8;
306  expo += 0x7F;
307  res |= expo << 23;
308  }
309  }
310  float float_res;
311  std::memcpy(&float_res, &res, sizeof(float));
312  return float_res;
313  }
314 
315  inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
316 };
317 
318 inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; }
319 inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; }
320 inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; }
321 
322 // User defined suffixes to make it easier to declare
323 // initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char
324 #if !defined(__CUDACC__) && !defined(__HIPCC__)
325 
326 inline Float8E4M3FNUZ operator"" _f8e4m3p8fnuz(unsigned long long int v) {
327  return Float8E4M3FNUZ(narrow<uint8_t>(v), Float8E4M3FNUZ::FromBits());
328 }
329 
330 inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) {
331  return Float8E4M3FNUZ(static_cast<float>(v), true);
332 }
333 
334 #endif
335 
336 inline void Float8E4M3FNUZToFloat(const Float8E4M3FNUZ* blf, float* flt, size_t size) {
337  auto src = blf;
338  auto d = flt;
339  for (; size != 0; ++src, ++d, --size) {
340  *d = src->ToFloat();
341  }
342 }
343 
344 inline void FloatToFloat8E4M3FNUZ(const float* flt, Float8E4M3FNUZ* blf, size_t size, bool saturate) {
345  auto src = flt;
346  auto d = blf;
347  for (; size != 0; ++src, ++d, --size) {
348  new (d) Float8E4M3FNUZ(*src, saturate);
349  }
350 }
351 
352 // Float8E5M2
353 struct Float8E5M2 {
354  uint8_t val{0};
355 #if defined(__HIP__)
356  ORT_HOST_DEVICE Float8E5M2() = default;
357 #else
358  Float8E5M2() = default;
359 #endif
360 
361  struct FromBitsT {};
362  static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
363  constexpr ORT_HOST_DEVICE Float8E5M2(unsigned char bits, FromBitsT) : val(bits) {}
364 
365  inline explicit ORT_HOST_DEVICE Float8E5M2(float v, bool saturate = true) {
366 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
367  val = __nv_cvt_float_to_fp8(v, saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E5M2);
368 #else
369  uint32_t b;
370  std::memcpy(&b, &v, sizeof(b));
371 
372  val = (b & 0x80000000) >> 24; // sign
373  if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
374  if (saturate) {
375  // the highest available value
376  val |= 0x7B;
377  } else {
378  // the infinity
379  val |= 0x7C;
380  }
381  } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
382  val |= 0x7f;
383  } else {
384  uint32_t e = (b & 0x7F800000) >> 23; // exponent
385  uint32_t m = b & 0x007FFFFF; // mantissa
386 
387  if (e != 0) {
388  if (e < 110) {
389  } else if (e < 113) {
390  // denormalized number
391  auto d = 112 - e;
392  if (d < 2) {
393  val |= 1 << (1 - d);
394  val |= m >> (22 + d);
395  } else if (m > 0) {
396  val |= 1;
397  }
398  auto mask = 1 << (21 + d);
399  if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
400  // rounding
401  val += 1;
402  }
403  } else if (e < 143) { // 127 + 15 + 1
404  auto ex = e - 112; // 127 - 15
405  val |= ex << 2;
406  val |= m >> 21;
407  if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
408  if ((val & 0x7F) < 0x7B) {
409  // rounding
410  val += 1;
411  } else if (saturate) {
412  val |= 0x7B;
413  } else {
414  val |= 0x7C;
415  }
416  }
417  } else if (saturate) {
418  val |= 0x7B;
419  } else {
420  val |= 0x7C;
421  }
422  }
423  }
424 #endif
425  }
426 
427  inline ORT_HOST_DEVICE bool IsNaN() const {
428  // 7D, 7E, 7F are positive NaNs; FD, FE, FF are negative NaNs
429  return (val & 0b01111111) > 0b01111100;
430  }
431 
432  inline ORT_HOST_DEVICE bool IsInfinity() const {
433  // 7C and FC are infinity
434  return (val & 0b01111111) == 0b01111100;
435  }
436 
437  inline ORT_HOST_DEVICE float ToFloat() const {
438 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
439  return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E5M2));
440 #else
441  uint32_t res;
442  if (val >= 253) {
443  res = 0xffc00000;
444  } else if (val >= 125 && val <= 127) {
445  res = 0x7fc00000;
446  } else if (val == 252) {
447  res = 0xff800000;
448  } else if (val == 124) {
449  res = 0x7f800000;
450  } else {
451  uint32_t expo = (val & 0x7C) >> 2;
452  uint32_t mant = val & 0x03;
453  uint32_t sign = val & 0x80;
454  res = sign << 24;
455  if (expo == 0) {
456  if (mant > 0) {
457  expo = 0x7F - 15;
458  if ((mant & 0x2) == 0) {
459  mant &= 0x1;
460  mant <<= 1;
461  expo -= 1;
462  }
463  res |= (mant & 0x1) << 22;
464  res |= expo << 23;
465  }
466  } else {
467  res |= mant << 21;
468  expo -= 15;
469  expo += 0x7F;
470  res |= expo << 23;
471  }
472  }
473 
474  float float_res;
475  std::memcpy(&float_res, &res, sizeof(float));
476  return float_res;
477 #endif
478  }
479 
480  inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
481 
482 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
483  ORT_HOST_DEVICE Float8E5M2(const __nv_fp8_e5m2& value) { val = *reinterpret_cast<const unsigned char*>(&value); }
484  explicit ORT_HOST_DEVICE operator __nv_fp8_e5m2() const { return *reinterpret_cast<const __nv_fp8_e5m2*>(&val); }
485 #endif
486 };
487 
488 inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; }
489 inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; }
490 inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; }
491 
492 // User defined suffixes to make it easier to declare
493 // initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char
494 #if !defined(__CUDACC__) && !defined(__HIPCC__)
495 
496 inline Float8E5M2 operator"" _f8e5m2fn(unsigned long long int v) {
497  return Float8E5M2(narrow<uint8_t>(v), Float8E5M2::FromBits());
498 }
499 
500 inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) {
501  return Float8E5M2(static_cast<float>(v), true);
502 }
503 
504 #endif
505 
506 inline void Float8E5M2ToFloat(const Float8E5M2* blf, float* flt, size_t size) {
507  auto src = blf;
508  auto d = flt;
509  for (; size != 0; ++src, ++d, --size) {
510  *d = src->ToFloat();
511  }
512 }
513 
514 inline void FloatToFloat8E5M2(const float* flt, Float8E5M2* blf, size_t size, bool saturate) {
515  auto src = flt;
516  auto d = blf;
517  for (; size != 0; ++src, ++d, --size) {
518  new (d) Float8E5M2(*src, saturate);
519  }
520 }
521 
522 // Float8E5M2FNUZ
524  uint8_t val{0};
525 #if defined(__HIP__)
526  ORT_HOST_DEVICE Float8E5M2FNUZ() = default;
527 #else
528  Float8E5M2FNUZ() = default;
529 #endif
530 
531  struct FromBitsT {};
532  static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
533  constexpr ORT_HOST_DEVICE Float8E5M2FNUZ(unsigned char bits, FromBitsT) : val(bits) {}
534 
535  inline explicit ORT_HOST_DEVICE Float8E5M2FNUZ(float v, bool saturate = true) {
536  // This type does not exist on CUDA.
537  uint32_t b;
538  std::memcpy(&b, &v, sizeof(b));
539 
540  val = (b & 0x80000000) >> 24; // sign
541  if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
542  if (saturate) {
543  val |= 0x7F;
544  } else {
545  val = 0x80;
546  }
547  } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
548  val = 0x80;
549  } else {
550  uint32_t e = (b & 0x7F800000) >> 23; // exponent
551  uint32_t m = b & 0x007FFFFF; // mantissa
552 
553  if (e < 109) {
554  // all near-zero numbers round to positive zero:
555  val = 0;
556  } else if (e < 112) {
557  // denormalized number
558  auto d = 111 - e;
559  if (d < 2) {
560  val |= 1 << (1 - d);
561  val |= m >> (22 + d);
562  } else if (m > 0) {
563  val |= 1;
564  } else {
565  // round to positive zero:
566  val = 0;
567  }
568  auto mask = 1 << (21 + d);
569  if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
570  // rounding
571  val += 1;
572  }
573  } else if (e < 143) {
574  // normalized number
575  auto ex = e - 111;
576  val |= ex << 2;
577  val |= m >> 21;
578  if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
579  if ((val & 0x7F) < 0x7F) {
580  // rounding
581  val += 1;
582  } else if (!saturate) {
583  val = 0x80;
584  }
585  }
586  } else if ((e == 255) && (m == 0)) {
587  val = 0x80;
588  } else if (saturate) {
589  val |= 0x7F;
590  } else {
591  val = 0x80;
592  }
593  }
594  }
595 
596  inline ORT_HOST_DEVICE bool IsNaN() const {
597  return val == 0b10000000;
598  }
599 
600  inline ORT_HOST_DEVICE float ToFloat() const {
601  // This type does not exist on CUDA.
602  uint32_t res;
603  if (val == 0x80) {
604  res = 0xffc00000;
605  } else {
606  uint32_t expo = (val & 0x7C) >> 2;
607  uint32_t mant = val & 0x03;
608  uint32_t sign = val & 0x80;
609  res = sign << 24;
610  if (expo == 0) {
611  if (mant > 0) {
612  expo = 0x7F - 16;
613  if ((mant & 0x2) == 0) {
614  mant &= 0x1;
615  mant <<= 1;
616  expo -= 1;
617  }
618  res |= (mant & 0x1) << 22;
619  res |= expo << 23;
620  }
621  } else {
622  res |= mant << 21;
623  expo -= 16;
624  expo += 0x7F;
625  res |= expo << 23;
626  }
627  }
628 
629  float float_res;
630  std::memcpy(&float_res, &res, sizeof(float));
631  return float_res;
632  }
633 
634  inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
635 };
636 
637 inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; }
638 inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; }
639 inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; }
640 
641 // User defined suffixes to make it easier to declare
642 // initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char
643 #if !defined(__CUDACC__) && !defined(__HIPCC__)
644 
645 inline Float8E5M2FNUZ operator"" _f8e5m2fnuz(unsigned long long int v) {
646  return Float8E5M2FNUZ(narrow<uint8_t>(v), Float8E5M2FNUZ::FromBits());
647 }
648 
649 inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) {
650  return Float8E5M2FNUZ(static_cast<float>(v), true);
651 }
652 
653 #endif
654 
655 inline void Float8E5M2FNUZToFloat(const Float8E5M2FNUZ* blf, float* flt, size_t size) {
656  auto src = blf;
657  auto d = flt;
658  for (; size != 0; ++src, ++d, --size) {
659  *d = src->ToFloat();
660  }
661 }
662 
663 inline void FloatToFloat8E5M2FNUZ(const float* flt, Float8E5M2FNUZ* blf, size_t size, bool saturate) {
664  auto src = flt;
665  auto d = blf;
666  for (; size != 0; ++src, ++d, --size) {
667  new (d) Float8E5M2FNUZ(*src, saturate);
668  }
669 }
670 
671 } // namespace onnxruntime
672 
673 namespace std {
674 
675 template <>
676 class numeric_limits<onnxruntime::Float8E4M3FN> {
677  public:
678  static constexpr onnxruntime::Float8E4M3FN lowest() {
680  }
681 
682  static constexpr onnxruntime::Float8E4M3FN max() {
684  }
685 
686  static constexpr onnxruntime::Float8E4M3FN min() {
687  return onnxruntime::Float8E4M3FN(0x08, onnxruntime::Float8E4M3FN::FromBits()); // 2^-6 = 0.015625
688  }
689 
691  return onnxruntime::Float8E4M3FN(0x01, onnxruntime::Float8E4M3FN::FromBits()); // 2^-9 = 0.001953125
692  }
693 
694  static constexpr onnxruntime::Float8E4M3FN epsilon() {
696  }
697 
700  }
701 
702  static constexpr onnxruntime::Float8E4M3FN infinity() {
703  // no infinity, returns quiet NaN instead
704  return quiet_NaN();
705  }
706 
709  }
710 
711  static constexpr bool is_specialized = true;
712  static constexpr bool is_signed = true;
713  static constexpr bool is_integer = false;
714  static constexpr bool is_exact = false;
715  static constexpr bool has_infinity = false;
716  static constexpr bool has_quiet_NaN = true;
717  static constexpr bool has_signaling_NaN = false;
718  static constexpr auto has_denorm = true;
719  static constexpr auto has_denorm_loss = true;
720  static constexpr auto round_style = round_to_nearest;
721  static constexpr bool is_iec559 = false;
722  static constexpr bool is_bounded = true;
723  static constexpr bool is_modulo = false;
724  static constexpr int digits = 4;
725  static constexpr int digits10 = 0;
726  static constexpr int max_digits10 = 3;
727  static constexpr int radix = 2;
728  static constexpr int min_exponent = -5;
729  static constexpr int min_exponent10 = -1;
730  static constexpr int max_exponent = 8;
731  static constexpr int max_exponent10 = 2;
732  static constexpr auto traps = false;
733  static constexpr auto tinyness_before = false;
734 };
735 
736 template <>
737 class numeric_limits<onnxruntime::Float8E5M2> {
738  public:
739  static constexpr onnxruntime::Float8E5M2 lowest() {
741  }
742 
743  static constexpr onnxruntime::Float8E5M2 max() {
745  }
746 
747  static constexpr onnxruntime::Float8E5M2 min() {
748  return onnxruntime::Float8E5M2(0x4, onnxruntime::Float8E5M2::FromBits()); // 2^-14 = 0.00006103515
749  }
750 
751  static constexpr onnxruntime::Float8E5M2 denorm_min() {
752  return onnxruntime::Float8E5M2(0x01, onnxruntime::Float8E5M2::FromBits()); // 2^-16 = 0.00001525878
753  }
754 
755  static constexpr onnxruntime::Float8E5M2 epsilon() {
757  }
758 
761  }
762 
763  static constexpr onnxruntime::Float8E5M2 infinity() {
765  }
766 
767  static constexpr onnxruntime::Float8E5M2 quiet_NaN() {
769  }
770 
771  static constexpr bool is_specialized = true;
772  static constexpr bool is_signed = true;
773  static constexpr bool is_integer = false;
774  static constexpr bool is_exact = false;
775  static constexpr bool has_infinity = true;
776  static constexpr bool has_quiet_NaN = true;
777  static constexpr bool has_signaling_NaN = false;
778  static constexpr auto has_denorm = true;
779  static constexpr auto has_denorm_loss = true;
780  static constexpr auto round_style = round_to_nearest;
781  static constexpr bool is_iec559 = false;
782  static constexpr bool is_bounded = true;
783  static constexpr bool is_modulo = false;
784  static constexpr int digits = 3;
785  static constexpr int digits10 = 0;
786  static constexpr int max_digits10 = 2;
787  static constexpr int radix = 2;
788  static constexpr int min_exponent = -13;
789  static constexpr int min_exponent10 = -4;
790  static constexpr int max_exponent = 16;
791  static constexpr int max_exponent10 = 4;
792  static constexpr auto traps = false;
793  static constexpr auto tinyness_before = false;
794 };
795 
796 template <>
797 class numeric_limits<onnxruntime::Float8E4M3FNUZ> {
798  public:
799  static constexpr onnxruntime::Float8E4M3FNUZ lowest() {
801  }
802 
803  static constexpr onnxruntime::Float8E4M3FNUZ max() {
805  }
806 
807  static constexpr onnxruntime::Float8E4M3FNUZ min() {
808  return onnxruntime::Float8E4M3FNUZ(0x08, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-7 = 0.0078125
809  }
810 
812  return onnxruntime::Float8E4M3FNUZ(0x01, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-10 = 0.0009765625
813  }
814 
817  }
818 
821  }
822 
824  // no infinity, returns quiet NaN instead
825  return quiet_NaN();
826  }
827 
830  }
831 
832  static constexpr bool is_specialized = true;
833  static constexpr bool is_signed = true;
834  static constexpr bool is_integer = false;
835  static constexpr bool is_exact = false;
836  static constexpr bool has_infinity = false;
837  static constexpr bool has_quiet_NaN = true;
838  static constexpr bool has_signaling_NaN = false;
839  static constexpr auto has_denorm = true;
840  static constexpr auto has_denorm_loss = true;
841  static constexpr auto round_style = round_to_nearest;
842  static constexpr bool is_iec559 = false;
843  static constexpr bool is_bounded = true;
844  static constexpr bool is_modulo = false;
845  static constexpr int digits = 4;
846  static constexpr int digits10 = 0;
847  static constexpr int max_digits10 = 3;
848  static constexpr int radix = 2;
849  static constexpr int min_exponent = -6;
850  static constexpr int min_exponent10 = -1;
851  static constexpr int max_exponent = 8;
852  static constexpr int max_exponent10 = 2;
853  static constexpr auto traps = false;
854  static constexpr auto tinyness_before = false;
855 };
856 
857 template <>
858 class numeric_limits<onnxruntime::Float8E5M2FNUZ> {
859  public:
860  static constexpr onnxruntime::Float8E5M2FNUZ lowest() {
862  }
863 
864  static constexpr onnxruntime::Float8E5M2FNUZ max() {
866  }
867 
868  static constexpr onnxruntime::Float8E5M2FNUZ min() {
869  return onnxruntime::Float8E5M2FNUZ(0x04, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-15 = 0.00003051757
870  }
871 
873  return onnxruntime::Float8E5M2FNUZ(0x01, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-17 = 0.00000762939
874  }
875 
878  }
879 
882  }
883 
885  // no infinity, returns quiet NaN instead
886  return quiet_NaN();
887  }
888 
891  }
892 
893  static constexpr bool is_specialized = true;
894  static constexpr bool is_signed = true;
895  static constexpr bool is_integer = false;
896  static constexpr bool is_exact = false;
897  static constexpr bool has_infinity = false;
898  static constexpr bool has_quiet_NaN = true;
899  static constexpr bool has_signaling_NaN = false;
900  static constexpr auto has_denorm = true;
901  static constexpr auto has_denorm_loss = true;
902  static constexpr auto round_style = round_to_nearest;
903  static constexpr bool is_iec559 = false;
904  static constexpr bool is_bounded = true;
905  static constexpr bool is_modulo = false;
906  static constexpr int digits = 3;
907  static constexpr int digits10 = 0;
908  static constexpr int max_digits10 = 2;
909  static constexpr int radix = 2;
910  static constexpr int min_exponent = -14;
911  static constexpr int min_exponent10 = -4;
912  static constexpr int max_exponent = 16;
913  static constexpr int max_exponent10 = 4;
914  static constexpr auto traps = false;
915  static constexpr auto tinyness_before = false;
916 };
917 
918 } // namespace std
919 
920 #endif // DISABLE_FLOAT8_TYPES
static constexpr onnxruntime::Float8E4M3FN infinity()
Definition: float8.h:702
static constexpr onnxruntime::Float8E5M2 denorm_min()
Definition: float8.h:751
bool_constant< is_integral< T >::value &&!std::is_same< T, bool >::value &&!std::is_same< T, char >::value &&!std::is_same< T, wchar_t >::value > is_integer
Definition: format.h:824
void FloatToFloat8E4M3FNUZ(const float *flt, Float8E4M3FNUZ *blf, size_t size, bool saturate)
Definition: float8.h:344
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
Definition: float8.h:36
static constexpr onnxruntime::Float8E5M2 lowest()
Definition: float8.h:739
static constexpr onnxruntime::Float8E5M2FNUZ round_error()
Definition: float8.h:880
GLint left
Definition: glcorearb.h:2005
const GLdouble * v
Definition: glcorearb.h:837
static constexpr onnxruntime::Float8E4M3FNUZ lowest()
Definition: float8.h:799
static constexpr onnxruntime::Float8E4M3FN denorm_min()
Definition: float8.h:690
GLsizei const GLfloat * value
Definition: glcorearb.h:824
ORT_HOST_DEVICE bool operator<(const Float8E4M3FN &left, const Float8E4M3FN &right)
Definition: float8.h:162
static constexpr onnxruntime::Float8E5M2FNUZ infinity()
Definition: float8.h:884
GLdouble right
Definition: glad.h:2817
ORT_HOST_DEVICE float ToFloat() const
Definition: float8.h:600
static constexpr onnxruntime::Float8E4M3FN epsilon()
Definition: float8.h:694
ORT_HOST_DEVICE float ToFloat() const
Definition: float8.h:437
void Float8E4M3FNUZToFloat(const Float8E4M3FNUZ *blf, float *flt, size_t size)
Definition: float8.h:336
ORT_HOST_DEVICE Float8E4M3FNUZ(float v, bool saturate=true)
Definition: float8.h:207
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
Definition: float8.h:532
static constexpr onnxruntime::Float8E4M3FN quiet_NaN()
Definition: float8.h:707
static constexpr onnxruntime::Float8E4M3FNUZ max()
Definition: float8.h:803
static constexpr onnxruntime::Float8E4M3FNUZ denorm_min()
Definition: float8.h:811
static constexpr onnxruntime::Float8E5M2FNUZ min()
Definition: float8.h:868
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
Definition: float8.h:362
static constexpr onnxruntime::Float8E4M3FNUZ infinity()
Definition: float8.h:823
constexpr ORT_HOST_DEVICE Float8E4M3FN(unsigned char bits, FromBitsT)
Definition: float8.h:37
ORT_HOST_DEVICE bool operator==(const Float8E4M3FN &left, const Float8E4M3FN &right)
Definition: float8.h:160
ORT_HOST_DEVICE bool IsNaN() const
Definition: float8.h:427
GLdouble GLdouble x2
Definition: glad.h:2349
ORT_HOST_DEVICE bool IsNaN() const
Definition: float8.h:273
ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN &left, const Float8E4M3FN &right)
Definition: float8.h:161
void FloatToFloat8E4M3FN(const float *flt, Float8E4M3FN *blf, size_t size, bool saturate)
Definition: float8.h:186
static constexpr onnxruntime::Float8E5M2 quiet_NaN()
Definition: float8.h:767
ORT_HOST_DEVICE Float8E5M2FNUZ(float v, bool saturate=true)
Definition: float8.h:535
static constexpr onnxruntime::Float8E5M2 epsilon()
Definition: float8.h:755
static constexpr onnxruntime::Float8E4M3FNUZ min()
Definition: float8.h:807
static constexpr onnxruntime::Float8E5M2 max()
Definition: float8.h:743
GLint GLuint mask
Definition: glcorearb.h:124
static constexpr onnxruntime::Float8E4M3FNUZ round_error()
Definition: float8.h:819
constexpr ORT_HOST_DEVICE Float8E4M3FNUZ(unsigned char bits, FromBitsT)
Definition: float8.h:205
ORT_HOST_DEVICE Float8E5M2(float v, bool saturate=true)
Definition: float8.h:365
ImageBuf OIIO_API saturate(const ImageBuf &src, float scale=0.0f, int firstchannel=0, ROI roi={}, int nthreads=0)
IMATH_HOSTDEVICE constexpr int sign(T a) IMATH_NOEXCEPT
Definition: ImathFun.h:33
#define ORT_HOST_DEVICE
Definition: float8.h:24
GLboolean GLboolean GLboolean b
Definition: glcorearb.h:1222
static constexpr onnxruntime::Float8E5M2FNUZ denorm_min()
Definition: float8.h:872
ORT_HOST_DEVICE bool IsNaN() const
Definition: float8.h:105
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
Definition: float8.h:204
static constexpr onnxruntime::Float8E4M3FN min()
Definition: float8.h:686
void FloatToFloat8E5M2(const float *flt, Float8E5M2 *blf, size_t size, bool saturate)
Definition: float8.h:514
static constexpr onnxruntime::Float8E5M2FNUZ max()
Definition: float8.h:864
ORT_HOST_DEVICE bool IsInfinity() const
Definition: float8.h:432
ORT_HOST_DEVICE bool IsNaN() const
Definition: float8.h:596
constexpr auto digits10() noexcept-> int
Definition: format.h:1289
ORT_HOST_DEVICE Float8E4M3FN(float v, bool saturate=true)
Definition: float8.h:39
GLsizeiptr size
Definition: glcorearb.h:664
std::integral_constant< bool, std::numeric_limits< T >::is_signed||std::is_same< T, int128_opt >::value > is_signed
Definition: format.h:818
IMATH_NAMESPACE::V2f IMATH_NAMESPACE::Box2i std::string this attribute is obsolete as of OpenEXR v3 float
static constexpr onnxruntime::Float8E5M2 round_error()
Definition: float8.h:759
void Float8E5M2ToFloat(const Float8E5M2 *blf, float *flt, size_t size)
Definition: float8.h:506
void FloatToFloat8E5M2FNUZ(const float *flt, Float8E5M2FNUZ *blf, size_t size, bool saturate)
Definition: float8.h:663
static constexpr onnxruntime::Float8E5M2FNUZ lowest()
Definition: float8.h:860
static constexpr onnxruntime::Float8E5M2FNUZ epsilon()
Definition: float8.h:876
ORT_HOST_DEVICE float ToFloat() const
Definition: float8.h:277
static constexpr onnxruntime::Float8E4M3FNUZ epsilon()
Definition: float8.h:815
GLuint GLfloat * val
Definition: glcorearb.h:1608
static constexpr onnxruntime::Float8E5M2 min()
Definition: float8.h:747
static constexpr onnxruntime::Float8E4M3FNUZ quiet_NaN()
Definition: float8.h:828
constexpr ORT_HOST_DEVICE Float8E5M2(unsigned char bits, FromBitsT)
Definition: float8.h:363
void Float8E4M3FNToFloat(const Float8E4M3FN *blf, float *flt, size_t size)
Definition: float8.h:178
static constexpr onnxruntime::Float8E4M3FN round_error()
Definition: float8.h:698
void Float8E5M2FNUZToFloat(const Float8E5M2FNUZ *blf, float *flt, size_t size)
Definition: float8.h:655
static constexpr onnxruntime::Float8E5M2 infinity()
Definition: float8.h:763
static constexpr onnxruntime::Float8E5M2FNUZ quiet_NaN()
Definition: float8.h:889
static constexpr onnxruntime::Float8E4M3FN max()
Definition: float8.h:682
static constexpr onnxruntime::Float8E4M3FN lowest()
Definition: float8.h:678
GLenum src
Definition: glcorearb.h:1793
constexpr ORT_HOST_DEVICE Float8E5M2FNUZ(unsigned char bits, FromBitsT)
Definition: float8.h:533
ORT_HOST_DEVICE float ToFloat() const
Definition: float8.h:109