HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_node_proto_helper.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #ifndef SHARED_PROVIDER
7 #include "core/common/status.h"
10 #include "core/common/gsl.h"
11 #endif
12 
13 #ifdef __has_attribute
14 #define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x)
15 #else
16 #define ORT_HAVE_ATTRIBUTE(x) 0
17 #endif
18 
19 #if ORT_HAVE_ATTRIBUTE(nodiscard)
20 #define MUST_USE_RESULT [[nodiscard]]
21 #elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result)
22 #define MUST_USE_RESULT __attribute__((warn_unused_result))
23 #else
24 #define MUST_USE_RESULT
25 #endif
26 
27 class IMLOpKernel;
28 
29 namespace onnxruntime {
30 
31 /**
32  A set of wrappers with common signatures for use with both OpKernelInfo
33  (as its base class) and InferenceContext. Used by ABI kernels for both
34  shape / type inference and kernel construction
35 */
36 template <class Impl_t>
38  public:
39  explicit OpNodeProtoHelper(const Impl_t* impl) : impl_(impl) {}
40 
41  /**
42  Get a single attribute
43  Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
44  */
45  template <typename T>
47 
48  /**
49  Get a single attribute
50  Call this function only when a default value for an optional attribute isn't specified in the op schema
51  */
52  template <typename T>
53  T GetAttrOrDefault(const std::string& name, const T& default_value) const {
54  T tmp;
55  return GetAttr<T>(name, &tmp).IsOK() ? tmp : default_value;
56  }
57 
58  /**
59  Get a single attribute
60  Call this function only when a default value for an optional attribute isn't specified in the op schema
61  */
62  template <typename T>
63  void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const {
64  if (!GetAttr<T>(name, value).IsOK())
65  *value = default_value;
66  }
67 
68  /**
69  Get repeated attributes
70  Call this function only when a default value for an optional attribute isn't specified in the op schema
71  */
72  template <typename T>
73  MUST_USE_RESULT std::vector<T> GetAttrsOrDefault(const std::string& name, const std::vector<T>& default_value = std::vector<T>{}) const {
74  std::vector<T> tmp;
75  return GetAttrs<T>(name, tmp).IsOK() ? tmp : default_value;
76  }
77 
78  /// <summary>
79  /// Return a gsl::span that points to an array of primitive types held by AttributeProto
80  /// This function allows to avoid copying big attributes locally into a kernel and operate on
81  /// AttributeProto data directly.
82  ///
83  /// Does not apply to strings, Tensors and Sparse Tensors that require special treatment.
84  /// </summary>
85  /// <typeparam name="T">Primitive type contained in the array</typeparam>
86  /// <param name="name">Attribute name</param>
87  /// <param name="values">Attribute data in a span, out parameter</param>
88  /// <returns>Status</returns>
89  template <typename T>
90  MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span<const T>& values) const;
91 
93 
96  return GetAttrs(name, tmp).IsOK() ? tmp : default_value;
97  }
98 
99  /**
100  Get repeated attributes
101  */
102  template <typename T>
103  MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector<T>& values) const;
104 
105  template <typename T>
106  MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span<T> values) const;
107 
109  std::vector<std::reference_wrapper<const std::string>>& refs) const;
110 
111  uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type,
112  const std::string& name) const noexcept;
113 
114  bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type,
115  const std::string& name) const noexcept;
116 
117  uint32_t GetInputCount() const {
118  return gsl::narrow_cast<uint32_t>(impl_->getNumInputs());
119  }
120 
121  uint32_t GetOutputCount() const {
122  return gsl::narrow_cast<uint32_t>(impl_->getNumOutputs());
123  }
124 
125  const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const {
126  return impl_->getInputType(index);
127  }
128 
129  const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const {
130  // Work around lack of a const method from the onnx InferenceContext interface
131  return const_cast<Impl_t*>(impl_)->getOutputType(index);
132  }
133 
134  // Try to query an attribute, returning nullptr if it doesn't exist
135  const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const {
136  return impl_->getAttribute(name);
137  }
138 
139  const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const {
140  const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name);
141  ORT_ENFORCE(attr != nullptr);
142  return attr;
143  }
144 
145  private:
146  OpNodeProtoHelper() = delete;
147  const Impl_t* impl_ = nullptr;
148 };
149 
150 // The methods on the following class are called by OpNodeProtoHelper, implementing
151 // the same signatures as InferenceContext other than const-ness.
153  public:
154  explicit ProtoHelperNodeContext(const onnxruntime::Node& node) : node_(node) {}
155  ProtoHelperNodeContext() = delete;
156 
157  const ONNX_NAMESPACE::AttributeProto* getAttribute(const std::string& name) const;
158  size_t getNumInputs() const;
159  const ONNX_NAMESPACE::TypeProto* getInputType(size_t index) const;
160  size_t getNumOutputs() const;
161  const ONNX_NAMESPACE::TypeProto* getOutputType(size_t index) const;
162 
163  private:
164  const onnxruntime::Node& node_;
165 };
166 
167 } // namespace onnxruntime
MUST_USE_RESULT Status GetAttrs(const std::string &name, TensorShapeVector &out) const
void GetAttrOrDefault(const std::string &name, T *value, const T &default_value) const
ProtoHelperNodeContext(const onnxruntime::Node &node)
T GetAttrOrDefault(const std::string &name, const T &default_value) const
uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, const std::string &name) const noexcept
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
const ONNX_NAMESPACE::AttributeProto * TryGetAttribute(const std::string &name) const
#define ORT_ENFORCE(condition,...)
Definition: common.h:173
bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, const std::string &name) const noexcept
MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string &name, const TensorShapeVector &default_value=TensorShapeVector{}) const
MUST_USE_RESULT Status GetAttrsStringRefs(const std::string &name, std::vector< std::reference_wrapper< const std::string >> &refs) const
const ONNX_NAMESPACE::TypeProto * getInputType(size_t index) const
MUST_USE_RESULT Status GetAttr(const std::string &name, T *value) const
absl::InlinedVector< int64_t, kTensorShapeSmallBufferElementsSize > TensorShapeVector
Definition: tensor_shape.h:46
const ONNX_NAMESPACE::TypeProto * GetInputType(size_t index) const
GLuint const GLchar * name
Definition: glcorearb.h:786
MUST_USE_RESULT Status GetAttrsAsSpan(const std::string &name, gsl::span< const T > &values) const
Return a gsl::span that points to an array of primitive types held by AttributeProto This function al...
MUST_USE_RESULT std::vector< T > GetAttrsOrDefault(const std::string &name, const std::vector< T > &default_value=std::vector< T >{}) const
const ONNX_NAMESPACE::AttributeProto * GetAttribute(const std::string &name) const
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
const ONNX_NAMESPACE::TypeProto * GetOutputType(size_t index) const
const ONNX_NAMESPACE::TypeProto * getOutputType(size_t index) const
GLuint index
Definition: glcorearb.h:786
const ONNX_NAMESPACE::AttributeProto * getAttribute(const std::string &name) const
Definition: core.h:1131
#define MUST_USE_RESULT
type
Definition: core.h:1059