11 #include <unordered_map>
18 #if !defined(ORT_MINIMAL_BUILD)
19 #include "onnx/defs/schema.h"
21 #include "onnx/defs/data_type_utils.h"
23 #include "onnx/onnx_pb.h"
24 #include "onnx/onnx-operators_pb.h"
28 namespace ONNX_NAMESPACE {
32 namespace onnxruntime {
35 #if !defined(DISABLE_ML_OPS)
59 #if !defined(DISABLE_SPARSE_TENSORS)
64 #if !defined(DISABLE_OPTIONAL_TYPE)
109 virtual bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const = 0;
121 virtual const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const = 0;
152 #if !defined(DISABLE_SPARSE_TENSORS)
157 #if !defined(DISABLE_OPTIONAL_TYPE)
168 template <
typename T>
172 template <
typename elemT>
175 template <
typename elemT>
178 #if !defined(DISABLE_SPARSE_TENSORS)
180 template <
typename elemT>
184 template <
typename T,
typename elemT>
198 #if !defined(DISABLE_SPARSE_TENSORS)
203 static std::vector<std::string>
ToString(
const std::vector<MLDataType>&
types);
229 namespace data_types_internal {
236 template <
typename T,
typename... Types>
240 template <
typename T,
typename Tail>
241 struct IsAnyOf<
T, Tail> :
public std::is_same<T, Tail> {
244 template <
typename T,
typename H,
typename... Tail>
254 template <
typename T>
256 int32_t, int64_t, std::string, bool, MLFloat16,
257 double, uint32_t, uint64_t, BFloat16> {
260 #if !defined(DISABLE_SPARSE_TENSORS)
264 template <
typename T>
266 int32_t, int64_t, std::string, bool, MLFloat16,
267 double, uint32_t, uint64_t, BFloat16> {
271 #if !defined(DISABLE_OPTIONAL_TYPE)
274 template <
typename T>
282 template <
typename T,
bool TensorContainedType>
285 template <
typename T>
288 return DataTypeImpl::GetTensorType<T>();
292 template <
typename T>
295 return DataTypeImpl::GetType<T>();
300 static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
301 ONNX_NAMESPACE::TypeProto& proto) {
302 proto.mutable_tensor_type()->set_elem_type(element_type);
306 #if !defined(DISABLE_SPARSE_TENSORS)
308 static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
309 ONNX_NAMESPACE::TypeProto& proto) {
310 proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
313 #endif // !defined(DISABLE_SPARSE_TENSORS)
315 #if !defined(DISABLE_ML_OPS)
319 ONNX_NAMESPACE::TypeProto&);
324 template <
typename V>
329 static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type,
const ONNX_NAMESPACE::TypeProto* value_proto,
330 ONNX_NAMESPACE::TypeProto& proto) {
331 ORT_ENFORCE(value_proto !=
nullptr,
"expected a registered ONNX type");
332 proto.mutable_map_type()->set_key_type(key_type);
342 ONNX_NAMESPACE::TypeProto&);
346 template <
typename T>
351 static void Set(
const ONNX_NAMESPACE::TypeProto* elem_proto,
352 ONNX_NAMESPACE::TypeProto& proto) {
353 ORT_ENFORCE(elem_proto !=
nullptr,
"expected a registered ONNX type");
361 ONNX_NAMESPACE::TypeProto&);
365 template <
typename T,
typename elemT>
368 return DataTypeImpl::GetTensorType<elemT>();
371 return DataTypeImpl::GetSequenceTensorType<elemT>();
375 static void Set(
const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
376 ORT_ENFORCE(elem_proto !=
nullptr,
"expected a registered ONNX type");
384 ONNX_NAMESPACE::TypeProto& proto);
390 #if defined(_MSC_VER) && !defined(__clang__)
391 #pragma warning(push)
392 #pragma warning(disable : 26436)
402 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
406 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
438 template <
typename elemT>
442 "Requires one of the tensor fundamental types");
449 return DataTypeImpl::GetType<elemT>();
454 using namespace data_types_internal;
455 TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(),
MutableTypeProto());
459 #if defined(DISABLE_OPTIONAL_TYPE)
468 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto&)
const override {
477 ORT_THROW(
"Type is disabled in this build.");
481 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
490 ~DisabledTypeBase()
override;
499 #if !defined(DISABLE_SPARSE_TENSORS)
505 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
509 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
529 template <
typename elemT>
533 "Requires one of the sparse-tensor fundamental types");
539 return DataTypeImpl::GetType<elemT>();
544 using namespace data_types_internal;
545 SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(),
MutableTypeProto());
549 #endif // !defined(DISABLE_SPARSE_TENSORS)
553 #if !defined(DISABLE_OPTIONAL_TYPE)
558 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
565 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
589 template <
typename T,
typename elemT>
591 #if !defined(DISABLE_OPTIONAL_TYPE)
594 public DisabledTypeBase
600 #if !defined(DISABLE_OPTIONAL_TYPE)
602 "Requires one of the supported types: Tensor or TensorSeq");
605 "Requires one of the tensor fundamental types");
608 return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
613 #if !defined(DISABLE_OPTIONAL_TYPE)
619 using namespace data_types_internal;
652 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
664 virtual void FromDataContainer(
const void*
data,
size_t data_size,
OrtValue& output)
const;
674 virtual void ToDataContainer(
const OrtValue& input,
size_t data_size,
void* data)
const;
685 bool IsMapCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const;
687 bool IsSequenceCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const;
689 bool IsOpaqueCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const;
697 template <
typename T>
700 static void Delete(
void* p) {
701 delete static_cast<T*
>(p);
710 return []() ->
void* {
return new T(); };
717 #if !defined(DISABLE_ML_OPS)
728 template <
typename CPPType>
732 "Requires one of the tensor fundamental types as key");
736 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override {
737 return this->IsMapCompatible(type_proto);
742 using namespace data_types_internal;
743 MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
744 MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->
GetTypeProto(),
759 template <
typename CPPType>
764 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override {
765 return this->IsSequenceCompatible(type_proto);
770 using namespace data_types_internal;
771 SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->
GetTypeProto(),
785 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override;
794 const ONNX_NAMESPACE::TypeProto*
GetTypeProto()
const override;
809 #if defined(_MSC_VER) && !defined(__clang__)
822 template <
typename TensorElemType>
826 "Requires one of the tensor fundamental types");
832 return DataTypeImpl::GetType<TensorElemType>();
837 using namespace data_types_internal;
838 SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->
GetTypeProto(),
856 template <
typename T, const
char D[], const
char N[]>
861 bool IsCompatible(
const ONNX_NAMESPACE::TypeProto& type_proto)
const override {
862 return this->IsOpaqueCompatible(type_proto);
906 const int32_t data_type_;
917 template <
typename T>
920 static void Delete(
void* p) {
921 delete static_cast<T*
>(p);
934 utils::ToTensorProtoElementType<T>()} {
946 #if !defined(DISABLE_SPARSE_TENSORS)
952 #if !defined(DISABLE_OPTIONAL_TYPE)
971 #define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
973 MLDataType TensorType<ELEM_TYPE>::Type() { \
974 static TensorType<ELEM_TYPE> tensor_type; \
975 return &tensor_type; \
978 MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
979 return TensorType<ELEM_TYPE>::Type(); \
982 #if !defined(DISABLE_SPARSE_TENSORS)
983 #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
985 MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
986 static SparseTensorType<ELEM_TYPE> tensor_type; \
987 return &tensor_type; \
990 MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
991 return SparseTensorType<ELEM_TYPE>::Type(); \
995 #define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
997 MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
998 static OptionalType<ORT_TYPE, TYPE> optional_type; \
999 return &optional_type; \
1002 MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
1003 return OptionalType<ORT_TYPE, TYPE>::Type(); \
1006 #if !defined(DISABLE_ML_OPS)
1007 #define ORT_REGISTER_MAP(TYPE) \
1009 MLDataType MapType<TYPE>::Type() { \
1010 static MapType<TYPE> map_type; \
1014 MLDataType DataTypeImpl::GetType<TYPE>() { \
1015 return MapType<TYPE>::Type(); \
1019 #define ORT_REGISTER_SEQ(TYPE) \
1021 MLDataType SequenceType<TYPE>::Type() { \
1022 static SequenceType<TYPE> sequence_type; \
1023 return &sequence_type; \
1026 MLDataType DataTypeImpl::GetType<TYPE>() { \
1027 return SequenceType<TYPE>::Type(); \
1030 #define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
1032 MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
1033 static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
1034 return &sequence_tensor_type; \
1037 MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
1038 return SequenceTensorType<ELEM_TYPE>::Type(); \
1041 #define ORT_REGISTER_PRIM_TYPE(TYPE) \
1043 MLDataType PrimitiveDataType<TYPE>::Type() { \
1044 static PrimitiveDataType<TYPE> prim_data_type; \
1045 return &prim_data_type; \
1048 MLDataType DataTypeImpl::GetType<TYPE>() { \
1049 return PrimitiveDataType<TYPE>::Type(); \
1052 #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
1054 MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
1055 static OpaqueType<CPPType, Domain, Name> opaque_type; \
1056 return &opaque_type; \
1059 MLDataType DataTypeImpl::GetType<CPPType>() { \
1060 return OpaqueType<CPPType, Domain, Name>::Type(); \
void AssignOpaqueDomainName(const char *domain, const char *name, ONNX_NAMESPACE::TypeProto &proto)
OpaqueTypes helpers.
static const std::vector< MLDataType > & AllFixedSizeTensorExceptHalfTypes()
std::vector< int64_t > VectorInt64
static void RegisterDataType(MLDataType)
static const TensorTypeBase * TensorTypeFromONNXEnum(int type)
virtual const ONNX_NAMESPACE::TypeProto * GetTypeProto() const =0
Retrieves an instance of TypeProto for a given MLDataType.
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Base class for MLDataType.
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
virtual MLDataType GetElementType() const
static void ToContainer(const OrtValue &, size_t, void *)
void FromDataContainer(const void *data, size_t data_size, OrtValue &output) const override
SequenceTensorTypeBase serves as a base type class for Tensor sequences. Akin to TensorTypeBase. Runtime representation is always TensorSeq.
ONNX_NAMESPACE::TypeProto & MutableTypeProto()
static void FromContainer(MLDataType, const void *, size_t, OrtValue &)
GLsizei const GLchar *const * string
void ToDataContainer(const OrtValue &input, size_t data_size, void *data) const override
GLsizei const GLfloat * value
std::map< std::string, float > MapStringToFloat
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const =0
this API will be used to check type compatibility at runtime
static MLDataType GetTensorType()
std::map< int64_t, int64_t > MapInt64ToInt64
MapType. Use this type to register mapping types.
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Sequence helpers.
#define ORT_ENFORCE(condition,...)
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
static MLDataType GetElemType()
CreateFunc GetCreateFunc() const override
static const std::vector< MLDataType > & AllOptionalTypes()
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const final
Retrieves an instance of TypeProto for a given MLDataType.
DeleteFunc GetDeleteFunc() const override
virtual MLDataType GetElementType() const
Common base-class for all sparse-tensors (with different element types).
static MLDataType GetSequenceTensorType()
int32_t GetDataType() const
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Optional helpers.
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const override
Retrieves an instance of TypeProto for a given MLDataType.
static const std::vector< MLDataType > & AllIEEEFloatTensorTypes()
static const std::vector< MLDataType > & AllFixedSizeTensorTypes()
static void Set(const onnx::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
static const std::vector< MLDataType > & AllTensorTypes()
const SparseTensorTypeBase * AsSparseTensorType() const
static MLDataType GetSparseTensorType()
PrimitiveDataTypeBase(size_t size, int32_t data_type)
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypes()
static const std::vector< MLDataType > & AllSequenceTensorTypes()
std::map< int64_t, float > MapInt64ToFloat
void(*)(void *) DeleteFunc
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase)
static void Set(const ONNX_NAMESPACE::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
virtual MLDataType GetElementType() const
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypes()
std::map< std::string, int64_t > MapStringToInt64
const NonTensorTypeBase * AsNonTensorType() const
Common base-class for all optional types.
PrimitiveDataType Typed specialization for primitive types. Concrete instances of this class are used...
bool IsNonTensorType() const
const TensorTypeBase * AsTensorType() const
static const std::vector< MLDataType > & AllIEEEFloatTensorExceptHalfTypes()
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
bool IsTensorType() const
MLDataType GetElementType() const override
static MLDataType GetDataType(const std::string &)
DeleteFunc GetDeleteFunc() const override
STATIC_INLINE uint64_t H(uint64_t x, uint64_t y, uint64_t mul, int r)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
GLuint const GLchar * name
std::map< std::string, std::string > MapStringToString
Predefined registered types.
static MLDataType GetValueType()
static const SequenceTensorTypeBase * SequenceTensorTypeFromONNXEnum(int type)
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Map helpers.
static MLDataType GetOptionalType()
const DataTypeImpl * MLDataType
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypes()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
std::map< int64_t, double > MapInt64ToDouble
static const std::vector< MLDataType > & AllNumericTensorTypes()
virtual MLDataType GetElementType() const
std::map< int64_t, std::string > MapInt64ToString
SequenceType. Use to register sequence for non-tensor types.
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
const OptionalTypeBase * AsOptionalType() const
SequenceTensorType. Use to register sequence for non-tensor types.
DataTypeImpl(GeneralType type, size_t size)
bool IsSparseTensorType() const
MLDataType GetElementType() const override
virtual ~DataTypeImpl()=default
std::map< std::string, double > MapStringToDouble
GA_API const UT_StringHolder N
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &) const override
this API will be used to check type compatibility at runtime
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypes()
Tensor type. This type does not have a C++ type associated with it at registration time except the el...
DeleteFunc GetDeleteFunc() const override
DeleteFunc GetDeleteFunc() const override
bool IsPrimitiveDataType() const
GLsizei GLenum GLenum * types
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
Base type for all non-tensors, maps, sequences and opaques.
const SequenceTensorTypeBase * AsSequenceTensorType() const
std::vector< std::string > VectorString
static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto *value_proto, ONNX_NAMESPACE::TypeProto &proto)
std::vector< MapInt64ToFloat > VectorMapInt64ToFloat
static MLDataType GetElemType()
~TensorTypeBase() override
static const SparseTensorTypeBase * SparseTensorTypeFromONNXEnum(int type)
std::vector< MapStringToFloat > VectorMapStringToFloat
virtual DeleteFunc GetDeleteFunc() const =0
Provide a specialization for your C++ Non-tensor type so your implementation FromDataTypeContainer/To...
bool IsOptionalType() const
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
static MLDataType GetType()