HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_kernel_context.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 namespace onnxruntime {
5 class IExecutionFrame;
6 class Stream;
7 namespace concurrency {
8 class ThreadPool;
9 }
10 
12  public:
13  using ArgMap = std::unordered_map<std::string, size_t>;
14 
15  OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
17  _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
18 
19  virtual ~OpKernelContext() = default;
20 
21  /**
22  Return the number of inputs for a variadic argument.
23  @param arg_num The operator argument number.
24  @returns Number of inputs the argument has.
25  */
26  virtual int NumVariadicInputs(size_t arg_num) const;
27 
28  virtual MLDataType InputType(int index) const;
29  virtual MLDataType OutputType(int index) const;
30 
31  const OrtValue* GetInputOrtValue(int index) const {
32  return GetInputMLValue(index);
33  }
34 
35  template <typename T>
36  const T* Input(int index) const {
37  const OrtValue* p_ml_value = GetInputMLValue(index);
38  ORT_TRY {
39  return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
40  }
41  ORT_CATCH(const std::exception& /*e*/) {
42  ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
43  }
44  }
45 
46  // Fetch a required input, enforcing that it is present.
47  template <typename T>
48  const T& RequiredInput(int index) const {
49  const T* input_ptr = Input<T>(index);
50  ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
51  return *input_ptr;
52  }
53 
54  // Fetch output (non-tensor) with specified index.
55  template <typename T>
56  T* Output(int index) {
58  return nullptr;
59 
60  OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
61  return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
62  }
63 
64  // In the case that memory allocation has not been done for an output tensor,
65  // The memory allocation will be done on-the-fly with given tensor shape.
66  // Return nullptr if the output is an unused optional output.
67  Tensor* Output(int index, const TensorShape& shape);
68  Tensor* Output(int index, const std::vector<int64_t>& shape);
69  Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
70 
71  // Fetch a required tensor output, enforcing that it is present.
72  Tensor& RequiredOutput(int index, const TensorShape& shape) {
73  Tensor* output_ptr = Output(index, shape);
74  ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
75  return *output_ptr;
76  }
77 
78 #if !defined(DISABLE_SPARSE_TENSORS)
79  // Fetch a sparse-tensor output corresponding to the specified index.
80  // shape must specify the shape of the underlying dense-tensor.
81  // Memory allocation for the output may happen when this method is invoked,
82  // unless static optimization pre-allocates it.
83  SparseTensor* OutputSparse(int index, const TensorShape& shape);
84 #endif
85 
86 #if !defined(DISABLE_OPTIONAL_TYPE)
87  // Use this API to output a "None" of a specific type (e.g. Tensor) at specified index
88  template <typename T>
90  auto* output_ort_value = GetOutputMLValue(index);
91 
92  auto type = DataTypeImpl::GetType<T>();
93 
94  output_ort_value->Init(nullptr, // This OrtValue is "None" and has no data
95  type,
96  type->GetDeleteFunc());
97  }
98 #endif
99 
100  // Retrieve indexed shape obtained from memory planning before actual
101  // computation. If the indexed shape cannot be inferred, this function returns
102  // false.
103  virtual bool TryGetInferredInputShape(int index, TensorShape& shape) const;
104 
105  // Retrieve indexed shape obtained from memory planning before actual
106  // computation. If the indexed shape cannot be inferred, this function returns
107  // false.
108  virtual bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
109 
110  const logging::Logger& Logger() const {
111  return *logger_;
112  }
113 
114  // always >= 0
115  virtual int InputCount() const {
116  return static_cast<int>(kernel_->Node().InputDefs().size());
117  }
118 
119  // always >= 0
120  virtual int ImplicitInputCount() const {
121  return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
122  }
123 
124  // always >= 0
125  virtual int OutputCount() const {
126  return static_cast<int>(kernel_->Node().OutputDefs().size());
127  }
128 
129  /**
130  Return an allocator on device 0, with memtype of OrtMemTypeDefault.
131  @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
132  */
133  [[nodiscard]] virtual Status GetTempSpaceAllocator(AllocatorPtr* output) const;
134 
135  /**
136  Return the allocator associated with the CPU EP with memtype of OrtMemTypeDefault.
137  @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
138  */
139  [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const;
140 
141  /**
142  Return the device id that current kernel runs on.
143  */
144  virtual int GetDeviceId() const {
145  return kernel_->Info().GetExecutionProvider()->GetDeviceId();
146  }
147 
148  /**
149  Return the compute stream associated with the EP that the kernel is partitioned to.
150  For EPs that do not have a compute stream (e.g. CPU EP), a nullptr is returned.
151  */
152  [[nodiscard]] virtual Stream* GetComputeStream() const {
153  return stream_;
154  }
155 
156  /**
157  Returns the opset domain of the underlying kernel
158  **/
159  const std::string& GetOpDomain() const;
160 
161  /**
162  Returns the optype of the underlying kernel
163  **/
164  const std::string& GetOpType() const;
165 
166  /**
167  Returns the node name of the underlying kernel
168  **/
169  const std::string& GetNodeName() const;
170 
171  /**
172  Returns the intra-op threadpool, if available.
173  */
175 
176  /**
177  Returns whether deterministic computation is preferred.
178  */
179  virtual bool GetUseDeterministicCompute() const {
180  return true;
181  }
182 
183  protected:
185 
187 
188  virtual const OrtValue* GetInputMLValue(int index) const;
189  const OrtValue* GetImplicitInputMLValue(int index) const;
191 
192 #ifdef ENABLE_ATEN
193  Status SetOutputMLValue(int index, const OrtValue& ort_value);
194 #endif
195 
196  // Creates the OrtValue* based on the shape, if it does not exist
197  virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);
198 
200 
201  private:
202  ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
203  int GetInputArgIndex(int index) const;
204  int GetImplicitInputArgIndex(int index) const;
205  int GetOutputArgIndex(int index) const;
206 
207  IExecutionFrame* const execution_frame_{};
208  const OpKernel* const kernel_{};
209  concurrency::ThreadPool* const threadpool_{};
210  const logging::Logger* const logger_{};
211 
212  // The argument starting index in ExecutionFrame.
213  int node_input_start_index_{-1};
214  int node_implicit_input_start_index_{-1};
215  int node_output_start_index_{-1};
216 
217  Stream* stream_;
218 };
219 
220 // Fetching output tensor without shape is not allowed except when it already exists
221 template <>
222 inline Tensor* OpKernelContext::Output<Tensor>(int index) {
223  OrtValue* p_ml_value = GetOutputMLValue(index);
224  ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
225  return p_ml_value->GetMutable<Tensor>();
226 }
227 
228 #if !defined(DISABLE_SPARSE_TENSORS)
229 template <>
230 inline SparseTensor* OpKernelContext::Output<SparseTensor>(int index) {
231  OrtValue* p_ml_value = GetOutputMLValue(index);
232  ORT_ENFORCE(p_ml_value, "Please fetch output sparse tensor with specified shape.");
233  return p_ml_value->GetMutable<SparseTensor>();
234 }
235 #endif
236 
237 } // namespace onnxruntime
GLuint GLuint stream
Definition: glcorearb.h:1832
const IExecutionProvider * GetExecutionProvider() const noexcept
const std::string & GetNodeName() const
virtual bool GetUseDeterministicCompute() const
OpKernelContext(_Inout_ IExecutionFrame *frame, _In_ const OpKernel *kernel, _In_ Stream *stream, _In_opt_ concurrency::ThreadPool *threadpool, _In_ const logging::Logger &logger)
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
Definition: graph.h:224
Base class for MLDataType.
Definition: data_types.h:81
#define _In_
virtual bool TryGetInferredOutputShape(int index, TensorShape &shape) const
virtual int GetDeviceId() const
#define _Ret_maybenull_
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool * GetOperatorThreadPool() const
virtual int OutputCount() const
virtual Stream * GetComputeStream() const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
const T * Input(int index) const
const OrtValue * GetInputOrtValue(int index) const
const logging::Logger & Logger() const
#define _In_opt_
#define ORT_ENFORCE(condition,...)
Definition: common.h:173
virtual const OrtValue * GetInputMLValue(int index) const
virtual int NumVariadicInputs(size_t arg_num) const
#define ORT_TRY
Definition: common.h:154
void OutputOptionalWithoutData(int index)
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
Definition: graph.h:237
#define _Inout_
This class implements SparseTensor. This class holds sparse non-zero data (values) and sparse format ...
Definition: sparse_tensor.h:55
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Definition: graph.h:231
virtual OrtValue * OutputMLValue(int index, const TensorShape &shape)
virtual int ImplicitInputCount() const
virtual int InputCount() const
Tensor & RequiredOutput(int index, const TensorShape &shape)
const std::string & GetOpDomain() const
virtual OrtValue * GetOrCreateOutputMLValue(int index)
std::unordered_map< std::string, size_t > ArgMap
const T & RequiredInput(int index) const
SparseTensor * OutputSparse(int index, const TensorShape &shape)
virtual ~OpKernelContext()=default
T * GetMutable()
Definition: ort_value.h:57
virtual MLDataType InputType(int index) const
onnxruntime::NodeIndex GetNodeIndex() const
#define ORT_THROW(...)
Definition: common.h:163
Status GetTempSpaceCPUAllocator(AllocatorPtr *output) const
OrtValue * GetOutputMLValue(int index)
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:190
const onnxruntime::Node & Node() const
const std::string & GetOpType() const
GLuint index
Definition: glcorearb.h:786
virtual bool TryGetInferredInputShape(int index, TensorShape &shape) const
virtual MLDataType OutputType(int index) const
type
Definition: core.h:1059
#define ORT_CATCH(x)
Definition: common.h:155
size_t NodeIndex
Definition: basic_types.h:30
const T & Get() const
Definition: ort_value.h:50
const OpKernelInfo & Info() const
Definition: op_kernel.h:125
const OrtValue * GetImplicitInputMLValue(int index) const
virtual Status GetTempSpaceAllocator(AllocatorPtr *output) const