HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
to_tensor_proto_element_type.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <cstdint>
7 #include <string>
8 
9 #ifndef SHARED_PROVIDER
10 #include "core/graph/onnx_protobuf.h"
11 #endif
12 
13 #include "core/framework/float8.h"
14 #include "core/framework/float16.h"
15 
16 namespace onnxruntime {
17 namespace utils {
18 /** Gets the TensorProto_DataType corresponding to the template type `T`. */
19 template <typename T>
20 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() {
21  return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
22 }
23 template <>
24 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<float>() {
25  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
26 }
27 template <>
28 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<uint8_t>() {
29  return ONNX_NAMESPACE::TensorProto_DataType_UINT8;
30 }
31 template <>
32 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<int8_t>() {
33  return ONNX_NAMESPACE::TensorProto_DataType_INT8;
34 }
35 template <>
36 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<uint16_t>() {
37  return ONNX_NAMESPACE::TensorProto_DataType_UINT16;
38 }
39 template <>
40 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<int16_t>() {
41  return ONNX_NAMESPACE::TensorProto_DataType_INT16;
42 }
43 template <>
44 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<int32_t>() {
45  return ONNX_NAMESPACE::TensorProto_DataType_INT32;
46 }
47 template <>
48 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<int64_t>() {
49  return ONNX_NAMESPACE::TensorProto_DataType_INT64;
50 }
51 template <>
52 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<std::string>() {
53  return ONNX_NAMESPACE::TensorProto_DataType_STRING;
54 }
55 template <>
56 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<bool>() {
57  return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
58 }
59 template <>
60 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<MLFloat16>() {
61  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
62 }
63 template <>
64 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<double>() {
65  return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
66 }
67 template <>
68 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<uint32_t>() {
69  return ONNX_NAMESPACE::TensorProto_DataType_UINT32;
70 }
71 template <>
72 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<uint64_t>() {
73  return ONNX_NAMESPACE::TensorProto_DataType_UINT64;
74 }
75 template <>
76 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<BFloat16>() {
77  return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
78 }
79 
80 #if !defined(DISABLE_FLOAT8_TYPES)
81 
82 template <>
83 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Float8E4M3FN>() {
84  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN;
85 }
86 template <>
87 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Float8E4M3FNUZ>() {
88  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ;
89 }
90 template <>
91 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Float8E5M2>() {
92  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
93 }
94 template <>
95 constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType<Float8E5M2FNUZ>() {
96  return ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ;
97 }
98 
99 #endif
100 
101 } // namespace utils
102 } // namespace onnxruntime
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< uint8_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< Float8E4M3FN >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< Float8E4M3FNUZ >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< uint32_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< Float8E5M2FNUZ >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< bool >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< int32_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< uint16_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< uint64_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< int8_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< MLFloat16 >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< Float8E5M2 >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< int64_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< int16_t >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< BFloat16 >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< double >()
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType< float >()