10 #include <type_traits>
13 #include "boost/mp11.hpp"
17 #ifndef SHARED_PROVIDER
18 #include "core/common/type_list.h"
20 #if !defined(ORT_MINIMAL_BUILD)
21 #include "onnx/defs/schema.h"
23 #include "onnx/defs/data_type_utils.h"
25 #include "onnx/onnx_pb.h"
26 #include "onnx/onnx-operators_pb.h"
29 namespace onnxruntime {
44 #define DispatchOnTensorType(tensor_type, function, ...) \
45 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
46 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
47 function<float>(__VA_ARGS__); \
49 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
50 function<bool>(__VA_ARGS__); \
52 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
53 function<double>(__VA_ARGS__); \
55 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
56 function<std::string>(__VA_ARGS__); \
58 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
59 function<int8_t>(__VA_ARGS__); \
61 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
62 function<uint8_t>(__VA_ARGS__); \
64 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
65 function<int16_t>(__VA_ARGS__); \
67 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
68 function<uint16_t>(__VA_ARGS__); \
70 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
71 function<int32_t>(__VA_ARGS__); \
73 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
74 function<uint32_t>(__VA_ARGS__); \
76 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
77 function<int64_t>(__VA_ARGS__); \
79 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
80 function<uint64_t>(__VA_ARGS__); \
82 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
83 function<MLFloat16>(__VA_ARGS__); \
85 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
86 function<BFloat16>(__VA_ARGS__); \
89 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
92 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
93 switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
94 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
95 retval = function<float>(__VA_ARGS__); \
97 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
98 retval = function<bool>(__VA_ARGS__); \
100 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
101 retval = function<double>(__VA_ARGS__); \
103 case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
104 retval = function<std::string>(__VA_ARGS__); \
106 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
107 retval = function<int8_t>(__VA_ARGS__); \
109 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
110 retval = function<uint8_t>(__VA_ARGS__); \
112 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
113 retval = function<uint16_t>(__VA_ARGS__); \
115 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
116 retval = function<int16_t>(__VA_ARGS__); \
118 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
119 retval = function<int32_t>(__VA_ARGS__); \
121 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
122 retval = function<uint32_t>(__VA_ARGS__); \
124 case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
125 retval = function<int64_t>(__VA_ARGS__); \
127 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
128 retval = function<uint64_t>(__VA_ARGS__); \
130 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
131 retval = function<MLFloat16>(__VA_ARGS__); \
133 case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
134 retval = function<BFloat16>(__VA_ARGS__); \
137 ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
147 return (prim_type !=
nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
155 return (prim_type !=
nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
162 assert(prim_type !=
nullptr);
163 return prim_type->
GetDataType() == ToTensorProtoElementType<T>();
168 namespace mltype_dispatcher_internal {
179 template <
class T,
class Fn,
class... Args>
181 if (utils::ToTensorProtoElementType<T>() == dt_type_) {
182 std::forward<Fn>(fn)(std::forward<Args>(
args)...);
189 ORT_ENFORCE(called_ == 1,
"Unsupported data type: ", dt_type_);
198 ORT_THROW(
"Unsupported data type: ", dt_type);
203 template <
class Ret,
class UnsupportedPolicy>
215 UnsupportedPolicy()(dt_type_, result_);
221 template <
class T,
class Fn,
class... Args>
223 if (utils::ToTensorProtoElementType<T>() == dt_type_) {
224 result_ = std::forward<Fn>(fn)(std::forward<Args>(
args)...);
231 template <
typename T>
233 std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
236 std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
258 template <
typename... Types>
260 using SupportedTypeList = TypeList<Types...>;
261 using SupportedTensorProtoElementTypeList =
262 boost::mp11::mp_transform<
267 boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
269 boost::mp11::mp_set_contains<
270 SupportedTensorProtoElementTypeList,
272 "Types must map to a unique set of ONNX tensor element data types supported by ORT.");
292 template <
template <
typename...>
class Fn,
typename... Args>
294 InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(
args)...);
306 template <
template <
typename...>
class Fn,
typename LeadingTemplateArgTypeList,
typename... Args>
310 "LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
316 static_cast<void>(std::array<
int,
sizeof...(Types)>{
317 helper.template Invoke<Types>(
318 boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
319 std::forward<Args>(
args)...)...});
334 template <
class Ret,
template <
typename...>
class Fn,
typename... Args>
338 std::forward<Args>(
args)...);
351 template <
class Ret,
template <
typename...>
class Fn,
class UnsupportedPolicy,
typename... Args>
354 Ret, Fn, UnsupportedPolicy, TypeList<>>(
355 std::forward<Args>(
args)...);
368 template <
class Ret,
template <
typename...>
class Fn,
typename LeadingTemplateArgTypeList,
typename... Args>
372 std::forward<Args>(
args)...);
389 template <
typename...>
class Fn,
390 class UnsupportedPolicy,
391 typename LeadingTemplateArgTypeList,
398 static_cast<void>(std::array<
int,
sizeof...(Types)>{
399 helper.template Invoke<Types>(
400 boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
401 std::forward<Args>(
args)...)...});
411 template <
typename L>
414 namespace data_types_internal {
440 prim_type_ =
static_cast<uint16_t
>(prim_type);
444 return type_ ==
type;
448 return prim_type_ ==
static_cast<uint16_t
>(prim_type);
476 using Cont = std::vector<data_types_internal::TypeNode>;
481 struct IsContainerOfType {
482 static bool check(
const Cont& c,
size_t index) {
483 if (index >= c.size()) {
492 struct IsContainerOfType<std::vector<T>> {
493 static bool check(
const Cont& c,
size_t index) {
494 if (index >= c.size()) {
498 ORT_ENFORCE(++index < c.size(),
"Sequence is missing type entry for its element");
499 constexpr int32_t prim_type = ToTensorProtoElementType<T>();
501 if constexpr(prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
503 c[index].IsPrimType(prim_type);
507 return IsContainerOfType<T>::check(c, index);
514 template <
class K,
class V>
515 struct IsContainerOfType<std::map<K, V>> {
516 static bool check(
const Cont& c,
size_t index) {
517 static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
518 "Map Key can not be a non-primitive type");
519 if (index >= c.size()) {
525 constexpr int32_t key_type = ToTensorProtoElementType<K>();
526 if (!c[index].IsPrimType(key_type)) {
529 ORT_ENFORCE(++index < c.size(),
"Map is missing type entry for its value");
530 constexpr int32_t val_type = ToTensorProtoElementType<V>();
531 if constexpr(val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
533 c[index].IsPrimType(val_type);
535 else return IsContainerOfType<V>::check(c, index);
544 assert(!types_.empty());
549 assert(!types_.empty());
555 assert(!types_.empty());
556 return IsContainerOfType<std::vector<T>>::check(types_, 0);
559 template <
class K,
class V>
561 assert(!types_.empty());
562 return IsContainerOfType<std::map<K, V>>::check(types_, 0);
typedef int(APIENTRYP RE_PFNGLXSWAPINTERVALSGIPROC)(int)
Base class for MLDataType.
MLTypeCallDispatcher(int32_t dt_type) noexcept
Ret InvokeRetWithUnsupportedPolicy(Args &&...args) const
int Invoke(Fn &&fn, Args &&...args)
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args &&...args) const
GLsizei const GLfloat * value
bool IsType(ContainerType type) const noexcept
#define ORT_ENFORCE(condition,...)
void operator()(int32_t dt_type, Ret &) const
Ret InvokeRetWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimType(int32_t prim_type) const noexcept
int32_t GetDataType() const
bool IsSequence() const noexcept
CallableDispatchableHelper(int32_t dt_type) noexcept
boost::mp11::mp_apply< MLTypeCallDispatcher, L > MLTypeCallDispatcherFromTypeList
void InvokeWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimitiveDataType(MLDataType dt_type)
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED > UndefinedTensorProtoElementTypeConstant
#define ORT_UNUSED_PARAMETER(x)
bool IsOpaqueType(MLDataType ml_type, const char *domain, const char *name)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
GLuint const GLchar * name
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType< T >()> TensorProtoElementTypeConstant
CallableDispatchableRetHelper(int32_t dt_type) noexcept
bool IsDataTypeString(MLDataType dt_type)
Use the following primitives if you have a few types to switch on so you.
TypeNode(ContainerType type, int32_t prim_type) noexcept
int Invoke(Fn &&fn, Args &&...args)
Ret InvokeRet(Args &&...args) const
**If you just want to fire and args
bool IsMap() const noexcept
void Invoke(Args &&...args) const
bool IsSequenceOf() const