HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
data_types.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <cstdint>
7 #include <cstring>
8 #include <string>
9 #include <type_traits>
10 #include <map>
11 #include <unordered_map>
12 #include <gsl/gsl>
13 #include "core/common/common.h"
14 #include "core/common/exceptions.h"
15 #include "core/framework/endian.h"
16 #include "core/framework/float8.h"
17 #include "core/framework/float16.h"
18 #include "core/framework/int4.h"
19 #include "core/graph/onnx_protobuf.h"
21 
22 struct OrtValue;
23 
24 namespace ONNX_NAMESPACE {
25 class TypeProto;
26 } // namespace ONNX_NAMESPACE
27 
28 namespace onnxruntime {
29 /// Predefined registered types
30 
31 #if !defined(DISABLE_ML_OPS)
32 
33 // maps (only used by ML ops)
34 using MapStringToString = std::map<std::string, std::string>;
35 using MapStringToInt64 = std::map<std::string, int64_t>;
36 using MapStringToFloat = std::map<std::string, float>;
37 using MapStringToDouble = std::map<std::string, double>;
38 using MapInt64ToString = std::map<int64_t, std::string>;
39 using MapInt64ToInt64 = std::map<int64_t, int64_t>;
40 using MapInt64ToFloat = std::map<int64_t, float>;
41 using MapInt64ToDouble = std::map<int64_t, double>;
42 
43 // vectors/sequences
44 using VectorMapStringToFloat = std::vector<MapStringToFloat>;
45 using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
46 
47 #endif
48 
49 using VectorString = std::vector<std::string>;
50 using VectorInt64 = std::vector<int64_t>;
51 
52 // Forward declarations
53 class DataTypeImpl;
54 class TensorTypeBase;
55 #if !defined(DISABLE_SPARSE_TENSORS)
57 #endif
59 class NonTensorTypeBase;
60 #if !defined(DISABLE_OPTIONAL_TYPE)
61 class OptionalTypeBase;
62 #endif
64 class Tensor;
65 class TensorSeq;
66 
67 // DataTypeImpl pointer as unique DataTypeImpl identifier.
68 using MLDataType = const DataTypeImpl*;
69 // be used with class MLValue
70 using DeleteFunc = void (*)(void*);
71 using CreateFunc = void* (*)();
72 
73 /**
74  * \brief Base class for MLDataType
75  *
76  */
77 class DataTypeImpl {
78  public:
79  enum class GeneralType {
80  kInvalid = 0,
81  kNonTensor = 1,
82  kTensor = 2,
83  kTensorSequence = 3,
84  kSparseTensor = 4,
85  kOptional = 5,
86  kPrimitive = 6,
87  };
88 
90  const size_t size_;
91 
92  protected:
94 
95  public:
96  virtual ~DataTypeImpl() = default;
97 
98  /**
99  * \brief this API will be used to check type compatibility at runtime
100  *
101  * \param type_proto a TypeProto instance that is constructed for a specific type
102  * will be checked against a TypeProto instance contained within a corresponding
103  * MLDataType instance.
104  */
105  virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
106 
107  size_t Size() const { return size_; }
108 
109  virtual DeleteFunc GetDeleteFunc() const = 0;
110 
111  /**
112  * \brief Retrieves an instance of TypeProto for
113  * a given MLDataType
114  * \returns optional TypeProto. Only ONNX types
115  has type proto, non-ONNX types will return nullptr.
116  */
117  virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
118 
119  bool IsTensorType() const {
120  return type_ == GeneralType::kTensor;
121  }
122 
123  bool IsTensorSequenceType() const {
125  }
126 
127  bool IsSparseTensorType() const {
129  }
130 
131  bool IsOptionalType() const {
132  return type_ == GeneralType::kOptional;
133  }
134 
135  bool IsNonTensorType() const {
136  return type_ == GeneralType::kNonTensor;
137  }
138 
139  bool IsPrimitiveDataType() const {
140  return type_ == GeneralType::kPrimitive;
141  }
142 
143  // Returns this if this is of tensor-type and null otherwise
144  const TensorTypeBase* AsTensorType() const;
145 
147 
148 #if !defined(DISABLE_SPARSE_TENSORS)
149  // Returns this if this is of sparse-tensor-type and null otherwise
151 #endif
152 
153 #if !defined(DISABLE_OPTIONAL_TYPE)
154  const OptionalTypeBase* AsOptionalType() const;
155 #endif
156 
157  const NonTensorTypeBase* AsNonTensorType() const;
158 
159  // Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase)
160  // and null otherwise
162 
163  // Return the type meta that we are using in the runtime.
164  template <typename T>
165  static MLDataType GetType();
166 
167  // Return the types for a concrete tensor type, like Tensor_Float
168  template <typename elemT>
169  static MLDataType GetTensorType();
170 
171  template <typename elemT>
173 
174 #if !defined(DISABLE_SPARSE_TENSORS)
175  // Return the MLDataType for a concrete sparse tensor type.
176  template <typename elemT>
178 #endif
179 
180  template <typename T, typename elemT>
181  static MLDataType GetOptionalType();
182 
183  /**
184  * Convert an ONNX TypeProto to onnxruntime DataTypeImpl.
185  * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back.
186  * Even though GetTypeProto() will not have the original information, it will still have enough to correctly
187  * map to MLDataType.
188  * \param proto
189  */
190  static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto);
191 
192  static const TensorTypeBase* TensorTypeFromONNXEnum(int type);
194 #if !defined(DISABLE_SPARSE_TENSORS)
196 #endif
197 
198  static const char* ToString(MLDataType type);
199  static std::vector<std::string> ToString(const std::vector<MLDataType>& types);
200  // Registers ONNX_NAMESPACE::DataType (internalized string) with
201  // MLDataType. DataType is produced by internalizing an instance of
202  // TypeProto contained within MLDataType
203  static void RegisterDataType(MLDataType);
204  static MLDataType GetDataType(const std::string&);
205 
206  // IR4: includes all float types, includes float16, bfloat16
207  // IR9: includes float 8 types as well
208  static const std::vector<MLDataType>& AllTensorTypes(); // up to IR4 (no float 8), deprecated
209  static const std::vector<MLDataType>& AllTensorTypesIRv4();
210  static const std::vector<MLDataType>& AllTensorTypesIRv9();
211 
212  static const std::vector<MLDataType>& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated
213  static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv4();
214  static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv9();
215 
216  static const std::vector<MLDataType>& AllSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
217  static const std::vector<MLDataType>& AllSequenceTensorTypesIRv4();
218  static const std::vector<MLDataType>& AllSequenceTensorTypesIRv9();
219 
220  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
221  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypesIRv4();
222  static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypesIRv9();
223 
224  static const std::vector<MLDataType>& AllNumericTensorTypes(); // up to IR4 (no float 8), deprecated
225  static const std::vector<MLDataType>& AllNumericTensorTypesIRv4();
226  static const std::vector<MLDataType>& AllNumericTensorTypesIRv9();
227 
228  static const std::vector<MLDataType>& AllIEEEFloatTensorTypes(); // float16, float, double
229 
230  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
231  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypesIRv4();
232  static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypesIRv9();
233 
234  static const std::vector<MLDataType>& AllOptionalAndTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
235  static const std::vector<MLDataType>& AllOptionalAndTensorAndSequenceTensorTypesIRv4();
236  static const std::vector<MLDataType>& AllOptionalAndTensorAndSequenceTensorTypesIRv9();
237 
238  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypes(); // up to IR4 (no float 8), deprecated
239  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypesIRv4();
240  static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypesIRv9();
241 
242  static const std::vector<MLDataType>& AllOptionalTypes(); // up to IR4 (no float 8), deprecated
243  static const std::vector<MLDataType>& AllOptionalTypesIRv4();
244  static const std::vector<MLDataType>& AllOptionalTypesIRv9();
245 
246  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypes(); // up to IR4 (no float 8), deprecated
247  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypesIRv4();
248  static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypesIRv9();
249 };
250 
251 std::ostream& operator<<(std::ostream& out, MLDataType data_type);
252 
253 /*
254  * Type registration helpers
255  */
256 namespace data_types_internal {
257 /// TensorType helpers
258 ///
259 
260 /// Is a given type on the list of types?
261 /// Accepts a list of types and the first argument is the type
262 /// We are checking if it is listed among those that follow
263 template <typename T, typename... Types>
264 struct IsAnyOf;
265 
266 /// Two types remaining, end of the list
267 template <typename T, typename Tail>
268 struct IsAnyOf<T, Tail> : public std::is_same<T, Tail> {
269 };
270 
271 template <typename T, typename H, typename... Tail>
272 struct IsAnyOf<T, H, Tail...> {
273  static constexpr bool value = (std::is_same<T, H>::value ||
274  IsAnyOf<T, Tail...>::value);
275 };
276 
277 /// Tells if the specified type is one of fundamental types
278 /// that can be contained within a tensor.
279 /// We do not have raw fundamental types, rather a subset
280 /// of fundamental types is contained within tensors.
281 template <typename T>
282 struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
283  int32_t, int64_t, std::string, bool, MLFloat16,
284  double, uint32_t, uint64_t, BFloat16,
285  Int4x2, UInt4x2
286 #if !defined(DISABLE_FLOAT8_TYPES)
287  ,
288  Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
289 #endif
290  > {
291 };
292 
293 #if !defined(DISABLE_SPARSE_TENSORS)
294 /// Use "IsSparseTensorContainedType<T>::value" to test if a type T
295 /// is permitted as the element-type of a sparse-tensor.
296 
297 template <typename T>
298 struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
299  int32_t, int64_t, std::string, bool, MLFloat16,
300  double, uint32_t, uint64_t, BFloat16
301 #if !defined(DISABLE_FLOAT8_TYPES)
302  ,
303  Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
304 #endif
305  > {
306 };
307 #endif
308 
309 #if !defined(DISABLE_OPTIONAL_TYPE)
310 /// Tells if the specified type is one of ORT types
311 /// that can be contained within an optional struct.
312 template <typename T>
313 struct IsOptionalOrtType : public IsAnyOf<T, Tensor, TensorSeq> {
314 };
315 #endif
316 
317 /// This template's Get() returns a corresponding MLDataType
318 /// It dispatches the call to either GetTensorType<>() or
319 /// GetType<>()
320 template <typename T, bool TensorContainedType>
322 
323 template <typename T>
324 struct GetMLDataType<T, true> {
325  static MLDataType Get() {
326  return DataTypeImpl::GetTensorType<T>();
327  }
328 };
329 
330 template <typename T>
331 struct GetMLDataType<T, false> {
332  static MLDataType Get() {
333  return DataTypeImpl::GetType<T>();
334  }
335 };
336 
338  static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
339  ONNX_NAMESPACE::TypeProto& proto) {
340  proto.mutable_tensor_type()->set_elem_type(element_type);
341  }
342 };
343 
344 #if !defined(DISABLE_SPARSE_TENSORS)
346  static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
347  ONNX_NAMESPACE::TypeProto& proto) {
348  proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
349  }
350 };
351 #endif // !defined(DISABLE_SPARSE_TENSORS)
352 
353 #if !defined(DISABLE_ML_OPS)
354 /// Map helpers
355 
356 void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&,
357  ONNX_NAMESPACE::TypeProto&);
358 
360  // V can be either a primitive type (in which case it is a tensor)
361  // or other preregistered types
362  template <typename V>
365  }
366 
367  static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto* value_proto,
368  ONNX_NAMESPACE::TypeProto& proto) {
369  ORT_ENFORCE(value_proto != nullptr, "expected a registered ONNX type");
370  proto.mutable_map_type()->set_key_type(key_type);
371  CopyMutableMapValue(*value_proto, proto);
372  }
373 };
374 #endif
375 
376 /// Sequence helpers
377 
378 // Element type is a primitive type so we set it to a tensor<elemT>
379 void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
380  ONNX_NAMESPACE::TypeProto&);
381 
382 // helper to create TypeProto with minimal binary size impact
384  template <typename T>
387  }
388 
389  static void Set(const ONNX_NAMESPACE::TypeProto* elem_proto,
390  ONNX_NAMESPACE::TypeProto& proto) {
391  ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
392  CopyMutableSeqElement(*elem_proto, proto);
393  }
394 };
395 
396 /// Optional helpers
397 
398 void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
399  ONNX_NAMESPACE::TypeProto&);
400 
401 // helper to create TypeProto with minimal binary size impact
403  template <typename T, typename elemT>
405  if constexpr (std::is_same<T, Tensor>::value) {
406  return DataTypeImpl::GetTensorType<elemT>();
407  } else {
408  static_assert(std::is_same<T, TensorSeq>::value, "Unsupported element type for optional type");
409  return DataTypeImpl::GetSequenceTensorType<elemT>();
410  }
411  }
412 
413  static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
414  ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
415  CopyMutableOptionalElement(*elem_proto, proto);
416  }
417 };
418 
419 /// OpaqueTypes helpers
420 
421 void AssignOpaqueDomainName(const char* domain, const char* name,
422  ONNX_NAMESPACE::TypeProto& proto);
423 
424 } // namespace data_types_internal
425 
426 // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
427 // However, we do not allocate this type on heap.
428 #if defined(_MSC_VER) && !defined(__clang__)
429 #pragma warning(push)
430 #pragma warning(disable : 26436)
431 #endif
432 /// All tensors base
433 class TensorTypeBase : public DataTypeImpl {
434  public:
435  static MLDataType Type();
436 
437  /// We first compare type_proto pointers and then
438  /// if they do not match try to account for the case
439  /// where TypeProto was created ad-hoc and not queried from MLDataType
440  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
441 
442  DeleteFunc GetDeleteFunc() const override;
443 
444  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
445 
446  virtual MLDataType GetElementType() const {
447  // should never reach here.
448  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
449  }
450 
452 
453  protected:
454  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
455 
456  TensorTypeBase();
457  ~TensorTypeBase() override;
458 
459  private:
460  struct Impl;
461  Impl* impl_;
462 };
463 
464 /**
465  * \brief Tensor type. This type does not have a C++ type associated with
466  * it at registration time except the element type. One of the types mentioned
467  * above at IsTensorContainedType<> list is acceptable.
468  *
469  * \details
470  * Usage:
471  * ORT_REGISTER_TENSOR(ELEMENT_TYPE)
472  * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor<type>
473  * type. IsCompatible() currently ignores shape.
474  */
475 
476 template <typename elemT>
477 class TensorType : public TensorTypeBase {
478  public:
480  "Requires one of the tensor fundamental types");
481 
482  static MLDataType Type();
483 
484  /// Tensors only can contain basic data types
485  /// that have been previously registered with ONNXRuntime
486  MLDataType GetElementType() const override {
487  return DataTypeImpl::GetType<elemT>();
488  }
489 
490  private:
491  TensorType() {
492  using namespace data_types_internal;
493  TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
494  }
495 };
496 
497 #if defined(DISABLE_OPTIONAL_TYPE)
498 
499 // TODO is this still needed after removing kernel def hashes?
500 /// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
501 /// with disabled types to keep the ORT format model kernel hashes stable.
502 class DisabledTypeBase : public DataTypeImpl {
503  public:
504  static MLDataType Type();
505 
506  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
507  // We always want to return false for the IsCompatible() for a disabled type
508  // because this will ensure that no kernel supporting the disabled type will
509  // be matched to a model node requiring that type and the model load will
510  // result in failure.
511  return false;
512  }
513 
514  DeleteFunc GetDeleteFunc() const override {
515  ORT_THROW("Type is disabled in this build.");
516  }
517 
518  // This must work
519  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
520 
521  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase);
522 
523  protected:
524  // This must work
525  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
526 
527  DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
528  ~DisabledTypeBase() override;
529 
530  private:
531  struct Impl;
532  Impl* impl_;
533 };
534 
535 #endif
536 
537 #if !defined(DISABLE_SPARSE_TENSORS)
538 /// Common base-class for all sparse-tensors (with different element types).
540  public:
541  static MLDataType Type();
542 
543  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
544 
545  DeleteFunc GetDeleteFunc() const override;
546 
547  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
548 
549  virtual MLDataType GetElementType() const {
550  // should never reach here.
551  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
552  }
553 
555 
556  protected:
557  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
558 
560  ~SparseTensorTypeBase() override;
561 
562  private:
563  struct Impl;
564  Impl* impl_;
565 };
566 
567 template <typename elemT>
569  public:
571  "Requires one of the sparse-tensor fundamental types");
572 
573  static MLDataType Type();
574 
575  /// Return a MLDataType representing the element-type
576  MLDataType GetElementType() const override {
577  return DataTypeImpl::GetType<elemT>();
578  }
579 
580  private:
581  SparseTensorType() {
582  using namespace data_types_internal;
583  SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
584  }
585 };
586 
587 #endif // !defined(DISABLE_SPARSE_TENSORS)
588 
589 /// Common base-class for all optional types.
590 
591 #if !defined(DISABLE_OPTIONAL_TYPE)
593  public:
594  static MLDataType Type();
595 
596  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
597 
598  DeleteFunc GetDeleteFunc() const override {
599  // should never reach here.
600  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
601  }
602 
603  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
604 
605  virtual MLDataType GetElementType() const {
606  // should never reach here.
607  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
608  }
609 
610  OptionalTypeBase(const OptionalTypeBase&) = delete;
611  OptionalTypeBase& operator=(const OptionalTypeBase&) = delete;
612 
613  protected:
614  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
615 
617  ~OptionalTypeBase() override;
618 
619  private:
620  struct Impl;
621  Impl* impl_;
622 };
623 #endif
624 
625 // Derive from OptionalTypeBase if the Optional type support is enabled,
626 // else derive from DisabledTypeBase
627 template <typename T, typename elemT>
629 #if !defined(DISABLE_OPTIONAL_TYPE)
630  public OptionalTypeBase
631 #else
632  public DisabledTypeBase
633 #endif
634 {
635  public:
636  static MLDataType Type();
637 
638 #if !defined(DISABLE_OPTIONAL_TYPE)
640  "Requires one of the supported types: Tensor or TensorSeq");
641 
643  "Requires one of the tensor fundamental types");
644 
645  MLDataType GetElementType() const override {
646  return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
647  }
648 #endif
649 
650  private:
651 #if !defined(DISABLE_OPTIONAL_TYPE)
652  OptionalType()
653 #else
654  OptionalType() : DisabledTypeBase{DataTypeImpl::GeneralType::kOptional, 0}
655 #endif
656  {
657  using namespace data_types_internal;
658  OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>()->GetTypeProto(), MutableTypeProto());
659  }
660 }; // namespace onnxruntime
661 
662 /**
663  * \brief Provide a specialization for your C++ Non-tensor type
664  * so your implementation FromDataTypeContainer/ToDataTypeContainer
665  * functions correctly. Otherwise you get a default implementation
666  * which may not be what you need/want.
667  *
668  * This class is used to create OrtValue, fetch data from OrtValue via
669  * C/C++ APIs
670  */
671 template <class T>
673  static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
674  ORT_THROW("Not implemented");
675  }
676  static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) {
677  ORT_THROW("Not implemented");
678  }
679 };
680 
681 /**
682  * \brief Base type for all non-tensors, maps, sequences and opaques
683  */
685  public:
686  DeleteFunc GetDeleteFunc() const override = 0;
687 
688  virtual CreateFunc GetCreateFunc() const = 0;
689 
690  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
691 
692  // \brief Override for Non-tensor types to initialize non-tensor CPP
693  // data representation from data. The caller of the interface
694  // should have a shared definition of the data which is used to initialize
695  // CPP data representation. This is used from C API.
696  //
697  // \param data - pointer to a data container structure non_tensor type specific
698  // \param data_size - size of the data container structure, used for rudimentary checks
699  // \param output - reference to a default constructed non-tensor type
700  // \returns OrtValue
701  // \throw if there is an error
702  virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const;
703 
704  // \brief Override for Non-tensor types to fetch data from the internal CPP data representation
705  // The caller of the interface should have a shared definition of the data which is used to initialize
706  // CPP data representation. This is used from C API.
707  //
708  // \param input - OrtValue containing data
709  // \param data_size - size of the structure that is being passed for receiving data, used for
710  // validation
711  // \param data - pointer to receiving data structure
712  virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const;
713 
714  NonTensorTypeBase(const NonTensorTypeBase&) = delete;
716 
717  protected:
718  NonTensorTypeBase(size_t size);
719  ~NonTensorTypeBase() override;
720 
721  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
722 
723  bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
724 
725  bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
726 
727  bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
728 
729  private:
730  struct Impl;
731  Impl* impl_;
732 };
733 
734 // This is where T is the actual CPPRuntimeType
735 template <typename T>
737  private:
738  static void Delete(void* p) {
739  delete static_cast<T*>(p);
740  }
741 
742  public:
743  DeleteFunc GetDeleteFunc() const override {
744  return &Delete;
745  }
746 
747  CreateFunc GetCreateFunc() const override {
748  return []() -> void* { return new T(); };
749  }
750 
751  protected:
753 };
754 
755 #if !defined(DISABLE_ML_OPS)
756 /**
757  * \brief MapType. Use this type to register
758  * mapping types.
759  *
760  * \param T - cpp type that you wish to register as runtime MapType
761  *
762  * \details Usage: ORT_REGISTER_MAP(C++Type)
763  * The type is required to have mapped_type and
764  * key_type defined
765  */
766 template <typename CPPType>
767 class MapType : public NonTensorType<CPPType> {
768  public:
770  "Requires one of the tensor fundamental types as key");
771 
772  static MLDataType Type();
773 
774  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
775  return this->IsMapCompatible(type_proto);
776  }
777 
778  private:
779  MapType() {
780  using namespace data_types_internal;
781  MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
782  MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->GetTypeProto(),
783  this->MutableTypeProto());
784  }
785 };
786 #endif
787 
788 /**
789  * \brief SequenceType. Use to register sequence for non-tensor types.
790  *
791  * \param T - CPP type that you wish to register as Sequence
792  * runtime type.
793  *
794  * \details Usage: ORT_REGISTER_SEQ(C++Type)
795  * The type is required to have value_type defined
796  */
797 template <typename CPPType>
798 class SequenceType : public NonTensorType<CPPType> {
799  public:
800  static MLDataType Type();
801 
802  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
803  return this->IsSequenceCompatible(type_proto);
804  }
805 
806  private:
807  SequenceType() {
808  using namespace data_types_internal;
809  SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->GetTypeProto(),
810  this->MutableTypeProto());
811  }
812 };
813 
814 /**
815  * \brief SequenceTensorTypeBase serves as a base type class for
816  * Tensor sequences. Akin to TensorTypeBase.
817  * Runtime representation is always TensorSeq.
818  */
820  public:
821  static MLDataType Type();
822 
823  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
824 
825  virtual MLDataType GetElementType() const {
826  // should never reach here.
827  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
828  }
829 
830  DeleteFunc GetDeleteFunc() const override;
831 
832  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
833 
836 
837  protected:
840 
841  ONNX_NAMESPACE::TypeProto& MutableTypeProto();
842 
843  private:
844  struct Impl;
845  Impl* impl_;
846 };
847 #if defined(_MSC_VER) && !defined(__clang__)
848 #pragma warning(pop)
849 #endif
850 /**
851  * \brief SequenceTensorType. Use to register sequence for non-tensor types.
852  *
853  * \param CPPRuntime - We always use TensorSeq
854  *
855  * \param TensorElemType - one of the primitive types
856  *
857  * \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE()
858  * The type is required to have value_type defined
859  */
860 template <typename TensorElemType>
862  public:
864  "Requires one of the tensor fundamental types");
865 
866  static MLDataType Type();
867 
868  /// Return a MLDataType representing the element-type
869  MLDataType GetElementType() const override {
870  return DataTypeImpl::GetType<TensorElemType>();
871  }
872 
873  private:
875  using namespace data_types_internal;
876  SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->GetTypeProto(),
877  MutableTypeProto());
878  }
879 };
880 
881 /**
882  * \brief OpaqueType
883  *
884  * \tparam T - cpp runtume that implements the Opaque type
885  *
886  * \tparam const char D[] - domain must be extern to be unique
887  *
888  * \tparam const char N[] - name must be extern to be unique
889  *
890  * \details Only one CPP type can be associated with a particular
891  * OpaqueType registration
892  *
893  */
894 template <typename T, const char D[], const char N[]>
895 class OpaqueType : public NonTensorType<T> {
896  public:
897  static MLDataType Type();
898 
899  bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
900  return this->IsOpaqueCompatible(type_proto);
901  }
902 
903  void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override {
904  NonTensorTypeConverter<T>::FromContainer(this, data, data_size, output);
905  }
906 
907  void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override {
908  NonTensorTypeConverter<T>::ToContainer(input, data_size, data);
909  }
910 
911  private:
912  OpaqueType() {
914  }
915 };
916 
917 /**
918  * \brief PrimitiveDataTypeBase
919  * Base class for primitive Tensor contained types
920  *
921  * \details This class contains an integer constant that can be
922  * used for input data type dispatching. This class also stores the number of subelements per size units.
923  * Example: For int4, the size unit is 1 byte and the number of subelements is 2.
924  *
925  */
927  public:
928  bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
929  return false;
930  }
931 
932  const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
933  return nullptr;
934  }
935 
936  int32_t GetDataType() const {
937  return data_type_;
938  }
939 
940  int32_t GetNumSubElems() const {
941  return num_sub_elems_;
942  }
943 
944  bool HasSubElems() const {
945  return num_sub_elems_ > 1;
946  }
947 
948  protected:
949  PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems)
950  : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {}
951 
952  private:
953  const int32_t data_type_;
954  const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives.
955 };
956 
957 /**
958  * \brief PrimitiveDataType
959  * Typed specialization for primitive types.
960  * Concrete instances of this class are used by Tensor.
961  *
962  * \param T - primitive data type
963  *
964  */
965 template <typename T>
967  private:
968  static void Delete(void* p) {
969  delete static_cast<T*>(p);
970  }
971 
972  public:
973  static MLDataType Type();
974 
975  DeleteFunc GetDeleteFunc() const override {
976  return &Delete;
977  }
978 
979  private:
980  explicit PrimitiveDataType(int32_t num_sub_elems)
981  : PrimitiveDataTypeBase{sizeof(T),
982  utils::ToTensorProtoElementType<T>(), num_sub_elems} {
983  }
984 };
985 
987  return IsTensorType() ? static_cast<const TensorTypeBase*>(this) : nullptr;
988 }
989 
991  return IsTensorSequenceType() ? static_cast<const SequenceTensorTypeBase*>(this) : nullptr;
992 }
993 
994 #if !defined(DISABLE_SPARSE_TENSORS)
996  return IsSparseTensorType() ? static_cast<const SparseTensorTypeBase*>(this) : nullptr;
997 }
998 #endif
999 
1000 #if !defined(DISABLE_OPTIONAL_TYPE)
1002  return IsOptionalType() ? static_cast<const OptionalTypeBase*>(this) : nullptr;
1003 }
1004 #endif
1005 
1007  return IsNonTensorType() ? static_cast<const NonTensorTypeBase*>(this) : nullptr;
1008 }
1009 
1011  return IsPrimitiveDataType() ? static_cast<const PrimitiveDataTypeBase*>(this) : nullptr;
1012 }
1013 
1014 // Explicit specialization of base class template function
1015 // is only possible within the enclosing namespace scope,
1016 // thus a simple way to pre-instantiate a given template
1017 // at a registration time does not currently work and the macro
1018 // is needed.
1019 #define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
1020  template <> \
1021  MLDataType TensorType<ELEM_TYPE>::Type() { \
1022  static TensorType<ELEM_TYPE> tensor_type; \
1023  return &tensor_type; \
1024  } \
1025  template <> \
1026  MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
1027  return TensorType<ELEM_TYPE>::Type(); \
1028  }
1029 
1030 #if !defined(DISABLE_SPARSE_TENSORS)
1031 #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
1032  template <> \
1033  MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
1034  static SparseTensorType<ELEM_TYPE> tensor_type; \
1035  return &tensor_type; \
1036  } \
1037  template <> \
1038  MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
1039  return SparseTensorType<ELEM_TYPE>::Type(); \
1040  }
1041 #endif
1042 
1043 #define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
1044  template <> \
1045  MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
1046  static OptionalType<ORT_TYPE, TYPE> optional_type; \
1047  return &optional_type; \
1048  } \
1049  template <> \
1050  MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
1051  return OptionalType<ORT_TYPE, TYPE>::Type(); \
1052  }
1053 
1054 #if !defined(DISABLE_ML_OPS)
1055 #define ORT_REGISTER_MAP(TYPE) \
1056  template <> \
1057  MLDataType MapType<TYPE>::Type() { \
1058  static MapType<TYPE> map_type; \
1059  return &map_type; \
1060  } \
1061  template <> \
1062  MLDataType DataTypeImpl::GetType<TYPE>() { \
1063  return MapType<TYPE>::Type(); \
1064  }
1065 #endif
1066 
1067 #define ORT_REGISTER_SEQ(TYPE) \
1068  template <> \
1069  MLDataType SequenceType<TYPE>::Type() { \
1070  static SequenceType<TYPE> sequence_type; \
1071  return &sequence_type; \
1072  } \
1073  template <> \
1074  MLDataType DataTypeImpl::GetType<TYPE>() { \
1075  return SequenceType<TYPE>::Type(); \
1076  }
1077 
1078 #define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
1079  template <> \
1080  MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
1081  static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
1082  return &sequence_tensor_type; \
1083  } \
1084  template <> \
1085  MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
1086  return SequenceTensorType<ELEM_TYPE>::Type(); \
1087  }
1088 
1089 #define ORT_REGISTER_PRIM_TYPE(TYPE) \
1090  template <> \
1091  MLDataType PrimitiveDataType<TYPE>::Type() { \
1092  static PrimitiveDataType<TYPE> prim_data_type(1); \
1093  return &prim_data_type; \
1094  } \
1095  template <> \
1096  MLDataType DataTypeImpl::GetType<TYPE>() { \
1097  return PrimitiveDataType<TYPE>::Type(); \
1098  }
1099 
1100 // Registers a subbyte primitive.
1101 // Examples:
1102 // - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
1103 // - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
1104 #define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
1105  template <> \
1106  MLDataType PrimitiveDataType<TYPE>::Type() { \
1107  static PrimitiveDataType<TYPE> prim_data_type(NUM_SUB_ELEMS); \
1108  return &prim_data_type; \
1109  } \
1110  template <> \
1111  MLDataType DataTypeImpl::GetType<TYPE>() { \
1112  return PrimitiveDataType<TYPE>::Type(); \
1113  }
1114 
1115 #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
1116  template <> \
1117  MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
1118  static OpaqueType<CPPType, Domain, Name> opaque_type; \
1119  return &opaque_type; \
1120  } \
1121  template <> \
1122  MLDataType DataTypeImpl::GetType<CPPType>() { \
1123  return OpaqueType<CPPType, Domain, Name>::Type(); \
1124  }
1125 } // namespace onnxruntime
void AssignOpaqueDomainName(const char *domain, const char *name, ONNX_NAMESPACE::TypeProto &proto)
OpaqueTypes helpers.
std::vector< int64_t > VectorInt64
Definition: data_types.h:50
static void RegisterDataType(MLDataType)
static const TensorTypeBase * TensorTypeFromONNXEnum(int type)
PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems)
Definition: data_types.h:949
virtual const ONNX_NAMESPACE::TypeProto * GetTypeProto() const =0
Retrieves an instance of TypeProto for a given MLDataType.
static const char * ToString(MLDataType type)
bool IsTensorSequenceType() const
Definition: data_types.h:123
Base class for MLDataType.
Definition: data_types.h:77
static const std::vector< MLDataType > & AllNumericTensorTypesIRv9()
size_t Size() const
Definition: data_types.h:107
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:774
virtual MLDataType GetElementType() const
Definition: data_types.h:549
static void ToContainer(const OrtValue &, size_t, void *)
Definition: data_types.h:676
static const std::vector< MLDataType > & AllNumericTensorTypesIRv4()
void
Definition: png.h:1083
void FromDataContainer(const void *data, size_t data_size, OrtValue &output) const override
Definition: data_types.h:903
static const std::vector< MLDataType > & AllOptionalTypesIRv9()
SequenceTensorTypeBase serves as a base type class for Tensor sequences. Akin to TensorTypeBase. Runtime representation is always TensorSeq.
Definition: data_types.h:819
ONNX_NAMESPACE::TypeProto & MutableTypeProto()
static void FromContainer(MLDataType, const void *, size_t, OrtValue &)
Definition: data_types.h:673
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypes()
void ToDataContainer(const OrtValue &input, size_t data_size, void *data) const override
Definition: data_types.h:907
GLsizei const GLfloat * value
Definition: glcorearb.h:824
std::map< std::string, float > MapStringToFloat
Definition: data_types.h:36
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const =0
this API will be used to check type compatibility at runtime
static const std::vector< MLDataType > & AllTensorTypesIRv9()
static MLDataType GetTensorType()
std::map< int64_t, int64_t > MapInt64ToInt64
Definition: data_types.h:39
#define ORT_NOT_IMPLEMENTED(...)
Definition: common.h:166
MapType. Use this type to register mapping types.
Definition: data_types.h:767
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Sequence helpers.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypesIRv4()
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
Definition: data_types.h:576
CreateFunc GetCreateFunc() const override
Definition: data_types.h:747
static const std::vector< MLDataType > & AllOptionalTypes()
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const final
Retrieves an instance of TypeProto for a given MLDataType.
Definition: data_types.h:932
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:598
virtual MLDataType GetElementType() const
Definition: data_types.h:825
Common base-class for all sparse-tensors (with different element types).
Definition: data_types.h:539
static const std::vector< MLDataType > & AllTensorTypesIRv4()
static const std::vector< MLDataType > & AllOptionalTypesIRv4()
static MLDataType GetSequenceTensorType()
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Optional helpers.
const ONNX_NAMESPACE::TypeProto * GetTypeProto() const override
Retrieves an instance of TypeProto for a given MLDataType.
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypesIRv9()
static MLDataType Type()
static const std::vector< MLDataType > & AllIEEEFloatTensorTypes()
static const std::vector< MLDataType > & AllFixedSizeTensorTypesIRv4()
static const std::vector< MLDataType > & AllFixedSizeTensorTypes()
static void Set(const onnx::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:413
static const std::vector< MLDataType > & AllTensorTypes()
const SparseTensorTypeBase * AsSparseTensorType() const
Definition: data_types.h:995
All tensors base.
Definition: data_types.h:433
GLint GLint GLsizei GLint GLenum GLenum type
Definition: glcorearb.h:108
static MLDataType GetSparseTensorType()
static const std::vector< MLDataType > & AllSequenceTensorTypesIRv4()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypes()
static const std::vector< MLDataType > & AllSequenceTensorTypes()
std::map< int64_t, float > MapInt64ToFloat
Definition: data_types.h:40
void(*)(void *) DeleteFunc
Definition: data_types.h:70
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase)
static void Set(const ONNX_NAMESPACE::TypeProto *elem_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:389
virtual MLDataType GetElementType() const
Definition: data_types.h:446
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypes()
std::map< std::string, int64_t > MapStringToInt64
Definition: data_types.h:35
const GeneralType type_
Definition: data_types.h:89
const NonTensorTypeBase * AsNonTensorType() const
Definition: data_types.h:1006
Common base-class for all optional types.
Definition: data_types.h:592
static const std::vector< MLDataType > & AllTensorAndSequenceTensorTypesIRv4()
PrimitiveDataType Typed specialization for primitive types. Concrete instances of this class are used...
Definition: data_types.h:966
bool IsNonTensorType() const
Definition: data_types.h:135
const TensorTypeBase * AsTensorType() const
Definition: data_types.h:986
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
Definition: data_types.h:1010
bool IsTensorType() const
Definition: data_types.h:119
MLDataType GetElementType() const override
Definition: data_types.h:645
static MLDataType GetDataType(const std::string &)
DeleteFunc GetDeleteFunc() const override
STATIC_INLINE uint64_t H(uint64_t x, uint64_t y, uint64_t mul, int r)
Definition: farmhash.h:762
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
Definition: data_types.h:926
GLuint const GLchar * name
Definition: glcorearb.h:786
std::map< std::string, std::string > MapStringToString
Predefined registered types.
Definition: data_types.h:34
static const SequenceTensorTypeBase * SequenceTensorTypeFromONNXEnum(int type)
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto &, ONNX_NAMESPACE::TypeProto &)
Map helpers.
static MLDataType GetOptionalType()
void *(*)( CreateFunc)
Definition: data_types.h:71
const DataTypeImpl * MLDataType
Definition: data_types.h:68
static const std::vector< MLDataType > & AllSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypes()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:899
std::map< int64_t, double > MapInt64ToDouble
Definition: data_types.h:41
#define ORT_THROW(...)
Definition: common.h:162
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllNumericTensorTypes()
GLsizeiptr size
Definition: glcorearb.h:664
virtual MLDataType GetElementType() const
Definition: data_types.h:605
std::map< int64_t, std::string > MapInt64ToString
Definition: data_types.h:38
SequenceType. Use to register sequence for non-tensor types.
Definition: data_types.h:798
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &type_proto) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:802
const OptionalTypeBase * AsOptionalType() const
Definition: data_types.h:1001
SequenceTensorType. Use to register sequence for non-tensor types.
Definition: data_types.h:861
LeafData & operator=(const LeafData &)=delete
DataTypeImpl(GeneralType type, size_t size)
Definition: data_types.h:93
bool IsSparseTensorType() const
Definition: data_types.h:127
MLDataType GetElementType() const override
Definition: data_types.h:486
virtual ~DataTypeImpl()=default
std::map< std::string, double > MapStringToDouble
Definition: data_types.h:37
GA_API const UT_StringHolder N
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:338
static const std::vector< MLDataType > & AllFixedSizeTensorTypesIRv9()
bool IsCompatible(const ONNX_NAMESPACE::TypeProto &) const override
this API will be used to check type compatibility at runtime
Definition: data_types.h:928
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypes()
Tensor type. This type does not have a C++ type associated with it at registration time except the el...
Definition: data_types.h:477
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypesIRv4()
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypesIRv4()
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:743
DeleteFunc GetDeleteFunc() const override
Definition: data_types.h:975
bool IsPrimitiveDataType() const
Definition: data_types.h:139
static const std::vector< MLDataType > & AllOptionalAndTensorAndSequenceTensorTypesIRv9()
static const std::vector< MLDataType > & AllTensorAndSequenceTensorAndOptionalTypesIRv9()
static const std::vector< MLDataType > & AllFixedSizeSequenceTensorTypesIRv4()
GLsizei GLenum GLenum * types
Definition: glcorearb.h:2542
static MLDataType Type()
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
Base type for all non-tensors, maps, sequences and opaques.
Definition: data_types.h:684
const SequenceTensorTypeBase * AsSequenceTensorType() const
Definition: data_types.h:990
std::vector< std::string > VectorString
Definition: data_types.h:49
static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto &proto)
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:346
static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto *value_proto, ONNX_NAMESPACE::TypeProto &proto)
Definition: data_types.h:367
std::vector< MapInt64ToFloat > VectorMapInt64ToFloat
Definition: data_types.h:45
static const SparseTensorTypeBase * SparseTensorTypeFromONNXEnum(int type)
std::vector< MapStringToFloat > VectorMapStringToFloat
Definition: data_types.h:44
static const std::vector< MLDataType > & AllFixedSizeTensorAndSequenceTensorTypesIRv9()
virtual DeleteFunc GetDeleteFunc() const =0
Provide a specialization for your C++ Non-tensor type so your implementation FromDataTypeContainer/To...
Definition: data_types.h:672
Definition: format.h:1821
bool IsOptionalType() const
Definition: data_types.h:131
MLDataType GetElementType() const override
Return a MLDataType representing the element-type.
Definition: data_types.h:869
static MLDataType GetType()