HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
float16.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 #pragma once
4 
5 #include "endian.h"
6 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
7 #include "cuda_bf16.h"
8 #endif
9 
10 #if !defined(__CUDACC__) && !defined(__HIPCC__)
11 #include "core/common/narrow.h"
12 #endif
13 
14 #include "core/common/common.h"
15 
16 namespace onnxruntime {
17 
18 #if defined(__CUDACC__) || defined(__HIPCC__)
19 #define ORT_HOST_DEVICE __host__ __device__
20 #else
21 #define ORT_HOST_DEVICE
22 #endif
23 
24 // MLFloat16
25 struct MLFloat16 {
26  uint16_t val{0};
27 
28  MLFloat16() = default;
29  explicit constexpr MLFloat16(uint16_t x) : val(x) {}
30  explicit MLFloat16(float f);
31 
32  float ToFloat() const;
33 
34  operator float() const { return ToFloat(); }
35 };
36 
37 inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
38 inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
39 inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }
40 
41 // BFloat16
42 struct BFloat16 {
43  uint16_t val{0};
44 #if defined(__HIP__)
45  ORT_HOST_DEVICE BFloat16() = default;
46 #else
47  BFloat16() = default;
48 #endif
49 
50  struct FromBitsT {};
51  static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
52  constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits) {}
53 
54  inline ORT_HOST_DEVICE BFloat16(float v) {
55 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
56  val = __bfloat16_as_ushort(__float2bfloat16(v));
57 #elif defined(__HIP__)
58  // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
59  if (v != v) { // isnan
60  val = UINT16_C(0x7FC0);
61  } else {
62  union {
63  uint32_t U32;
64  float F32;
65  };
66 
67  F32 = v;
68  uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
69  val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
70  }
71 #else
72  if constexpr(endian::native == endian::little) {
73  std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
74  }
75  else {
76  std::memcpy(&val, &v, sizeof(uint16_t));
77  }
78 #endif
79  }
80 
81  inline ORT_HOST_DEVICE float ToFloat() const {
82 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
83  return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
84 #elif defined(__HIP__)
85  // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
86  float result = 0;
87  uint32_t tmp = val;
88  tmp <<= 16;
89  float* tempRes = reinterpret_cast<float*>(&tmp);
90  result = *tempRes;
91  return result;
92 #else
93  float result;
94  char* const first = reinterpret_cast<char*>(&result);
95  char* const second = first + sizeof(uint16_t);
96  if constexpr(endian::native == endian::little) {
97  std::memset(first, 0, sizeof(uint16_t));
98  std::memcpy(second, &val, sizeof(uint16_t));
99  }
100  else {
101  std::memcpy(first, &val, sizeof(uint16_t));
102  std::memset(second, 0, sizeof(uint16_t));
103  }
104  return result;
105 #endif
106  }
107 
108  inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
109 
110 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
111  ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
112  explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
113 #endif
114 };
115 
116 inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& right) { return left.val == right.val; }
117 inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; }
118 inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; }
119 
120 
121 // User defined suffixes to make it easier to declare
122 // initializers with MLFloat16 and BFloat16 from unsigned short
123 // E.g 10_f16 or 10_b16
124 #if !defined(__CUDACC__) && !defined(__HIPCC__)
125 inline MLFloat16 operator"" _f16(unsigned long long int v) {
126  return MLFloat16(narrow<uint16_t>(v));
127 }
128 
129 inline MLFloat16 operator"" _fp16(long double v) {
130  return MLFloat16(static_cast<float>(v));
131 }
132 
133 inline BFloat16 operator"" _b16(unsigned long long int v) {
134  return BFloat16(narrow<uint16_t>(v), BFloat16::FromBits());
135 }
136 
137 inline BFloat16 operator"" _bfp16(long double v) {
138  return BFloat16(static_cast<float>(v));
139 }
140 
141 #endif
142 
143 inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
144  auto src = blf;
145  auto d = flt;
146  for (; size != 0; ++src, ++d, --size) {
147  *d = src->ToFloat();
148  }
149 }
150 
151 inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
152  auto src = flt;
153  auto d = blf;
154  for (; size != 0; ++src, ++d, --size) {
155  new (d) BFloat16(*src);
156  }
157 }
158 
159 } // namespace onnxruntime
GLint first
Definition: glcorearb.h:405
float ToFloat() const
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT)
Definition: float16.h:52
GLint left
Definition: glcorearb.h:2005
const GLdouble * v
Definition: glcorearb.h:837
void BFloat16ToFloat(const BFloat16 *blf, float *flt, size_t size)
Definition: float16.h:143
GLsizei const GLfloat * value
Definition: glcorearb.h:824
bool operator!=(const MLFloat16 &left, const MLFloat16 &right)
Definition: float16.h:38
ORT_HOST_DEVICE BFloat16(float v)
Definition: float16.h:54
GLdouble right
Definition: glad.h:2817
**But if you need a result
Definition: thread.h:613
void FloatToBFloat16(const float *flt, BFloat16 *blf, size_t size)
Definition: float16.h:151
GLfloat f
Definition: glcorearb.h:1926
#define ORT_HOST_DEVICE
Definition: float16.h:21
bool operator<(const MLFloat16 &left, const MLFloat16 &right)
Definition: float16.h:39
GLint GLenum GLint x
Definition: glcorearb.h:409
bool operator==(const MLFloat16 &left, const MLFloat16 &right)
Definition: float16.h:37
GLsizeiptr size
Definition: glcorearb.h:664
GLuint GLfloat * val
Definition: glcorearb.h:1608
constexpr MLFloat16(uint16_t x)
Definition: float16.h:29
Definition: core.h:1131
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
Definition: float16.h:51
ORT_HOST_DEVICE float ToFloat() const
Definition: float16.h:81
GLenum src
Definition: glcorearb.h:1793