HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
data_types_internal.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <array>
7 #include <cassert>
8 #include <cstdint>
9 #include <string>
10 #include <type_traits>
11 #include <vector>
12 
13 #include "boost/mp11.hpp"
14 
15 #include "core/common/common.h"
17 #ifndef SHARED_PROVIDER
18 #include "core/common/type_list.h"
20 #include "core/graph/onnx_protobuf.h"
21 #endif
22 
23 namespace onnxruntime {
24 namespace utils {
25 
26 // The following primitives are strongly recommended for switching on tensor input datatypes for
27 // kernel implementations.
28 //
29 // 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
30 // DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
31 // 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
32 // if you have a standalone MLDatatType with a sequence of if/else statements.
33 // 3) For something in between, we suggest to use CallDispatcher pattern.
34 //
35 // Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
36 // Every primitive type carries with it an integer constant that can be used for quick switching on types.
37 
38 #if !defined(DISABLE_FLOAT8_TYPES)
39 
40 #define DispatchOnTensorType(tensor_type, function, ...) \
41  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
42  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
43  function<float>(__VA_ARGS__); \
44  break; \
45  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
46  function<bool>(__VA_ARGS__); \
47  break; \
48  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
49  function<double>(__VA_ARGS__); \
50  break; \
51  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
52  function<std::string>(__VA_ARGS__); \
53  break; \
54  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
55  function<int8_t>(__VA_ARGS__); \
56  break; \
57  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
58  function<uint8_t>(__VA_ARGS__); \
59  break; \
60  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
61  function<int16_t>(__VA_ARGS__); \
62  break; \
63  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
64  function<uint16_t>(__VA_ARGS__); \
65  break; \
66  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
67  function<int32_t>(__VA_ARGS__); \
68  break; \
69  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
70  function<uint32_t>(__VA_ARGS__); \
71  break; \
72  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
73  function<int64_t>(__VA_ARGS__); \
74  break; \
75  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
76  function<uint64_t>(__VA_ARGS__); \
77  break; \
78  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
79  function<MLFloat16>(__VA_ARGS__); \
80  break; \
81  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
82  function<BFloat16>(__VA_ARGS__); \
83  break; \
84  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
85  function<Float8E4M3FN>(__VA_ARGS__); \
86  break; \
87  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
88  function<Float8E4M3FNUZ>(__VA_ARGS__); \
89  break; \
90  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
91  function<Float8E5M2>(__VA_ARGS__); \
92  break; \
93  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
94  function<Float8E5M2FNUZ>(__VA_ARGS__); \
95  break; \
96  case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
97  function<Int4x2>(__VA_ARGS__); \
98  break; \
99  case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
100  function<UInt4x2>(__VA_ARGS__); \
101  break; \
102  default: \
103  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
104  }
105 
106 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
107  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
108  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
109  retval = function<float>(__VA_ARGS__); \
110  break; \
111  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
112  retval = function<bool>(__VA_ARGS__); \
113  break; \
114  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
115  retval = function<double>(__VA_ARGS__); \
116  break; \
117  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
118  retval = function<std::string>(__VA_ARGS__); \
119  break; \
120  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
121  retval = function<int8_t>(__VA_ARGS__); \
122  break; \
123  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
124  retval = function<uint8_t>(__VA_ARGS__); \
125  break; \
126  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
127  retval = function<uint16_t>(__VA_ARGS__); \
128  break; \
129  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
130  retval = function<int16_t>(__VA_ARGS__); \
131  break; \
132  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
133  retval = function<int32_t>(__VA_ARGS__); \
134  break; \
135  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
136  retval = function<uint32_t>(__VA_ARGS__); \
137  break; \
138  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
139  retval = function<int64_t>(__VA_ARGS__); \
140  break; \
141  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
142  retval = function<uint64_t>(__VA_ARGS__); \
143  break; \
144  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
145  retval = function<MLFloat16>(__VA_ARGS__); \
146  break; \
147  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
148  retval = function<BFloat16>(__VA_ARGS__); \
149  break; \
150  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
151  retval = function<Float8E4M3FN>(__VA_ARGS__); \
152  break; \
153  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
154  retval = function<Float8E4M3FNUZ>(__VA_ARGS__); \
155  break; \
156  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
157  retval = function<Float8E5M2>(__VA_ARGS__); \
158  break; \
159  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
160  retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
161  break; \
162  case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
163  retval = function<Int4x2>(__VA_ARGS__); \
164  break; \
165  case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
166  retval = function<UInt4x2>(__VA_ARGS__); \
167  break; \
168  default: \
169  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
170  }
171 
172 #else
173 
174 #define DispatchOnTensorType(tensor_type, function, ...) \
175  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
176  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
177  function<float>(__VA_ARGS__); \
178  break; \
179  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
180  function<bool>(__VA_ARGS__); \
181  break; \
182  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
183  function<double>(__VA_ARGS__); \
184  break; \
185  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
186  function<std::string>(__VA_ARGS__); \
187  break; \
188  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
189  function<int8_t>(__VA_ARGS__); \
190  break; \
191  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
192  function<uint8_t>(__VA_ARGS__); \
193  break; \
194  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
195  function<int16_t>(__VA_ARGS__); \
196  break; \
197  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
198  function<uint16_t>(__VA_ARGS__); \
199  break; \
200  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
201  function<int32_t>(__VA_ARGS__); \
202  break; \
203  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
204  function<uint32_t>(__VA_ARGS__); \
205  break; \
206  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
207  function<int64_t>(__VA_ARGS__); \
208  break; \
209  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
210  function<uint64_t>(__VA_ARGS__); \
211  break; \
212  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
213  function<MLFloat16>(__VA_ARGS__); \
214  break; \
215  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
216  function<BFloat16>(__VA_ARGS__); \
217  break; \
218  case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
219  function<Int4x2>(__VA_ARGS__); \
220  break; \
221  case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
222  function<UInt4x2>(__VA_ARGS__); \
223  break; \
224  default: \
225  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
226  }
227 
228 #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
229  switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
230  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
231  retval = function<float>(__VA_ARGS__); \
232  break; \
233  case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
234  retval = function<bool>(__VA_ARGS__); \
235  break; \
236  case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
237  retval = function<double>(__VA_ARGS__); \
238  break; \
239  case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
240  retval = function<std::string>(__VA_ARGS__); \
241  break; \
242  case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
243  retval = function<int8_t>(__VA_ARGS__); \
244  break; \
245  case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
246  retval = function<uint8_t>(__VA_ARGS__); \
247  break; \
248  case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
249  retval = function<uint16_t>(__VA_ARGS__); \
250  break; \
251  case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
252  retval = function<int16_t>(__VA_ARGS__); \
253  break; \
254  case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
255  retval = function<int32_t>(__VA_ARGS__); \
256  break; \
257  case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
258  retval = function<uint32_t>(__VA_ARGS__); \
259  break; \
260  case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
261  retval = function<int64_t>(__VA_ARGS__); \
262  break; \
263  case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
264  retval = function<uint64_t>(__VA_ARGS__); \
265  break; \
266  case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
267  retval = function<MLFloat16>(__VA_ARGS__); \
268  break; \
269  case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
270  retval = function<BFloat16>(__VA_ARGS__); \
271  break; \
272  case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
273  retval = function<Int4x2>(__VA_ARGS__); \
274  break; \
275  case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
276  retval = function<UInt4x2>(__VA_ARGS__); \
277  break; \
278  default: \
279  ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
280  }
281 
282 #endif
283 
284 ////////////////////////////////////////////////////////////////////////////////
285 /// Use the following primitives if you have a few types to switch on so you
286 // can write a short sequence of if/else statements.
287 
288 // This is a frequently used check so we make a separate utility function.
289 inline bool IsDataTypeString(MLDataType dt_type) {
290  auto prim_type = dt_type->AsPrimitiveDataType();
291  return (prim_type != nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
292 }
293 
294 // Test if MLDataType is a concrete type of PrimitiveDataTypeBase
295 // and it is T
296 template <class T>
297 inline bool IsPrimitiveDataType(MLDataType dt_type) {
298  auto prim_type = dt_type->AsPrimitiveDataType();
299  return (prim_type != nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
300 }
301 
302 // Use after AsPrimitiveDataType() is successful
303 // Check if PrimitiveDataTypeBase is of type T
304 template <class T>
305 inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) {
306  assert(prim_type != nullptr);
307  return prim_type->GetDataType() == ToTensorProtoElementType<T>();
308 }
309 
310 // This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226
311 // GCC until very recently does not support template parameter pack expansion within lambda context.
312 namespace mltype_dispatcher_internal {
313 
314 // T - type handled by this helper
316  int32_t dt_type_; // Type currently dispatched
317  size_t called_;
318 
319  public:
320  explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {}
321 
322  // Must return integer to be in a expandable context
323  template <class T, class Fn, class... Args>
324  int Invoke(Fn&& fn, Args&&... args) {
325  if (utils::ToTensorProtoElementType<T>() == dt_type_) {
326  std::forward<Fn>(fn)(std::forward<Args>(args)...);
327  ++called_;
328  }
329  return 0;
330  }
331 
332  void CheckCalledOnce() const {
333  ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
334  }
335 };
336 
337 // Default policy is to throw an exception.
338 // Other policies may set the second result argument accordingly.
339 template <class Ret>
341  void operator()(int32_t dt_type, Ret& /*result*/) const {
342  ORT_THROW("Unsupported data type: ", dt_type);
343  }
344 };
345 
346 // Helper with the result type
347 template <class Ret, class UnsupportedPolicy>
349  int32_t dt_type_; // Type currently dispatched
350  size_t called_;
351  Ret result_;
352 
353  public:
354  explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {}
355 
356  Ret Get() {
357  // No type was invoked
358  if (called_ == 0) {
359  UnsupportedPolicy()(dt_type_, result_);
360  }
361  return result_;
362  }
363 
364  // Must return integer to be in a expandable context
365  template <class T, class Fn, class... Args>
366  int Invoke(Fn&& fn, Args&&... args) {
367  if (utils::ToTensorProtoElementType<T>() == dt_type_) {
368  result_ = std::forward<Fn>(fn)(std::forward<Args>(args)...);
369  ++called_;
370  }
371  return 0;
372  }
373 };
374 
375 template <typename T>
377  std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
378 
380  std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
381 
382 } // namespace mltype_dispatcher_internal
383 
384 /**
385  * This class helps to efficiently dispatch calls to implementation function
386  * objects with a tensor element type template argument.
387  *
388  * The constructor accepts a value corresponding to a tensor element type.
389  * For example, it can be obtained from:
390  * input_tensor->GetElementType()
391  *
392  * The Invoke member functions will instantiate and invoke the provided
393  * function object template, Fn. Fn must be default constructible. Fn must also
394  * have a tensor element type template argument. This type template argument
395  * will be the type that corresponds to the value given in the constructor.
396  * These functions accept and forward arbitrary function arguments. They ensure
397  * that Fn is called once with the type specified in the constructor.
398  *
399  * @tparam Types The types supported by the implementation. This should be a
400  * set of ONNX tensor element types that are supported by ORT.
401  */
402 template <typename... Types>
404  using SupportedTypeList = TypeList<Types...>;
405  using SupportedTensorProtoElementTypeList =
406  boost::mp11::mp_transform<
408 
409  static_assert(
410  boost::mp11::mp_and<
411  boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
412  boost::mp11::mp_not<
413  boost::mp11::mp_set_contains<
414  SupportedTensorProtoElementTypeList,
416  "Types must map to a unique set of ONNX tensor element data types supported by ORT.");
417 
418  int32_t dt_type_;
419 
420  public:
421  /**
422  * Constructor.
423  * @param dt_type The value corresponding to the tensor element type to be
424  * dispatched to. This can be obtained from
425  * input_tensor->GetElementType() or
426  * utils::ToTensorProtoElementType<T>().
427  */
428  explicit MLTypeCallDispatcher(int32_t dt_type) noexcept : dt_type_(dt_type) {}
429 
430  /**
431  * Invokes Fn<T> with the specified arguments.
432  *
433  * @tparam Fn The function object template.
434  * @tparam Args The argument types.
435  */
436  template <template <typename...> class Fn, typename... Args>
437  void Invoke(Args&&... args) const {
438  InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
439  }
440 
441  /**
442  * Invokes Fn<..., T> with leading template arguments and the specified
443  * arguments.
444  *
445  * @tparam Fn The function object template.
446  * @tparam LeadingTemplateArgTypeList A type list of the leading template
447  * arguments.
448  * @tparam Args The argument types.
449  */
450  template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
451  void InvokeWithLeadingTemplateArgs(Args&&... args) const {
452  static_assert(
454  "LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
455 
457 
458  // given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
459  // call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
460  static_cast<void>(std::array<int, sizeof...(Types)>{
461  helper.template Invoke<Types>(
462  boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
463  std::forward<Args>(args)...)...});
464 
465  // avoid "unused parameter" warning for the case where Types is empty
466  static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
467 
468  helper.CheckCalledOnce();
469  }
470 
471  /**
472  * Invokes Fn<T> with the specified arguments and returns the result.
473  *
474  * @tparam Ret The return type. Fn should return a type convertible to Ret.
475  * @tparam Fn The function object template.
476  * @tparam Args The argument types.
477  */
478  template <class Ret, template <typename...> class Fn, typename... Args>
479  Ret InvokeRet(Args&&... args) const {
482  std::forward<Args>(args)...);
483  }
484 
485  /**
486  * Invokes Fn<T> with the specified arguments and returns the result.
487  *
488  * @tparam Ret The return type. Fn should return a type convertible to Ret.
489  * @tparam Fn The function object template.
490  * @tparam UnsupportedPolicy The policy used to handle unsupported types.
491  * See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
492  * for an example.
493  * @tparam Args The argument types.
494  */
495  template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
496  Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
498  Ret, Fn, UnsupportedPolicy, TypeList<>>(
499  std::forward<Args>(args)...);
500  }
501 
502  /**
503  * Invokes Fn<..., T> with leading template arguments and the specified
504  * arguments and returns the result.
505  *
506  * @tparam Ret The return type. Fn should return a type convertible to Ret.
507  * @tparam Fn The function object template.
508  * @tparam LeadingTemplateArgTypeList A type list of the leading template
509  * arguments.
510  * @tparam Args The argument types.
511  */
512  template <class Ret, template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
513  Ret InvokeRetWithLeadingTemplateArgs(Args&&... args) const {
515  Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>, LeadingTemplateArgTypeList>(
516  std::forward<Args>(args)...);
517  }
518 
519  /**
520  * Invokes Fn<..., T> with leading template arguments and the specified
521  * arguments and returns the result.
522  *
523  * @tparam Ret The return type. Fn should return a type convertible to Ret.
524  * @tparam Fn The function object template.
525  * @tparam UnsupportedPolicy The policy used to handle unsupported types.
526  * See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
527  * for an example.
528  * @tparam LeadingTemplateArgTypeList A type list of the leading template
529  * arguments.
530  * @tparam Args The argument types.
531  */
532  template <class Ret,
533  template <typename...> class Fn,
534  class UnsupportedPolicy,
535  typename LeadingTemplateArgTypeList,
536  typename... Args>
539 
540  // given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
541  // call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
542  static_cast<void>(std::array<int, sizeof...(Types)>{
543  helper.template Invoke<Types>(
544  boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
545  std::forward<Args>(args)...)...});
546 
547  // avoid "unused parameter" warning for the case where Types is empty
548  static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
549 
550  return helper.Get();
551  }
552 };
553 
554 // the type MLTypeCallDispatcher<T...> given a type list L<T...>
555 template <typename L>
556 using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher, L>;
557 
558 namespace data_types_internal {
559 
560 enum class ContainerType : uint16_t {
561  kUndefined = 0,
562  kTensor = 1,
563  kMap = 2,
564  kSequence = 3,
565  kOpaque = 4,
566  kOptional = 5
567 };
568 
569 class TypeNode {
570  // type_ is a TypeProto value case enum
571  // that may be a kTypeTensor, kTypeMap, kTypeSequence
572  // prim_type_ is a TypeProto_DataType enum that has meaning
573  // - for Tensor then prim_type_ is the contained type
574  // - for Map prim_type is the key type. Next entry describes map value
575  // - For sequence prim_type_ is not used and has no meaning. Next entry
576  // describes the value for the sequence
577  // Tensor is always the last entry as it describes a contained primitive type.
578  ContainerType type_;
579  uint16_t prim_type_;
580 
581  public:
582  TypeNode(ContainerType type, int32_t prim_type) noexcept {
583  type_ = type;
584  prim_type_ = static_cast<uint16_t>(prim_type);
585  }
586 
587  bool IsType(ContainerType type) const noexcept {
588  return type_ == type;
589  }
590 
591  bool IsPrimType(int32_t prim_type) const noexcept {
592  return prim_type_ == static_cast<uint16_t>(prim_type);
593  }
594 };
595 
596 } // namespace data_types_internal
597 
598 ////////////////////////////////////////////////////////////////////
599 /// Provides generic interface to test whether MLDataType is a Sequence,
600 /// Map or an Opaque type including arbitrary recursive definitions
601 /// without querying DataTypeImpl::GetType<T> for all known complex types
602 
603 // T is a sequence contained element type
604 // If returns true then we know that the runtime
605 // representation is std::vector<T>
606 // T itself can be a runtime representation of another
607 // sequence, map, opaque type or a tensor
608 //
609 // That is it can be std::vector or a std::map
610 // If T is a primitive type sequence is tested whether it contains
611 // tensors of that type
612 //
613 // If T is an opaque type, then it is only tested to be opaque but not exactly
614 // a specific opaque type. To Test for a specific Opaque type use IsOpaqueType() below
615 //
616 // This class examines the supplied MLDataType and records
617 // its information in a vector so any subsequent checks for Sequences and Maps
618 // are quick.
620  using Cont = std::vector<data_types_internal::TypeNode>;
621  Cont types_;
622 
623  // Default IsContainerOfType is for Opaque type
624  template <class T>
625  struct IsContainerOfType {
626  static bool check(const Cont& c, size_t index) {
627  if (index >= c.size()) {
628  return false;
629  }
630  return c[index].IsType(data_types_internal::ContainerType::kOpaque);
631  }
632  };
633 
634  // Handles the case where sequence element is also a sequence
635  template <class T>
636  struct IsContainerOfType<std::vector<T>> {
637  static bool check(const Cont& c, size_t index) {
638  if (index >= c.size()) {
639  return false;
640  }
641  if (c[index].IsType(data_types_internal::ContainerType::kSequence)) {
642  ORT_ENFORCE(++index < c.size(), "Sequence is missing type entry for its element");
643  constexpr int32_t prim_type = ToTensorProtoElementType<T>();
644  // Check if this is a primitive type and it matches
645  if constexpr (prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
646  return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
647  c[index].IsPrimType(prim_type);
648  } else {
649  // T is not primitive, check next entry for non-primitive proto
650  return IsContainerOfType<T>::check(c, index);
651  }
652  }
653  return false;
654  }
655  };
656 
657  template <class K, class V>
658  struct IsContainerOfType<std::map<K, V>> {
659  static bool check(const Cont& c, size_t index) {
660  static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
661  "Map Key can not be a non-primitive type");
662  if (index >= c.size()) {
663  return false;
664  }
665  if (!c[index].IsType(data_types_internal::ContainerType::kMap)) {
666  return false;
667  }
668  constexpr int32_t key_type = ToTensorProtoElementType<K>();
669  if (!c[index].IsPrimType(key_type)) {
670  return false;
671  }
672  ORT_ENFORCE(++index < c.size(), "Map is missing type entry for its value");
673  constexpr int32_t val_type = ToTensorProtoElementType<V>();
674  if constexpr (val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
675  return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
676  c[index].IsPrimType(val_type);
677  } else
678  return IsContainerOfType<V>::check(c, index);
679  }
680  };
681 
682  public:
683  explicit ContainerChecker(MLDataType);
684  ~ContainerChecker() = default;
685 
686  bool IsMap() const noexcept {
687  assert(!types_.empty());
688  return types_[0].IsType(data_types_internal::ContainerType::kMap);
689  }
690 
691  bool IsSequence() const noexcept {
692  assert(!types_.empty());
693  return types_[0].IsType(data_types_internal::ContainerType::kSequence);
694  }
695 
696  template <class T>
697  bool IsSequenceOf() const {
698  assert(!types_.empty());
699  return IsContainerOfType<std::vector<T>>::check(types_, 0);
700  }
701 
702  template <class K, class V>
703  bool IsMapOf() const {
704  assert(!types_.empty());
705  return IsContainerOfType<std::map<K, V>>::check(types_, 0);
706  }
707 };
708 
709 bool IsOpaqueType(MLDataType ml_type, const char* domain, const char* name);
710 
711 } // namespace utils
712 } // namespace onnxruntime
type
Definition: core.h:556
typedef int(APIENTRYP RE_PFNGLXSWAPINTERVALSGIPROC)(int)
Base class for MLDataType.
Definition: data_types.h:77
MLTypeCallDispatcher(int32_t dt_type) noexcept
Ret InvokeRetWithUnsupportedPolicy(Args &&...args) const
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args &&...args) const
GLsizei const GLfloat * value
Definition: glcorearb.h:824
bool IsType(ContainerType type) const noexcept
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
Ret InvokeRetWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimType(int32_t prim_type) const noexcept
boost::mp11::mp_apply< MLTypeCallDispatcher, L > MLTypeCallDispatcherFromTypeList
GLint GLint GLsizei GLint GLenum GLenum type
Definition: glcorearb.h:108
void InvokeWithLeadingTemplateArgs(Args &&...args) const
bool IsPrimitiveDataType(MLDataType dt_type)
const PrimitiveDataTypeBase * AsPrimitiveDataType() const
Definition: data_types.h:1010
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED > UndefinedTensorProtoElementTypeConstant
#define ORT_UNUSED_PARAMETER(x)
Definition: common.h:47
bool IsOpaqueType(MLDataType ml_type, const char *domain, const char *name)
PrimitiveDataTypeBase Base class for primitive Tensor contained types.
Definition: data_types.h:926
GLuint const GLchar * name
Definition: glcorearb.h:786
std::integral_constant< ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType< T >()> TensorProtoElementTypeConstant
#define ORT_THROW(...)
Definition: common.h:162
bool IsDataTypeString(MLDataType dt_type)
Use the following primitives if you have a few types to switch on so you.
TypeNode(ContainerType type, int32_t prim_type) noexcept
GLuint index
Definition: glcorearb.h:786
**If you just want to fire and args
Definition: thread.h:618