6 #if !defined(DISABLE_FLOAT8_TYPES)
9 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
13 #if !defined(__CUDACC__) && !defined(__HIPCC__)
19 namespace onnxruntime {
21 #if defined(__CUDACC__) || defined(__HIPCC__)
22 #define ORT_HOST_DEVICE __host__ __device__
24 #define ORT_HOST_DEVICE
40 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
41 val = __nv_cvt_float_to_fp8(v,
saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E4M3);
44 std::memcpy(&b, &v,
sizeof(b));
46 val =
static_cast<uint8_t
>((b & 0x80000000) >> 24);
47 if ((b & 0x7fffffff) == 0x7f800000) {
53 }
else if ((b & 0x7F800000) == 0x7F800000) {
56 uint8_t e =
static_cast<uint8_t
>((b & 0x7F800000) >> 23);
57 uint32_t m =
static_cast<uint32_t
>(b & 0x007FFFFF);
69 auto mask = 1 << (20 + d);
70 if ((m &
mask) && ((
val & 1) || ((m & (mask - 1)) > 0) || ((m &
mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
83 if ((
val & 0x7F) == 0x7F) {
87 if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
88 if ((
val & 0x7F) < 0x7E) {
106 return (
val & 0b01111111) == 0b01111111;
110 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
111 return __half2float(__nv_cvt_fp8_to_halfraw(
val, __NV_E4M3));
116 }
else if (
val == 127) {
119 uint32_t expo = (
val & 0x78) >> 3;
120 uint32_t mant =
val & 0x07;
126 if ((mant & 0x4) == 0) {
131 if ((mant & 0x4) == 0) {
136 res |= (mant & 0x3) << 21;
147 std::memcpy(&float_res, &res,
sizeof(
float));
154 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
156 explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3()
const {
return *
reinterpret_cast<const __nv_fp8_e4m3*
>(&
val); }
166 #if !defined(__CUDACC__) && !defined(__HIPCC__)
181 for (; size != 0; ++
src, ++d, --
size) {
189 for (; size != 0; ++
src, ++d, --
size) {
210 std::memcpy(&b, &v,
sizeof(b));
212 val =
static_cast<uint8_t
>((b & 0x80000000) >> 24);
213 if ((b & 0x7fffffff) == 0x7f800000) {
221 }
else if ((b & 0x7F800000) == 0x7F800000) {
224 uint8_t e =
static_cast<uint8_t
>((b & 0x7F800000) >> 23);
225 uint32_t m =
static_cast<uint32_t
>(b & 0x007FFFFF);
230 }
else if (e < 120) {
235 val |= m >> (21 + d);
242 auto mask = 1 << (20 + d);
243 if ((m &
mask) && ((
val & 1) || ((m & (mask - 1)) > 0) || ((m &
mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
247 }
else if (e < 135) {
257 if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
258 if ((
val & 0x7F) < 0x7F) {
274 return val == 0b10000000;
283 uint32_t expo = (
val & 0x78) >> 3;
284 uint32_t mant =
val & 0x07;
290 if ((mant & 0x4) == 0) {
295 if ((mant & 0x4) == 0) {
300 res |= (mant & 0x3) << 21;
311 std::memcpy(&float_res, &res,
sizeof(
float));
324 #if !defined(__CUDACC__) && !defined(__HIPCC__)
339 for (; size != 0; ++
src, ++d, --
size) {
347 for (; size != 0; ++
src, ++d, --
size) {
366 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
367 val = __nv_cvt_float_to_fp8(v,
saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E5M2);
370 std::memcpy(&b, &v,
sizeof(b));
372 val = (b & 0x80000000) >> 24;
373 if ((b & 0x7FFFFFFF) == 0x7F800000) {
381 }
else if ((b & 0x7F800000) == 0x7F800000) {
384 uint32_t e = (b & 0x7F800000) >> 23;
385 uint32_t m = b & 0x007FFFFF;
389 }
else if (e < 113) {
394 val |= m >> (22 + d);
398 auto mask = 1 << (21 + d);
399 if ((m &
mask) && ((
val & 1) || ((m & (mask - 1)) > 0) || ((m &
mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
403 }
else if (e < 143) {
407 if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
408 if ((
val & 0x7F) < 0x7B) {
429 return (
val & 0b01111111) > 0b01111100;
434 return (
val & 0b01111111) == 0b01111100;
438 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
439 return __half2float(__nv_cvt_fp8_to_halfraw(
val, __NV_E5M2));
444 }
else if (
val >= 125 &&
val <= 127) {
446 }
else if (
val == 252) {
448 }
else if (
val == 124) {
451 uint32_t expo = (
val & 0x7C) >> 2;
452 uint32_t mant =
val & 0x03;
458 if ((mant & 0
x2) == 0) {
463 res |= (mant & 0x1) << 22;
475 std::memcpy(&float_res, &res,
sizeof(
float));
482 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
484 explicit ORT_HOST_DEVICE operator __nv_fp8_e5m2()
const {
return *
reinterpret_cast<const __nv_fp8_e5m2*
>(&
val); }
494 #if !defined(__CUDACC__) && !defined(__HIPCC__)
496 inline Float8E5M2 operator"" _f8e5m2fn(
unsigned long long int v) {
501 return Float8E5M2(static_cast<float>(v),
true);
509 for (; size != 0; ++
src, ++d, --
size) {
517 for (; size != 0; ++
src, ++d, --
size) {
538 std::memcpy(&b, &v,
sizeof(b));
540 val = (b & 0x80000000) >> 24;
541 if ((b & 0x7FFFFFFF) == 0x7F800000) {
547 }
else if ((b & 0x7F800000) == 0x7F800000) {
550 uint32_t e = (b & 0x7F800000) >> 23;
551 uint32_t m = b & 0x007FFFFF;
556 }
else if (e < 112) {
561 val |= m >> (22 + d);
568 auto mask = 1 << (21 + d);
569 if ((m &
mask) && ((
val & 1) || ((m & (mask - 1)) > 0) || ((m &
mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
573 }
else if (e < 143) {
578 if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
579 if ((
val & 0x7F) < 0x7F) {
586 }
else if ((e == 255) && (m == 0)) {
597 return val == 0b10000000;
606 uint32_t expo = (
val & 0x7C) >> 2;
607 uint32_t mant =
val & 0x03;
613 if ((mant & 0
x2) == 0) {
618 res |= (mant & 0x1) << 22;
630 std::memcpy(&float_res, &res,
sizeof(
float));
643 #if !defined(__CUDACC__) && !defined(__HIPCC__)
658 for (; size != 0; ++
src, ++d, --
size) {
666 for (; size != 0; ++
src, ++d, --
size) {
676 class numeric_limits<onnxruntime::Float8E4M3FN> {
711 static constexpr
bool is_specialized =
true;
714 static constexpr
bool is_exact =
false;
715 static constexpr
bool has_infinity =
false;
716 static constexpr
bool has_quiet_NaN =
true;
717 static constexpr
bool has_signaling_NaN =
false;
718 static constexpr
auto has_denorm =
true;
719 static constexpr
auto has_denorm_loss =
true;
720 static constexpr
auto round_style = round_to_nearest;
721 static constexpr
bool is_iec559 =
false;
722 static constexpr
bool is_bounded =
true;
723 static constexpr
bool is_modulo =
false;
724 static constexpr
int digits = 4;
726 static constexpr
int max_digits10 = 3;
727 static constexpr
int radix = 2;
728 static constexpr
int min_exponent = -5;
729 static constexpr
int min_exponent10 = -1;
730 static constexpr
int max_exponent = 8;
731 static constexpr
int max_exponent10 = 2;
732 static constexpr
auto traps =
false;
733 static constexpr
auto tinyness_before =
false;
737 class numeric_limits<onnxruntime::Float8E5M2> {
771 static constexpr
bool is_specialized =
true;
774 static constexpr
bool is_exact =
false;
775 static constexpr
bool has_infinity =
true;
776 static constexpr
bool has_quiet_NaN =
true;
777 static constexpr
bool has_signaling_NaN =
false;
778 static constexpr
auto has_denorm =
true;
779 static constexpr
auto has_denorm_loss =
true;
780 static constexpr
auto round_style = round_to_nearest;
781 static constexpr
bool is_iec559 =
false;
782 static constexpr
bool is_bounded =
true;
783 static constexpr
bool is_modulo =
false;
784 static constexpr
int digits = 3;
786 static constexpr
int max_digits10 = 2;
787 static constexpr
int radix = 2;
788 static constexpr
int min_exponent = -13;
789 static constexpr
int min_exponent10 = -4;
790 static constexpr
int max_exponent = 16;
791 static constexpr
int max_exponent10 = 4;
792 static constexpr
auto traps =
false;
793 static constexpr
auto tinyness_before =
false;
797 class numeric_limits<onnxruntime::Float8E4M3FNUZ> {
832 static constexpr
bool is_specialized =
true;
835 static constexpr
bool is_exact =
false;
836 static constexpr
bool has_infinity =
false;
837 static constexpr
bool has_quiet_NaN =
true;
838 static constexpr
bool has_signaling_NaN =
false;
839 static constexpr
auto has_denorm =
true;
840 static constexpr
auto has_denorm_loss =
true;
841 static constexpr
auto round_style = round_to_nearest;
842 static constexpr
bool is_iec559 =
false;
843 static constexpr
bool is_bounded =
true;
844 static constexpr
bool is_modulo =
false;
845 static constexpr
int digits = 4;
847 static constexpr
int max_digits10 = 3;
848 static constexpr
int radix = 2;
849 static constexpr
int min_exponent = -6;
850 static constexpr
int min_exponent10 = -1;
851 static constexpr
int max_exponent = 8;
852 static constexpr
int max_exponent10 = 2;
853 static constexpr
auto traps =
false;
854 static constexpr
auto tinyness_before =
false;
858 class numeric_limits<onnxruntime::Float8E5M2FNUZ> {
893 static constexpr
bool is_specialized =
true;
896 static constexpr
bool is_exact =
false;
897 static constexpr
bool has_infinity =
false;
898 static constexpr
bool has_quiet_NaN =
true;
899 static constexpr
bool has_signaling_NaN =
false;
900 static constexpr
auto has_denorm =
true;
901 static constexpr
auto has_denorm_loss =
true;
902 static constexpr
auto round_style = round_to_nearest;
903 static constexpr
bool is_iec559 =
false;
904 static constexpr
bool is_bounded =
true;
905 static constexpr
bool is_modulo =
false;
906 static constexpr
int digits = 3;
908 static constexpr
int max_digits10 = 2;
909 static constexpr
int radix = 2;
910 static constexpr
int min_exponent = -14;
911 static constexpr
int min_exponent10 = -4;
912 static constexpr
int max_exponent = 16;
913 static constexpr
int max_exponent10 = 4;
914 static constexpr
auto traps =
false;
915 static constexpr
auto tinyness_before =
false;
920 #endif // DISABLE_FLOAT8_TYPES
static constexpr onnxruntime::Float8E4M3FN infinity()
static constexpr onnxruntime::Float8E5M2 denorm_min()
bool_constant< is_integral< T >::value &&!std::is_same< T, bool >::value &&!std::is_same< T, char >::value &&!std::is_same< T, wchar_t >::value > is_integer
void FloatToFloat8E4M3FNUZ(const float *flt, Float8E4M3FNUZ *blf, size_t size, bool saturate)
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
static constexpr onnxruntime::Float8E5M2 lowest()
static constexpr onnxruntime::Float8E5M2FNUZ round_error()
static constexpr onnxruntime::Float8E4M3FNUZ lowest()
static constexpr onnxruntime::Float8E4M3FN denorm_min()
GLsizei const GLfloat * value
ORT_HOST_DEVICE bool operator<(const Float8E4M3FN &left, const Float8E4M3FN &right)
static constexpr onnxruntime::Float8E5M2FNUZ infinity()
ORT_HOST_DEVICE float ToFloat() const
static constexpr onnxruntime::Float8E4M3FN epsilon()
ORT_HOST_DEVICE float ToFloat() const
void Float8E4M3FNUZToFloat(const Float8E4M3FNUZ *blf, float *flt, size_t size)
ORT_HOST_DEVICE Float8E4M3FNUZ(float v, bool saturate=true)
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
static constexpr onnxruntime::Float8E4M3FN quiet_NaN()
static constexpr onnxruntime::Float8E4M3FNUZ max()
static constexpr onnxruntime::Float8E4M3FNUZ denorm_min()
static constexpr onnxruntime::Float8E5M2FNUZ min()
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
static constexpr onnxruntime::Float8E4M3FNUZ infinity()
constexpr ORT_HOST_DEVICE Float8E4M3FN(unsigned char bits, FromBitsT)
ORT_HOST_DEVICE bool operator==(const Float8E4M3FN &left, const Float8E4M3FN &right)
ORT_HOST_DEVICE bool IsNaN() const
ORT_HOST_DEVICE bool IsNaN() const
ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN &left, const Float8E4M3FN &right)
void FloatToFloat8E4M3FN(const float *flt, Float8E4M3FN *blf, size_t size, bool saturate)
static constexpr onnxruntime::Float8E5M2 quiet_NaN()
ORT_HOST_DEVICE Float8E5M2FNUZ(float v, bool saturate=true)
static constexpr onnxruntime::Float8E5M2 epsilon()
static constexpr onnxruntime::Float8E4M3FNUZ min()
static constexpr onnxruntime::Float8E5M2 max()
static constexpr onnxruntime::Float8E4M3FNUZ round_error()
constexpr ORT_HOST_DEVICE Float8E4M3FNUZ(unsigned char bits, FromBitsT)
ORT_HOST_DEVICE Float8E5M2(float v, bool saturate=true)
ImageBuf OIIO_API saturate(const ImageBuf &src, float scale=0.0f, int firstchannel=0, ROI roi={}, int nthreads=0)
IMATH_HOSTDEVICE constexpr int sign(T a) IMATH_NOEXCEPT
GLboolean GLboolean GLboolean b
static constexpr onnxruntime::Float8E5M2FNUZ denorm_min()
ORT_HOST_DEVICE bool IsNaN() const
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
static constexpr onnxruntime::Float8E4M3FN min()
void FloatToFloat8E5M2(const float *flt, Float8E5M2 *blf, size_t size, bool saturate)
static constexpr onnxruntime::Float8E5M2FNUZ max()
ORT_HOST_DEVICE bool IsInfinity() const
ORT_HOST_DEVICE bool IsNaN() const
ORT_HOST_DEVICE Float8E4M3FN(float v, bool saturate=true)
std::integral_constant< bool, std::numeric_limits< T >::is_signed||std::is_same< T, int128_opt >::value > is_signed
IMATH_NAMESPACE::V2f IMATH_NAMESPACE::Box2i std::string this attribute is obsolete as of OpenEXR v3 float
static constexpr onnxruntime::Float8E5M2 round_error()
void Float8E5M2ToFloat(const Float8E5M2 *blf, float *flt, size_t size)
void FloatToFloat8E5M2FNUZ(const float *flt, Float8E5M2FNUZ *blf, size_t size, bool saturate)
static constexpr onnxruntime::Float8E5M2FNUZ lowest()
static constexpr onnxruntime::Float8E5M2FNUZ epsilon()
ORT_HOST_DEVICE float ToFloat() const
static constexpr onnxruntime::Float8E4M3FNUZ epsilon()
static constexpr onnxruntime::Float8E5M2 min()
static constexpr onnxruntime::Float8E4M3FNUZ quiet_NaN()
constexpr ORT_HOST_DEVICE Float8E5M2(unsigned char bits, FromBitsT)
void Float8E4M3FNToFloat(const Float8E4M3FN *blf, float *flt, size_t size)
static constexpr onnxruntime::Float8E4M3FN round_error()
void Float8E5M2FNUZToFloat(const Float8E5M2FNUZ *blf, float *flt, size_t size)
static constexpr onnxruntime::Float8E5M2 infinity()
static constexpr onnxruntime::Float8E5M2FNUZ quiet_NaN()
static constexpr onnxruntime::Float8E4M3FN max()
static constexpr onnxruntime::Float8E4M3FN lowest()
constexpr ORT_HOST_DEVICE Float8E5M2FNUZ(unsigned char bits, FromBitsT)
ORT_HOST_DEVICE float ToFloat() const