HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
data_types.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <cstdint>
7 #include <cstring>
8 #include <string>
9 #include <type_traits>
10 #include <map>
11 #include <unordered_map>
12 #include "core/common/gsl.h"
13 #include "core/common/common.h"
14 #include "core/common/exceptions.h"
15 #include "core/framework/endian.h"
16 #include "core/framework/float16.h"
18 #if !defined(ORT_MINIMAL_BUILD)
19 #include "onnx/defs/schema.h"
20 #else
21 #include "onnx/defs/data_type_utils.h"
22 #endif
23 #include "onnx/onnx_pb.h"
24 #include "onnx/onnx-operators_pb.h"
25 
26 struct OrtValue;
27 
28 namespace ONNX_NAMESPACE {
29 class TypeProto;
30 } // namespace ONNX_NAMESPACE
31 
32 namespace onnxruntime {
33 /// Predefined registered types
34 
35 #if !defined(DISABLE_ML_OPS)
36 
37 // maps (only used by ML ops)
38 using MapStringToString = std::map<std::string, std::string>;
39 using MapStringToInt64 = std::map<std::string, int64_t>;
40 using MapStringToFloat = std::map<std::string, float>;
41 using MapStringToDouble = std::map<std::string, double>;
42 using MapInt64ToString = std::map<int64_t, std::string>;
43 using MapInt64ToInt64 = std::map<int64_t, int64_t>;
44 using MapInt64ToFloat = std::map<int64_t, float>;
45 using MapInt64ToDouble = std::map<int64_t, double>;
46 
47 // vectors/sequences
48 using VectorMapStringToFloat = std::vector<MapStringToFloat>;
49 using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
50 
51 #endif
52 
53 using VectorString = std::vector<std::string>;
54 using VectorInt64 = std::vector<int64_t>;
55 
56 // Forward declarations
57 class DataTypeImpl;
58 class TensorTypeBase;
59 #if !defined(DISABLE_SPARSE_TENSORS)
61 #endif
63 class NonTensorTypeBase;
64 #if !defined(DISABLE_OPTIONAL_TYPE)
65 class OptionalTypeBase;
66 #endif
68 class Tensor;
69 class TensorSeq;
70 
71 // DataTypeImpl pointer as unique DataTypeImpl identifier.
72 using MLDataType = const DataTypeImpl*;
73 // be used with class MLValue
74 using DeleteFunc = void (*)(void*);
75 using CreateFunc = void* (*)();
76 
77 /**
78  * \brief Base class for MLDataType
79  *
80  */
81 class DataTypeImpl {
82  public:
83  enum class GeneralType {
84  kInvalid = 0,
85  kNonTensor = 1,
86  kTensor = 2,
87  kTensorSequence = 3,
88  kSparseTensor = 4,
89  kOptional = 5,
90  kPrimitive = 6,
91  };
92 
94  const size_t size_;
95 
96  protected:
98 
99  public:
100  virtual ~DataTypeImpl() = default;
101 
102  /**
103  * \brief this API will be used to check type compatibility at runtime
104  *
105  * \param type_proto a TypeProto instance that is constructed for a specific type
106  * will be checked against a TypeProto instance contained within a corresponding
107  * MLDataType instance.
108  */
109  virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
110 
111  size_t Size() const { return size_; }
112 
113  virtual DeleteFunc GetDeleteFunc() const = 0;
114 
115  /**
116  * \brief Retrieves an instance of TypeProto for
117  * a given MLDataType
118  * \returns optional TypeProto. Only ONNX types
119  has type proto, non-ONNX types will return nullptr.
120  */
121  virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
122 
123  bool IsTensorType() const {
124  return type_ == GeneralType::kTensor;
125  }
126 
127  bool IsTensorSequenceType() const {
129  }
130 
131  bool IsSparseTensorType() const {
133  }
134 
135  bool IsOptionalType() const {
136  return type_ == GeneralType::kOptional;
137  }
138 
139  bool IsNonTensorType() const {
140  return type_ == GeneralType::kNonTensor;
141  }
142 
143  bool IsPrimitiveDataType() const {
144  return type_ == GeneralType::kPrimitive;
145  }
146 
147  // Returns this if this is of tensor-type and null otherwise
148  const TensorTypeBase* AsTensorType() const;
149 
151 
152 #if !defined(DISABLE_SPARSE_TENSORS)
153  // Returns this if this is of sparse-tensor-type and null otherwise
155 #endif
156 
157 #if !defined(DISABLE_OPTIONAL_TYPE)
158  const OptionalTypeBase* AsOptionalType() const;
159 #endif
160 
161  const NonTensorTypeBase* AsNonTensorType() const;
162 
163  // Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase)
164  // and null otherwise
166 
167  // Return the type meta that we are using in the runtime.
168  template <typename T>
169  static MLDataType GetType();
170 
171  // Return the types for a concrete tensor type, like Tensor_Float
172  template <typename elemT>
173  static MLDataType GetTensorType();
174 
175  template <typename elemT>
177 
178 #if !defined(DISABLE_SPARSE_TENSORS)
179  // Return the MLDataType for a concrete sparse tensor type.
180  template <typename elemT>
182 #endif
183 
184  template <typename T, typename elemT>
185  static MLDataType GetOptionalType();
186 
187  /**
188  * Convert an ONNX TypeProto to onnxruntime DataTypeImpl.
189  * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back.
190  * Even though GetTypeProto() will not have the original information, it will still have enough to correctly
191  * map to MLDataType.
192  * \param proto
193  */
194  static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto);
195 
196  static const TensorTypeBase* TensorTypeFromONNXEnum(int type);
198 #if !defined(DISABLE_SPARSE_TENSORS)
200 #endif
201 
202  static const char* ToString(MLDataType type);
203  static std::vector<std::string> ToString(const std::vector<MLDataType>& types);
204  // Registers ONNX_NAMESPACE::DataType (internalized string) with
205  // MLDataType. DataType is produced by internalizing an instance of
206  // TypeProto contained within MLDataType
207  static void RegisterDataType(MLDataType);
208  static MLDataType GetDataType(const std::string&);
209 
210  static const std::vector<MLDataType>& AllTensorTypes();
211  static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
212  static const std::vector<MLDataType>& AllSequenceTensorTypes();
213  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypes();
214  static const std::vector<MLDataType>& AllNumericTensorTypes();
215  static const std::vector<MLDataType>& AllIEEEFloatTensorTypes();
216  static const std::vector<MLDataType>& AllFixedSizeTensorExceptHalfTypes();
217  static const std::vector<MLDataType>& AllIEEEFloatTensorExceptHalfTypes();
218  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypes();
219  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypes();
220  static const std::vector<MLDataType>& AllOptionalTypes();
221  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypes();
222 };
223 
224 std::ostream& operator<<(std::ostream& out, MLDataType data_type);
225 
226 /*
227  * Type registration helpers
228  */
229 namespace data_types_internal {
230 /// TensorType helpers
231 ///
232 
233 /// Is a given type on the list of types?
234 /// Accepts a list of types and the first argument is the type
235 /// We are checking if it is listed among those that follow
236 template <typename T, typename... Types>
237 struct IsAnyOf;
238 
239 /// Two types remaining, end of the list
240 template <typename T, typename Tail>
241 struct IsAnyOf<T, Tail> : public std::is_same<T, Tail> {
242 };
243 
244 template <typename T, typename H, typename... Tail>
245 struct IsAnyOf<T, H, Tail...> {
246  static constexpr bool value = (std::is_same<T, H>::value ||
247  IsAnyOf<T, Tail...>::value);
248 };
249 
250 /// Tells if the specified type is one of fundamental types
251 /// that can be contained within a tensor.
252 /// We do not have raw fundamental types, rather a subset
253 /// of fundamental types is contained within tensors.
254 template <typename T>
255 struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
256  int32_t, int64_t, std::string, bool, MLFloat16,
257  double, uint32_t, uint64_t, BFloat16> {
258 };
259 
260 #if !defined(DISABLE_SPARSE_TENSORS)
261 /// Use "IsSparseTensorContainedType<T>::value" to test if a type T
262 /// is permitted as the element-type of a sparse-tensor.
263 
264 template <typename T>
265 struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
266  int32_t, int64_t, std::string, bool, MLFloat16,
267  double, uint32_t, uint64_t, BFloat16> {
268 };
269 #endif
270 
271 #if !defined(DISABLE_OPTIONAL_TYPE)
272 /// Tells if the specified type is one of ORT types
273 /// that can be contained within an optional struct.
274 template <typename T>
275 struct IsOptionalOrtType : public IsAnyOf<T, Tensor, TensorSeq> {
276 };
277 #endif
278 
279 /// This template's Get() returns a corresponding MLDataType
280 /// It dispatches the call to either GetTensorType<>() or
281 /// GetType<>()
282 template <typename T, bool TensorContainedType>
284 
285 template <typename T>
286 struct GetMLDataType<T, true> {
287  static MLDataType Get() {
288  return DataTypeImpl::GetTensorType<T>();
289  }
290 };
291 
292 template <typename T>
293 struct GetMLDataType<T, false> {
294  static MLDataType Get() {
295  return DataTypeImpl::GetType<T>();
296  }
297 };
298 
300  static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
301  ONNX_NAMESPACE::TypeProto& proto) {
302  proto.mutable_tensor_type()->set_elem_type(element_type);
303  }
304 };
305 
306 #if !defined(DISABLE_SPARSE_TENSORS)
308  static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
309  ONNX_NAMESPACE::TypeProto& proto) {
310  proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
311  }
312 };
313 #endif // !defined(DISABLE_SPARSE_TENSORS)
314 
315 #if !defined(DISABLE_ML_OPS)
316 /// Map helpers
317 
318 void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&,
319  ONNX_NAMESPACE::TypeProto&);
320 
322  // V can be either a primitive type (in which case it is a tensor)
323  // or other preregistered types
324  template <typename V>
327  }
328 
329  static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto* value_proto,
330  ONNX_NAMESPACE::TypeProto& proto) {
331  ORT_ENFORCE(value_proto != nullptr, "expected a registered ONNX type");
332  proto.mutable_map_type()->set_key_type(key_type);
333  CopyMutableMapValue(*value_proto, proto);
334  }
335 };
336 #endif
337 
338 /// Sequence helpers
339 
340 // Element type is a primitive type so we set it to a tensor<elemT>
341 void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
342  ONNX_NAMESPACE::TypeProto&);
343 
344 // helper to create TypeProto with minimal binary size impact
346  template <typename T>
349  }
350 
351  static void Set(const ONNX_NAMESPACE::TypeProto* elem_proto,
352  ONNX_NAMESPACE::TypeProto& proto) {
353  ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
354  CopyMutableSeqElement(*elem_proto, proto);
355  }
356 };
357 
358 /// Optional helpers
359 
360 void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
361  ONNX_NAMESPACE::TypeProto&);
362 
363 // helper to create TypeProto with minimal binary size impact
365  template <typename T, typename elemT>
367  if constexpr (std::is_same<T, Tensor>::value) {
368  return DataTypeImpl::GetTensorType<elemT>();
369  } else {
370  static_assert(std::is_same<T, TensorSeq>::value, "Unsupported element type for optional type");
371  return DataTypeImpl::GetSequenceTensorType<elemT>();
372  }
373  }
374 
375  static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
376  ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
377  CopyMutableOptionalElement(*elem_proto, proto);
378  }
379 };
380 
381 /// OpaqueTypes helpers
382 
383 void AssignOpaqueDomainName(const char* domain, const char* name,
384  ONNX_NAMESPACE::TypeProto& proto);
385 
386 } // namespace data_types_internal
387 
388 //The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
389 //However, we do not allocate this type on heap.
390 #if defined(_MSC_VER) && !defined(__clang__)
391 #pragma warning(push)
392 #pragma warning(disable : 26436)
393 #endif
394 /// All tensors base
395 class TensorTypeBase : public DataTypeImpl {
396  public:
397  static MLDataType Type();
398 
399  /// We first compare type_proto pointers and then
400  /// if they do not match try to account for the case
401  /// where TypeProto was created ad-hoc and not queried from MLDataType
402  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
403 
404  DeleteFunc GetDeleteFunc() const override;
405 
406  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
407 
408  virtual MLDataType GetElementType() const {
409  // should never reach here.
410  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
411  }
412 
414 
415  protected:
416  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
417 
418  TensorTypeBase();
419  ~TensorTypeBase() override;
420 
421  private:
422  struct Impl;
423  Impl* impl_;
424 };
425 
426 /**
427  * \brief Tensor type. This type does not have a C++ type associated with
428  * it at registration time except the element type. One of the types mentioned
429  * above at IsTensorContainedType<> list is acceptable.
430  *
431  * \details
432  * Usage:
433  * ORT_REGISTER_TENSOR(ELEMENT_TYPE)
434  * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor<type>
435  * type. IsCompatible() currently ignores shape.
436  */
437 
438 template <typename elemT>
439 class TensorType : public TensorTypeBase {
440  public:
442  "Requires one of the tensor fundamental types");
443 
444  static MLDataType Type();
445 
446  /// Tensors only can contain basic data types
447  /// that have been previously registered with ONNXRuntime
448  MLDataType GetElementType() const override {
449  return DataTypeImpl::GetType<elemT>();
450  }
451 
452  private:
453  TensorType() {
454  using namespace data_types_internal;
455  TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
456  }
457 };
458 
459 #if defined(DISABLE_OPTIONAL_TYPE)
460 
461 // TODO is this still needed after removing kernel def hashes?
462 /// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
463 /// with disabled types to keep the ORT format model kernel hashes stable.
464 class DisabledTypeBase : public DataTypeImpl {
465  public:
466  static MLDataType Type();
467 
468  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
469  // We always want to return false for the IsCompatible() for a disabled type
470  // because this will ensure that no kernel supporting the disabled type will
471  // be matched to a model node requiring that type and the model load will
472  // result in failure.
473  return false;
474  }
475 
476  DeleteFunc GetDeleteFunc() const override {
477  ORT_THROW("Type is disabled in this build.");
478  }
479 
480  // This must work
481  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
482 
483  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase);
484 
485  protected:
486  // This must work
487  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
488 
489  DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
490  ~DisabledTypeBase() override;
491 
492  private:
493  struct Impl;
494  Impl* impl_;
495 };
496 
497 #endif
498 
499 #if !defined(DISABLE_SPARSE_TENSORS)
500 /// Common base-class for all sparse-tensors (with different element types).
502  public:
503  static MLDataType Type();
504 
505  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
506 
507  DeleteFunc GetDeleteFunc() const override;
508 
509  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
510 
511  virtual MLDataType GetElementType() const {
512  // should never reach here.
513  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
514  }
515 
517 
518  protected:
519  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
520 
522  ~SparseTensorTypeBase() override;
523 
524  private:
525  struct Impl;
526  Impl* impl_;
527 };
528 
529 template <typename elemT>
531  public:
533  "Requires one of the sparse-tensor fundamental types");
534 
535  static MLDataType Type();
536 
537  /// Return a MLDataType representing the element-type
538  MLDataType GetElementType() const override {
539  return DataTypeImpl::GetType<elemT>();
540  }
541 
542  private:
543  SparseTensorType() {
544  using namespace data_types_internal;
545  SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
546  }
547 };
548 
549 #endif // !defined(DISABLE_SPARSE_TENSORS)
550 
551 /// Common base-class for all optional types.
552 
553 #if !defined(DISABLE_OPTIONAL_TYPE)
555  public:
556  static MLDataType Type();
557 
558  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
559 
560  DeleteFunc GetDeleteFunc() const override {
561  // should never reach here.
562  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
563  }
564 
565  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
566 
567  virtual MLDataType GetElementType() const {
568  // should never reach here.
569  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
570  }
571 
572  OptionalTypeBase(const OptionalTypeBase&) = delete;
573  OptionalTypeBase& operator=(const OptionalTypeBase&) = delete;
574 
575  protected:
576  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
577 
579  ~OptionalTypeBase() override;
580 
581  private:
582  struct Impl;
583  Impl* impl_;
584 };
585 #endif
586 
587 // Derive from OptionalTypeBase if the Optional type support is enabled,
588 // else derive from DisabledTypeBase
589 template <typename T, typename elemT>
591 #if !defined(DISABLE_OPTIONAL_TYPE)
592  public OptionalTypeBase
593 #else
594  public DisabledTypeBase
595 #endif
596 {
597  public:
598  static MLDataType Type();
599 
600 #if !defined(DISABLE_OPTIONAL_TYPE)
602  "Requires one of the supported types: Tensor or TensorSeq");
603 
605  "Requires one of the tensor fundamental types");
606 
607  MLDataType GetElementType() const override {
608  return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
609  }
610 #endif
611 
612  private:
613 #if !defined(DISABLE_OPTIONAL_TYPE)
614  OptionalType()
615 #else
616  OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
617 #endif
618  {
619  using namespace data_types_internal;
620  OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>()->GetTypeProto(), MutableTypeProto());
621  }
622 }; // namespace onnxruntime
623 
624 /**
625  * \brief Provide a specialization for your C++ Non-tensor type
626  * so your implementation FromDataTypeContainer/ToDataTypeContainer
627  * functions correctly. Otherwise you get a default implementation
628  * which may not be what you need/want.
629  *
630  * This class is used to create OrtValue, fetch data from OrtValue via
631  * C/C++ APIs
632  */
633 template <class T>
635  static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
636  ORT_THROW("Not implemented");
637  }
638  static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) {
639  ORT_THROW("Not implemented");
640  }
641 };
642 
643 /**
644  * \brief Base type for all non-tensors, maps, sequences and opaques
645  */
647  public:
648  DeleteFunc GetDeleteFunc() const override = 0;
649 
650  virtual CreateFunc GetCreateFunc() const = 0;
651 
652  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
653 
654  // \brief Override for Non-tensor types to initialize non-tensor CPP
655  // data representation from data. The caller of the interface
656  // should have a shared definition of the data which is used to initialize
657  // CPP data representation. This is used from C API.
658  //
659  // \param data - pointer to a data container structure non_tensor type specific
660  // \param data_size - size of the data container structure, used for rudimentary checks
661  // \param output - reference to a default constructed non-tensor type
662  // \returns OrtValue
663  // \throw if there is an error
664  virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const;
665 
666  // \brief Override for Non-tensor types to fetch data from the internal CPP data representation
667  // The caller of the interface should have a shared definition of the data which is used to initialize
668  // CPP data representation. This is used from C API.
669  //
670  // \param input - OrtValue containing data
671  // \param data_size - size of the structure that is being passed for receiving data, used for
672  // validation
673  // \param data - pointer to receiving data structure
674  virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const;
675 
676  NonTensorTypeBase(const NonTensorTypeBase&) = delete;
677  NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete;
678 
679  protected:
680  NonTensorTypeBase(size_t size);
681  ~NonTensorTypeBase() override;
682 
683  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
684 
685  bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
686 
687  bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
688 
689  bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
690 
691  private:
692  struct Impl;
693  Impl* impl_;
694 };
695 
696 // This is where T is the actual CPPRuntimeType
697 template <typename T>
699  private:
700  static void Delete(void* p) {
701  delete static_cast<T*>(p);
702  }
703 
704  public:
705  DeleteFunc GetDeleteFunc() const override {
706  return &Delete;
707  }
708 
709  CreateFunc GetCreateFunc() const override {
710  return []() -> void* { return new T(); };
711  }
712 
713  protected:
715 };
716 
717 #if !defined(DISABLE_ML_OPS)
718 /**
719  * \brief MapType. Use this type to register
720  * mapping types.
721  *
722  * \param T - cpp type that you wish to register as runtime MapType
723  *
724  * \details Usage: ORT_REGISTER_MAP(C++Type)
725  * The type is required to have mapped_type and
726  * key_type defined
727  */
728 template <typename CPPType>
729 class MapType : public NonTensorType<CPPType> {
730  public:
732  "Requires one of the tensor fundamental types as key");
733 
734  static MLDataType Type();
735 
736  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
737  return this->IsMapCompatible(type_proto);
738  }
739 
740  private:
741  MapType() {
742  using namespace data_types_internal;
743  MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
744  MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->GetTypeProto(),
745  this->MutableTypeProto());
746  }
747 };
748 #endif
749 
750 /**
751  * \brief SequenceType. Use to register sequence for non-tensor types.
752  *
753  * \param T - CPP type that you wish to register as Sequence
754  * runtime type.
755  *
756  * \details Usage: ORT_REGISTER_SEQ(C++Type)
757  * The type is required to have value_type defined
758  */
759 template <typename CPPType>
760 class SequenceType : public NonTensorType<CPPType> {
761  public:
762  static MLDataType Type();
763 
764  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
765  return this->IsSequenceCompatible(type_proto);
766  }
767 
768  private:
769  SequenceType() {
770  using namespace data_types_internal;
771  SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->GetTypeProto(),
772  this->MutableTypeProto());
773  }
774 };
775 
776 /**
777  * \brief SequenceTensorTypeBase serves as a base type class for
778  * Tensor sequences. Akin to TensorTypeBase.
779  * Runtime representation is always TensorSeq.
780  */
782  public:
783  static MLDataType Type();
784 
785  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
786 
787  virtual MLDataType GetElementType() const {
788  // should never reach here.
789  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
790  }
791 
792  DeleteFunc GetDeleteFunc() const override;
793 
794  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
795 
797  SequenceTensorTypeBase& operator=(const SequenceTensorTypeBase&) = delete;
798 
799  protected:
802 
803  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
804 
805  private:
806  struct Impl;
807  Impl* impl_;
808 };
809 #if defined(_MSC_VER) && !defined(__clang__)
810 #pragma warning(pop)
811 #endif
812 /**
813  * \brief SequenceTensorType. Use to register sequence for non-tensor types.
814  *
815  * \param CPPRuntime - We always use TensorSeq
816  *
817  * \param TensorElemType - one of the primitive types
818  *
819  * \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE()
820  * The type is required to have value_type defined
821  */
822 template <typename TensorElemType>
824  public:
826  "Requires one of the tensor fundamental types");
827 
828  static MLDataType Type();
829 
830  /// Return a MLDataType representing the element-type
831  MLDataType GetElementType() const override {
832  return DataTypeImpl::GetType<TensorElemType>();
833  }
834 
835  private:
837  using namespace data_types_internal;
838  SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->GetTypeProto(),
839  MutableTypeProto());
840  }
841 };
842 
843 /**
844  * \brief OpaqueType
845  *
846  * \tparam T - cpp runtume that implements the Opaque type
847  *
848  * \tparam const char D[] - domain must be extern to be unique
849  *
850  * \tparam const char N[] - name must be extern to be unique
851  *
852  * \details Only one CPP type can be associated with a particular
853  * OpaqueType registration
854  *
855  */
856 template <typename T, const char D[], const char N[]>
857 class OpaqueType : public NonTensorType<T> {
858  public:
859  static MLDataType Type();
860 
861  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
862  return this->IsOpaqueCompatible(type_proto);
863  }
864 
865  void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override {
866  NonTensorTypeConverter<T>::FromContainer(this, data, data_size, output);
867  }
868 
869  void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override {
870  NonTensorTypeConverter<T>::ToContainer(input, data_size, data);
871  }
872 
873  private:
874  OpaqueType() {
876  }
877 };
878 
879 /**
880  * \brief PrimitiveDataTypeBase
881  * Base class for primitive Tensor contained types
882  *
883  * \details This class contains an integer constant that can be
884  * used for input data type dispatching
885  *
886  */
888  public:
889  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
890  return false;
891  }
892 
893  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
894  return nullptr;
895  }
896 
897  int32_t GetDataType() const {
898  return data_type_;
899  }
900 
901  protected:
902  PrimitiveDataTypeBase(size_t size, int32_t data_type)
903  : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
904 
905  private:
906  const int32_t data_type_;
907 };
908 
909 /**
910  * \brief PrimitiveDataType
911  * Typed specialization for primitive types.
912  * Concrete instances of this class are used by Tensor.
913  *
914  * \param T - primitive data type
915  *
916  */
917 template <typename T>
919  private:
920  static void Delete(void* p) {
921  delete static_cast<T*>(p);
922  }
923 
924  public:
925  static MLDataType Type();
926 
927  DeleteFunc GetDeleteFunc() const override {
928  return &Delete;
929  }
930 
931  private:
933  : PrimitiveDataTypeBase{sizeof(T),
934  utils::ToTensorProtoElementType<T>()} {
935  }
936 };
937 
939  return IsTensorType() ? static_cast<const TensorTypeBase*>(this) : nullptr;
940 }
941 
943  return IsTensorSequenceType() ? static_cast<const SequenceTensorTypeBase*>(this) : nullptr;
944 }
945 
946 #if !defined(DISABLE_SPARSE_TENSORS)
948  return IsSparseTensorType() ? static_cast<const SparseTensorTypeBase*>(this) : nullptr;
949 }
950 #endif
951 
952 #if !defined(DISABLE_OPTIONAL_TYPE)
954  return IsOptionalType() ? static_cast<const OptionalTypeBase*>(this) : nullptr;
955 }
956 #endif
957 
959  return IsNonTensorType() ? static_cast<const NonTensorTypeBase*>(this) : nullptr;
960 }
961 
963  return IsPrimitiveDataType() ? static_cast<const PrimitiveDataTypeBase*>(this) : nullptr;
964 }
965 
966 // Explicit specialization of base class template function
967 // is only possible within the enclosing namespace scope,
968 // thus a simple way to pre-instantiate a given template
969 // at a registration time does not currently work and the macro
970 // is needed.
971 #define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
972  template <> \
973  MLDataType TensorType<ELEM_TYPE>::Type() { \
974  static TensorType<ELEM_TYPE> tensor_type; \
975  return &tensor_type; \
976  } \
977  template <> \
978  MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
979  return TensorType<ELEM_TYPE>::Type(); \
980  }
981 
982 #if !defined(DISABLE_SPARSE_TENSORS)
983 #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
984  template <> \
985  MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
986  static SparseTensorType<ELEM_TYPE> tensor_type; \
987  return &tensor_type; \
988  } \
989  template <> \
990  MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
991  return SparseTensorType<ELEM_TYPE>::Type(); \
992  }
993 #endif
994 
995 #define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
996  template <> \
997  MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
998  static OptionalType<ORT_TYPE, TYPE> optional_type; \
999  return &optional_type; \
1000  } \
1001  template <> \
1002  MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
1003  return OptionalType<ORT_TYPE, TYPE>::Type(); \
1004  }
1005 
1006 #if !defined(DISABLE_ML_OPS)
1007 #define ORT_REGISTER_MAP(TYPE) \
1008  template <> \
1009  MLDataType MapType<TYPE>::Type() { \
1010  static MapType<TYPE> map_type; \
1011  return &map_type; \
1012  } \
1013  template <> \
1014  MLDataType DataTypeImpl::GetType<TYPE>() { \
1015  return MapType<TYPE>::Type(); \
1016  }
1017 #endif
1018 
1019 #define ORT_REGISTER_SEQ(TYPE) \
1020  template <> \
1021  MLDataType SequenceType<TYPE>::Type() { \
1022  static SequenceType<TYPE> sequence_type; \
1023  return &sequence_type; \
1024  } \
1025  template <> \
1026  MLDataType DataTypeImpl::GetType<TYPE>() { \
1027  return SequenceType<TYPE>::Type(); \
1028  }
1029 
1030 #define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
1031  template <> \
1032  MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
1033  static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
1034  return &sequence_tensor_type; \
1035  } \
1036  template <> \
1037  MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
1038  return SequenceTensorType<ELEM_TYPE>::Type(); \
1039  }
1040 
1041 #define ORT_REGISTER_PRIM_TYPE(TYPE) \
1042  template <> \
1043  MLDataType PrimitiveDataType<TYPE>::Type() { \
1044  static PrimitiveDataType<TYPE> prim_data_type; \
1045  return &prim_data_type; \
1046  } \
1047  template <> \
1048  MLDataType DataTypeImpl::GetType<TYPE>() { \
1049  return PrimitiveDataType<TYPE>::Type(); \
1050  }
1051 
1052 #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
1053  template <> \
1054  MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
1055  static OpaqueType<CPPType, Domain, Name> opaque_type; \
1056  return &opaque_type; \
1057  } \
1058  template <> \
1059  MLDataType DataTypeImpl::GetType<CPPType>() { \
1060  return OpaqueType<CPPType, Domain, Name>::Type(); \
1061  }
1062 } // namespace onnxruntime
void AssignOpaqueDomainName(const char *domain, const char *name, ONNX_NAMESPACE::TypeProto &proto)
OpaqueTypes helpers.
static const std::vector< MLDataType > & AllFixedSizeTensorExceptHalfTypes()
std::vector< int64_t > VectorInt64
Definition: data_types.h:54
static void RegisterDataType(MLDataType)
static const TensorTypeBase * TensorTypeFromONNXEnum(int type)
virtual const ONNX_NAMESPACE::TypeProto * GetTypeProto() const =0
Retrieves an instance of TypeProto for a given MLDataType.
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Definition: data_types.h:127
Base class for MLDataType.
Definition: data_types.h:81
size_t Size() const
Definition: data_types.h:111
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:736
virtual MLDataType GetElementType() const
Definition: data_types.h:511
static void ToContainer(const OrtValue &, size_t, void *)
Definition: data_types.h:638
void
Definition: png.h:1083
void FromDataContainer(const void *data, size_t data_size, OrtValue &output) const override
Definition: data_types.h:865
SequenceTensorTypeBase serves as a base type class for Tensor sequences. Akin to TensorTypeBase. Runtime representation is always TensorSeq.
Definition: data_types.h:781
ONNX_NAMESPACE::TypeProto & MutableTypeProto()
static void FromContainer(MLDataType, const void *, size_t, OrtValue &)
Definition: data_types.h:635
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
void ToDataContainer(const OrtValue &input, size_t data_size, void *data) const override
Definition: data_types.h:869
GLsizei const GLfloat * value
Definition: glcorearb.h:824
std::map< std::string, float > MapStringToFloat
Definition: data_types.h:40
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const =0
this API will be used to check type compatibility at runtime
static MLDataType GetTensorType()
std::map< int64_t, int64_t > MapInt64ToInt64
Definition: data_types.h:43
MapType. Use this type to register mapping types.
Definition: data_types.h:729
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Sequence helpers.
#define ORT_ENFORCE(condition,...)
Definition: common.h:173
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
Definition: data_types.h:538
CreateFunc GetCreateFunc() const override
Definition: data_types.h:709
static const std::vector< MLDataType > & AllOptionalTypes()
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const final
Retrieves an instance of TypeProto for a given MLDataType.
Definition: data_types.h:893
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:560
virtual MLDataType GetElementType() const
Definition: data_types.h:787
Common base-class for all sparse-tensors (with different element types).
Definition: data_types.h:501
static MLDataType GetSequenceTensorType()
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Optional helpers.
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const override
Retrieves an instance of TypeProto for a given MLDataType.
static MLDataType Type()
static const std::vector< MLDataType > & AllIEEEFloatTensorTypes()
static const std::vector< MLDataType > & AllFixedSizeTensorTypes()
static void Set(const onnx::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:375
static const std::vector< MLDataType > & AllTensorTypes()
const SparseTensorTypeBase * AsSparseTensorType() const
Definition: data_types.h:947
All tensors base.
Definition: data_types.h:395
static MLDataType GetSparseTensorType()
PrimitiveDataTypeBase(size_t size, int32_t data_type)
Definition: data_types.h:902
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypes()
static const std::vector< MLDataType > & AllSequenceTensorTypes()
std::map< int64_t, float > MapInt64ToFloat
Definition: data_types.h:44
void(*)(void *) DeleteFunc
Definition: data_types.h:74
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase)
static void Set(const ONNX_NAMESPACE::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:351
virtual MLDataType GetElementType() const
Definition: data_types.h:408
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypes()
std::map< std::string, int64_t > MapStringToInt64
Definition: data_types.h:39
const GeneralType type_
Definition: data_types.h:93
const NonTensorTypeBase * AsNonTensorType() const
Definition: data_types.h:958
Common base-class for all optional types.
Definition: data_types.h:554
PrimitiveDataType Typed specialization for primitive types. Concrete instances of this class are used...
Definition: data_types.h:918
bool IsNonTensorType() const
Definition: data_types.h:139
const TensorTypeBase * AsTensorType() const
Definition: data_types.h:938
static const std::vector< MLDataType > & AllIEEEFloatTensorExceptHalfTypes()
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
Definition: data_types.h:962
bool IsTensorType() const
Definition: data_types.h:123
MLDataType GetElementType() const override
Definition: data_types.h:607
static MLDataType GetDataType(const std::string &)
DeleteFunc GetDeleteFunc() const override
STATIC_INLINE uint64_t H(uint64_t x, uint64_t y, uint64_t mul, int r)
Definition: farmhash.h:701
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
Definition: data_types.h:887
GLuint const GLchar * name
Definition: glcorearb.h:786
std::map< std::string, std::string > MapStringToString
Predefined registered types.
Definition: data_types.h:38
static const SequenceTensorTypeBase * SequenceTensorTypeFromONNXEnum(int type)
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Map helpers.
static MLDataType GetOptionalType()
void *(*)( CreateFunc)
Definition: data_types.h:75
const DataTypeImpl * MLDataType
Definition: data_types.h:72
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypes()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:861
std::map< int64_t, double > MapInt64ToDouble
Definition: data_types.h:45
#define ORT_THROW(...)
Definition: common.h:163
static const std::vector< MLDataType > & AllNumericTensorTypes()
GLsizeiptr size
Definition: glcorearb.h:664
virtual MLDataType GetElementType() const
Definition: data_types.h:567
std::map< int64_t, std::string > MapInt64ToString
Definition: data_types.h:42
SequenceType. Use to register sequence for non-tensor types.
Definition: data_types.h:760
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:764
const OptionalTypeBase * AsOptionalType() const
Definition: data_types.h:953
SequenceTensorType. Use to register sequence for non-tensor types.
Definition: data_types.h:823
DataTypeImpl(GeneralType type, size_t size)
Definition: data_types.h:97
bool IsSparseTensorType() const
Definition: data_types.h:131
MLDataType GetElementType() const override
Definition: data_types.h:448
virtual ~DataTypeImpl()=default
std::map< std::string, double > MapStringToDouble
Definition: data_types.h:41
GA_API const UT_StringHolder N
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:300
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:889
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypes()
Tensor type. This type does not have a C++ type associated with it at registration time except the el...
Definition: data_types.h:439
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:705
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:927
bool IsPrimitiveDataType() const
Definition: data_types.h:143
Definition: core.h:1131
GLsizei GLenum GLenum * types
Definition: glcorearb.h:2542
static MLDataType Type()
#define const
Definition: zconf.h:214
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
Base type for all non-tensors, maps, sequences and opaques.
Definition: data_types.h:646
const SequenceTensorTypeBase * AsSequenceTensorType() const
Definition: data_types.h:942
std::vector< std::string > VectorString
Definition: data_types.h:53
static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:308
static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto *value_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:329
type
Definition: core.h:1059
std::vector< MapInt64ToFloat > VectorMapInt64ToFloat
Definition: data_types.h:49
static const SparseTensorTypeBase * SparseTensorTypeFromONNXEnum(int type)
std::vector< MapStringToFloat > VectorMapStringToFloat
Definition: data_types.h:48
virtual DeleteFunc GetDeleteFunc() const =0
Provide a specialization for your C++ Non-tensor type so your implementation FromDataTypeContainer/To...
Definition: data_types.h:634
Definition: format.h:895
bool IsOptionalType() const
Definition: data_types.h:135
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
Definition: data_types.h:831
static MLDataType GetType()