6 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
10 #if !defined(__CUDACC__) && !defined(__HIPCC__)
16 namespace onnxruntime {
18 #if defined(__CUDACC__) || defined(__HIPCC__)
19 #define ORT_HOST_DEVICE __host__ __device__
21 #define ORT_HOST_DEVICE
55 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
56 val = __bfloat16_as_ushort(__float2bfloat16(v));
57 #elif defined(__HIP__)
60 val = UINT16_C(0x7FC0);
68 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
69 val =
static_cast<uint16_t
>((U32 + rounding_bias) >> 16);
73 std::memcpy(&
val, reinterpret_cast<char*>(&v) +
sizeof(uint16_t),
sizeof(uint16_t));
76 std::memcpy(&
val, &v,
sizeof(uint16_t));
82 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
83 return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&
val));
84 #elif defined(__HIP__)
89 float* tempRes =
reinterpret_cast<float*
>(&tmp);
94 char*
const first =
reinterpret_cast<char*
>(&
result);
95 char*
const second = first +
sizeof(uint16_t);
97 std::memset(first, 0,
sizeof(uint16_t));
98 std::memcpy(second, &
val,
sizeof(uint16_t));
101 std::memcpy(first, &
val,
sizeof(uint16_t));
102 std::memset(second, 0,
sizeof(uint16_t));
110 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
112 explicit ORT_HOST_DEVICE operator __nv_bfloat16()
const {
return *
reinterpret_cast<const __nv_bfloat16*
>(&
val); }
124 #if !defined(__CUDACC__) && !defined(__HIPCC__)
133 inline BFloat16 operator"" _b16(
unsigned long long int v) {
138 return BFloat16(static_cast<float>(v));
146 for (; size != 0; ++
src, ++d, --
size) {
154 for (; size != 0; ++
src, ++d, --
size) {
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT)
void BFloat16ToFloat(const BFloat16 *blf, float *flt, size_t size)
GLsizei const GLfloat * value
bool operator!=(const MLFloat16 &left, const MLFloat16 &right)
ORT_HOST_DEVICE BFloat16(float v)
**But if you need a result
void FloatToBFloat16(const float *flt, BFloat16 *blf, size_t size)
bool operator<(const MLFloat16 &left, const MLFloat16 &right)
bool operator==(const MLFloat16 &left, const MLFloat16 &right)
constexpr MLFloat16(uint16_t x)
static constexpr ORT_HOST_DEVICE FromBitsT FromBits()
ORT_HOST_DEVICE float ToFloat() const