10 #include <type_traits>
13 #include "boost/mp11.hpp"
17 #ifndef SHARED_PROVIDER
18 #include "core/common/type_list.h"
20 #include "core/graph/onnx_protobuf.h"
23 namespace onnxruntime {
38 #if !defined(DISABLE_FLOAT8_TYPES)
40 #define DispatchOnTensorType(tensor_type, function, ...) \
41 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
42 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
43 function<float>(__VA_ARGS__); \
45 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
46 function<bool>(__VA_ARGS__); \
48 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
49 function<double>(__VA_ARGS__); \
51 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
52 function<std::string>(__VA_ARGS__); \
54 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
55 function<int8_t>(__VA_ARGS__); \
57 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
58 function<uint8_t>(__VA_ARGS__); \
60 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
61 function<int16_t>(__VA_ARGS__); \
63 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
64 function<uint16_t>(__VA_ARGS__); \
66 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
67 function<int32_t>(__VA_ARGS__); \
69 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
70 function<uint32_t>(__VA_ARGS__); \
72 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
73 function<int64_t>(__VA_ARGS__); \
75 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
76 function<uint64_t>(__VA_ARGS__); \
78 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
79 function<MLFloat16>(__VA_ARGS__); \
81 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
82 function<BFloat16>(__VA_ARGS__); \
84 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
85 function<Float8E4M3FN>(__VA_ARGS__); \
87 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
88 function<Float8E4M3FNUZ>(__VA_ARGS__); \
90 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
91 function<Float8E5M2>(__VA_ARGS__); \
93 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
94 function<Float8E5M2FNUZ>(__VA_ARGS__); \
96 case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
97 function<Int4x2>(__VA_ARGS__); \
99 case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
100 function<UInt4x2>(__VA_ARGS__); \
103 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
106 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
107 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
108 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
109 retval = function<float>(__VA_ARGS__); \
111 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
112 retval = function<bool>(__VA_ARGS__); \
114 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
115 retval = function<double>(__VA_ARGS__); \
117 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
118 retval = function<std::string>(__VA_ARGS__); \
120 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
121 retval = function<int8_t>(__VA_ARGS__); \
123 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
124 retval = function<uint8_t>(__VA_ARGS__); \
126 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
127 retval = function<uint16_t>(__VA_ARGS__); \
129 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
130 retval = function<int16_t>(__VA_ARGS__); \
132 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
133 retval = function<int32_t>(__VA_ARGS__); \
135 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
136 retval = function<uint32_t>(__VA_ARGS__); \
138 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
139 retval = function<int64_t>(__VA_ARGS__); \
141 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
142 retval = function<uint64_t>(__VA_ARGS__); \
144 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
145 retval = function<MLFloat16>(__VA_ARGS__); \
147 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
148 retval = function<BFloat16>(__VA_ARGS__); \
150 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
151 retval = function<Float8E4M3FN>(__VA_ARGS__); \
153 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
154 retval = function<Float8E4M3FNUZ>(__VA_ARGS__); \
156 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
157 retval = function<Float8E5M2>(__VA_ARGS__); \
159 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
160 retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
162 case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
163 retval = function<Int4x2>(__VA_ARGS__); \
165 case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
166 retval = function<UInt4x2>(__VA_ARGS__); \
169 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
174 #define DispatchOnTensorType(tensor_type, function, ...) \
175 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
176 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
177 function<float>(__VA_ARGS__); \
179 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
180 function<bool>(__VA_ARGS__); \
182 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
183 function<double>(__VA_ARGS__); \
185 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
186 function<std::string>(__VA_ARGS__); \
188 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
189 function<int8_t>(__VA_ARGS__); \
191 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
192 function<uint8_t>(__VA_ARGS__); \
194 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
195 function<int16_t>(__VA_ARGS__); \
197 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
198 function<uint16_t>(__VA_ARGS__); \
200 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
201 function<int32_t>(__VA_ARGS__); \
203 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
204 function<uint32_t>(__VA_ARGS__); \
206 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
207 function<int64_t>(__VA_ARGS__); \
209 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
210 function<uint64_t>(__VA_ARGS__); \
212 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
213 function<MLFloat16>(__VA_ARGS__); \
215 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
216 function<BFloat16>(__VA_ARGS__); \
218 case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
219 function<Int4x2>(__VA_ARGS__); \
221 case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
222 function<UInt4x2>(__VA_ARGS__); \
225 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
228 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
229 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
230 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
231 retval = function<float>(__VA_ARGS__); \
233 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
234 retval = function<bool>(__VA_ARGS__); \
236 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
237 retval = function<double>(__VA_ARGS__); \
239 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
240 retval = function<std::string>(__VA_ARGS__); \
242 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
243 retval = function<int8_t>(__VA_ARGS__); \
245 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
246 retval = function<uint8_t>(__VA_ARGS__); \
248 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
249 retval = function<uint16_t>(__VA_ARGS__); \
251 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
252 retval = function<int16_t>(__VA_ARGS__); \
254 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
255 retval = function<int32_t>(__VA_ARGS__); \
257 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
258 retval = function<uint32_t>(__VA_ARGS__); \
260 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
261 retval = function<int64_t>(__VA_ARGS__); \
263 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
264 retval = function<uint64_t>(__VA_ARGS__); \
266 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
267 retval = function<MLFloat16>(__VA_ARGS__); \
269 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
270 retval = function<BFloat16>(__VA_ARGS__); \
272 case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
273 retval = function<Int4x2>(__VA_ARGS__); \
275 case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
276 retval = function<UInt4x2>(__VA_ARGS__); \
279 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
291 return (prim_type !=
nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
299 return (prim_type !=
nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
306 assert(prim_type !=
nullptr);
307 return prim_type->
GetDataType() == ToTensorProtoElementType<T>();
312 namespace mltype_dispatcher_internal {
323 template <
class T,
class Fn,
class... Args>
325 if (utils::ToTensorProtoElementType<T>() == dt_type_) {
326 std::forward<Fn>(fn)(std::forward<Args>(
args)...);
333 ORT_ENFORCE(called_ == 1,
"Unsupported data type: ", dt_type_);
342 ORT_THROW(
"Unsupported data type: ", dt_type);
347 template <
class Ret,
class UnsupportedPolicy>
359 UnsupportedPolicy()(dt_type_, result_);
365 template <
class T,
class Fn,
class... Args>
367 if (utils::ToTensorProtoElementType<T>() == dt_type_) {
368 result_ = std::forward<Fn>(fn)(std::forward<Args>(
args)...);
375 template <
typename T>
377 std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
380 std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
402 template <
typename... Types>
404 using SupportedTypeList = TypeList<Types...>;
405 using SupportedTensorProtoElementTypeList =
406 boost::mp11::mp_transform<
411 boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
413 boost::mp11::mp_set_contains<
414 SupportedTensorProtoElementTypeList,
416 "Types must map to a unique set of ONNX tensor element data types supported by ORT.");
436 template <
template <
typename...>
class Fn,
typename... Args>
438 InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(
args)...);
450 template <
template <
typename...>
class Fn,
typename LeadingTemplateArgTypeList,
typename... Args>
454 "LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
460 static_cast<void>(std::array<
int,
sizeof...(Types)>{
461 helper.template Invoke<Types>(
462 boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
463 std::forward<Args>(
args)...)...});
478 template <
class Ret,
template <
typename...>
class Fn,
typename... Args>
482 std::forward<Args>(
args)...);
495 template <
class Ret,
template <
typename...>
class Fn,
class UnsupportedPolicy,
typename... Args>
498 Ret, Fn, UnsupportedPolicy, TypeList<>>(
499 std::forward<Args>(
args)...);
512 template <
class Ret,
template <
typename...>
class Fn,
typename LeadingTemplateArgTypeList,
typename... Args>
516 std::forward<Args>(
args)...);
533 template <
typename...>
class Fn,
534 class UnsupportedPolicy,
535 typename LeadingTemplateArgTypeList,
542 static_cast<void>(std::array<
int,
sizeof...(Types)>{
543 helper.template Invoke<Types>(
544 boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
545 std::forward<Args>(
args)...)...});
555 template <
typename L>
558 namespace data_types_internal {
584 prim_type_ =
static_cast<uint16_t
>(prim_type);
588 return type_ ==
type;
592 return prim_type_ ==
static_cast<uint16_t
>(prim_type);
620 using Cont = std::vector<data_types_internal::TypeNode>;
625 struct IsContainerOfType {
626 static bool check(
const Cont&
c,
size_t index) {
627 if (index >= c.size()) {
636 struct IsContainerOfType<std::vector<T>> {
637 static bool check(
const Cont&
c,
size_t index) {
638 if (index >= c.size()) {
642 ORT_ENFORCE(++index < c.size(),
"Sequence is missing type entry for its element");
643 constexpr int32_t prim_type = ToTensorProtoElementType<T>();
645 if constexpr (prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
647 c[index].IsPrimType(prim_type);
650 return IsContainerOfType<T>::check(c, index);
657 template <
class K,
class V>
658 struct IsContainerOfType<std::map<K, V>> {
659 static bool check(
const Cont&
c,
size_t index) {
660 static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
661 "Map Key can not be a non-primitive type");
662 if (index >= c.size()) {
668 constexpr int32_t key_type = ToTensorProtoElementType<K>();
669 if (!c[index].IsPrimType(key_type)) {
672 ORT_ENFORCE(++index < c.size(),
"Map is missing type entry for its value");
673 constexpr int32_t val_type = ToTensorProtoElementType<V>();
674 if constexpr (val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
676 c[index].IsPrimType(val_type);
678 return IsContainerOfType<V>::check(c, index);
687 assert(!types_.empty());
692 assert(!types_.empty());
698 assert(!types_.empty());
699 return IsContainerOfType<std::vector<T>>::check(types_, 0);
702 template <
class K,
class V>
704 assert(!types_.empty());
705 return IsContainerOfType<std::map<K, V>>::check(types_, 0);
typedef int(APIENTRYP RE_PFNGLXSWAPINTERVALSGIPROC)(int)
Base class for MLDataType.
MLTypeCallDispatcher(int32_t dt_type) noexcept
Ret InvokeRetWithUnsupportedPolicy(Args &&...args) const
int Invoke(Fn &&fn, Args &&...args)
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args &&...args) const
GLsizei const GLfloat * value
bool IsType(ContainerType type) const noexcept
#define ORT_ENFORCE(condition,...)
void operator()(int32_t dt_type, Ret &) const
Ret InvokeRetWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimType(int32_t prim_type) const noexcept
int32_t GetDataType() const
bool IsSequence() const noexcept
CallableDispatchableHelper(int32_t dt_type) noexcept
boost::mp11::mp_apply< MLTypeCallDispatcher, L > MLTypeCallDispatcherFromTypeList
GLint GLint GLsizei GLint GLenum GLenum type
void InvokeWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimitiveDataType(MLDataType dt_type)
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED > UndefinedTensorProtoElementTypeConstant
#define ORT_UNUSED_PARAMETER(x)
bool IsOpaqueType(MLDataType ml_type, const char *domain, const char *name)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
GLuint const GLchar * name
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType< T >()> TensorProtoElementTypeConstant
CallableDispatchableRetHelper(int32_t dt_type) noexcept
bool IsDataTypeString(MLDataType dt_type)
Use the following primitives if you have a few types to switch on so you.
void CheckCalledOnce() const
TypeNode(ContainerType type, int32_t prim_type) noexcept
int Invoke(Fn &&fn, Args &&...args)
Ret InvokeRet(Args &&...args) const
**If you just want to fire and args
bool IsMap() const noexcept
void Invoke(Args &&...args) const
bool IsSequenceOf() const