HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_kernel.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include "boost/mp11.hpp"
7 
8 // It is safe to include the below header even if SHARED_PROVIDER macro is enabled
9 // as it doesn't include any pb headers.
10 #include "core/framework/prepacked_weights_container.h"
11 
12 #ifndef SHARED_PROVIDER
13 #include <functional>
14 #include "core/common/exceptions.h"
16 #include "core/common/status.h"
22 #include "core/framework/tensor.h"
24 #include "core/graph/constants.h"
26 #if !defined(ORT_MINIMAL_BUILD)
27 #include "onnx/defs/schema.h"
28 #else
29 #include "onnx/defs/data_type_utils.h"
30 #endif
31 #include "onnx/onnx_pb.h"
32 #include "onnx/onnx-operators_pb.h"
33 #include "core/common/gsl.h"
34 namespace onnxruntime {
35 class OpKernelContext;
36 }
37 #endif
38 
39 namespace onnxruntime {
40 
41 std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info);
42 
43 class OpKernel {
44  public:
45  using DoneCallback = std::function<void()>;
46 
47  explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
48  virtual ~OpKernel() = default;
49 
50  const onnxruntime::Node& Node() const;
51  const onnxruntime::KernelDef& KernelDef() const;
52 
53  [[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0;
54 
55  [[nodiscard]] virtual bool IsAsync() const {
56  // by default all kernels are sync version.
57  return false;
58  }
59 
60  [[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const {
61  ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
62  }
63 
64  // Override this function to PrePack initialized constant tensor to the format as needed.
65  // For example, MatMul kernel can pack the input B if it is constant like code below.
66  // Status PrePack(const Tensor& tensor, int input_idx, /*out*/ bool& is_packed,
67  // /*out*/ PrePackedWeights* prepacked_weight_for_caching,
68  // AllocatorPtr alloc) override {
69  // is_packed = false;
70  // if (input_idx == 1) {
71  // is_packed = true;
72  // this.Pack(tensor, this.buffer_, alloc);
73  // if (prepacked_weight_for_caching) {
74  // // LOGIC TO CACHE `this.buffer_` SINCE THE KERNEL DOESN"T OWN THE PACKED WEIGHT
75  // }
76  // }
77  // return Status::OK();
78  // }
79  // Please refer to MatMulIntegerToFloatBase for a complete example
80  // @param tensor: The initialized constant tensor
81  // @param input_idx: The input index of the tensor in this kernel
82  // @param alloc: The kernel's PrePack() method MUST use this allocator for allocating the pre-packed
83  // weights' buffers. The alloc that the PrePack() method will receive will be either
84  // the allocator tied to the session if the kernel owns the pre-packed buffer or an
85  // allocator shared between sessions if the pre-packed buffer is to be shared across sessions
86  // (i.e.) the kernel does not own the buffer.
87  // @param is_packed: Set it to true if the kernel packed the tensor or to false
88  // The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
89  // and the original initialized constant tensor will be released and not accessible anymore in
90  // the Compute function.
91  // @param prepacked_weights: A PrePackedWeights instance will be provided to the kernel IF the pre-packed weights
92  // are meant to be stored in a shared container.
93 
94  virtual Status
95  PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
96  /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
97  is_packed = false;
98  return Status::OK();
99  }
100 
101  // Override this function to use provided pre-packed weight.
102  // Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
103  // int input_idx,
104  // /*out*/ bool& used_shared_buffers) {
105  // used_shared_buffers = true;
106  // this.buffer_ = std::move(prepacked_buffers[0]);
107  // return Status::OK();
108  // }
109  // Please refer to MatMulIntegerToFloatBase for a complete example
110  // @param prepacked_buffers: The pre-packed buffers to be used by this kernel for the provided input index
111  // (Sometimes a single constant initializer may have multiple pre-packed buffers associated
112  // with it and it upto the kernel developer to store it in any order of their choice in PrePack()
113  // and must use the same order for retrieval in UseSharedPrePackedBuffers().
114  // @param input_idx: The input index of the tensor in this kernel
115  // @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
116  // that the provided weight has been used by the kernel.
117  virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
118  int /*input_idx*/,
119  /*out*/ bool& used_shared_buffers) {
120  used_shared_buffers = false;
121  return Status::OK();
122  }
123 
124  const OrtMemoryInfo& Allocator(int id, OrtMemType mem_type) const;
125  const OpKernelInfo& Info() const {
126  return *op_kernel_info_;
127  }
128 
129  private:
130  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
131  std::unique_ptr<OpKernelInfo> op_kernel_info_;
132 };
133 class FuncManager;
134 using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>;
135 using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::type;
136 
138  std::unique_ptr<KernelDef> kernel_def; // Owned and stored in the global kernel registry.
141 
142  KernelCreateInfo(std::unique_ptr<KernelDef> definition,
143  KernelCreateFn create_func)
144  : kernel_def(std::move(definition)),
145  kernel_create_func(create_func) {}
146 
148  : kernel_def(std::move(other.kernel_def)),
149  kernel_create_func(std::move(other.kernel_create_func)) {}
150 
151  KernelCreateInfo() = default;
152 };
153 
154 // Forward declarations for the non-specialized BuildKernelCreateInfo method.
155 template <typename T>
156 KernelCreateInfo BuildKernelCreateInfo();
157 
158 namespace ml {
159 template <typename T>
161 } // namespace ml
162 
163 namespace contrib {
164 template <typename T>
166 } // namespace contrib
167 
168 namespace contrib {
169 namespace cuda {
170 template <typename T>
172 } // namespace cuda
173 } // namespace contrib
174 
175 namespace contrib {
176 namespace rocm {
177 template <typename T>
179 } // namespace rocm
180 } // namespace contrib
181 
182 namespace contrib {
183 namespace snpe {
184 template <typename T>
186 } // namespace snpe
187 } // namespace contrib
188 
190 
191 // Naming convention for operator kernel classes
192 #define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
193  provider##_##name##_##domain##_ver##ver
194 
195 #define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \
196  ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
197 
198 #define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \
199  ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
200 
201 #define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \
202  ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
203 
204 #define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
205  class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
206  template <> \
207  KernelCreateInfo \
208  BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
209  return KernelCreateInfo( \
210  builder.SetName(#name) \
211  .SetDomain(domain) \
212  .SinceVersion(ver) \
213  .Provider(provider) \
214  .Build(), \
215  static_cast<KernelCreatePtrFn>( \
216  [](FuncManager&, \
217  const OpKernelInfo& info, \
218  std::unique_ptr<OpKernel>& out) -> Status { \
219  out = std::make_unique<__VA_ARGS__>(info); \
220  return Status::OK(); \
221  })); \
222  }
223 
224 #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \
225  provider##_##name##_##domain##_ver##startver##_##endver
226 
227 #define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \
228  ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
229 
230 #define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \
231  ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
232 
233 #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \
234  class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \
235  template <> \
236  KernelCreateInfo \
237  BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \
238  return KernelCreateInfo( \
239  builder.SetName(#name) \
240  .SetDomain(domain) \
241  .SinceVersion(startver, endver) \
242  .Provider(provider) \
243  .Build(), \
244  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
245  }
246 
247 #define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \
248  provider##_##name##_##domain##_ver##ver##_##type
249 
250 #define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \
251  ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
252 
253 #define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \
254  ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
255 
256 #define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \
257  ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
258 
259 #define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \
260  class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
261  template <> \
262  KernelCreateInfo \
263  BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
264  return KernelCreateInfo( \
265  builder.SetName(#name) \
266  .SetDomain(domain) \
267  .SinceVersion(ver) \
268  .Provider(provider) \
269  .Build(), \
270  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
271  }
272 
273 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \
274  provider##_##name##_##domain##_ver##ver##_##type1##_##type2
275 
276 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \
277  class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \
278  template <> \
279  KernelCreateInfo \
280  BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \
281  return KernelCreateInfo( \
282  builder.SetName(#name) \
283  .SetDomain(domain) \
284  .SinceVersion(ver) \
285  .Provider(provider) \
286  .Build(), \
287  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
288  }
289 
290 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
291  provider##_##name##_##domain##_ver##startver##_##endver##_##type
292 
293 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \
294  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \
295  __VA_ARGS__)
296 
297 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \
298  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \
299  __VA_ARGS__)
300 
301 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \
302  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \
303  __VA_ARGS__)
304 
305 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \
306  class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \
307  template <> \
308  KernelCreateInfo \
309  BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
310  type, name)>() { \
311  return KernelCreateInfo( \
312  builder.SetName(#name) \
313  .SetDomain(domain) \
314  .SinceVersion(startver, endver) \
315  .Provider(provider) \
316  .Build(), \
317  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
318  }
319 
320 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \
321  provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2
322 
323 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \
324  provider, builder, ...) \
325  class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \
326  template <> \
327  KernelCreateInfo \
328  BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
329  type1, type2, name)>() { \
330  return KernelCreateInfo( \
331  builder.SetName(#name) \
332  .SetDomain(domain) \
333  .SinceVersion(startver, endver) \
334  .Provider(provider) \
335  .Build(), \
336  static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
337  }
338 
339 template <typename... Types>
341  std::vector<MLDataType> operator()() const {
342  return {DataTypeImpl::GetTensorType<Types>()...};
343  }
344 };
345 
346 #if !defined(DISABLE_SPARSE_TENSORS)
347 template <typename... Types>
349  std::vector<MLDataType> operator()() const {
350  return {DataTypeImpl::GetSparseTensorType<Types>()...};
351  }
352 };
353 #endif
354 
355 // Use within macro definitions to create a custom vector of constraints.
356 // Example: #define REG_KERNEL(OP, VERSION, KERNEL_CLASS, Type, ...)
357 // .TypeConstraint("T", BuildKernelDefConstraints<Type, __VA_ARGS_>())
358 template <typename... Types>
359 inline std::vector<MLDataType> BuildKernelDefConstraints() {
360  return BuildKernelDefConstraintsImpl<Types...>{}();
361 }
362 
363 #if !defined(DISABLE_SPARSE_TENSORS)
364 template <typename... Types>
365 inline std::vector<MLDataType> BuildKernelDefSparseConstraints() {
366  return BuildKernelDefSparseConstraintsImpl<Types...>{}();
367 }
368 #endif
369 
370 // version of BuildKernelDefConstraints() which takes a type list
371 template <typename L>
372 inline std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() {
373  return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}();
374 }
375 
376 #if !defined(DISABLE_SPARSE_TENSORS)
377 template <typename L>
378 inline std::vector<MLDataType> BuildKernelDefSparseConstraintsFromTypeList() {
379  return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}();
380 }
381 #endif
382 
383 } // namespace onnxruntime
384 
385 #ifndef SHARED_PROVIDER
387 #endif
KernelCreateInfo BuildKernelCreateInfo()
KernelCreateInfo BuildKernelCreateInfo()
std::unique_ptr< KernelDef > kernel_def
Definition: op_kernel.h:138
KernelCreateInfo(KernelCreateInfo &&other) noexcept
Definition: op_kernel.h:147
KernelCreateInfo BuildKernelCreateInfo()
std::function< void()> DoneCallback
Definition: op_kernel.h:45
const OrtMemoryInfo & Allocator(int id, OrtMemType mem_type) const
virtual Status UseSharedPrePackedBuffers(std::vector< BufferUniquePtr > &, int, bool &used_shared_buffers)
Definition: op_kernel.h:117
KernelCreateInfo BuildKernelCreateInfo()
std::vector< MLDataType > BuildKernelDefConstraintsFromTypeList()
Definition: op_kernel.h:372
const onnxruntime::KernelDef & KernelDef() const
OpKernel(const OpKernelInfo &info)
Definition: op_kernel.h:47
virtual Status ComputeAsync(_Inout_ OpKernelContext *, DoneCallback) const
Definition: op_kernel.h:60
KernelCreateInfo BuildKernelCreateInfo()
std::unique_ptr< OpKernelInfo > CopyOpKernelInfo(const OpKernelInfo &info)
std::vector< MLDataType > BuildKernelDefConstraints()
Definition: op_kernel.h:359
#define _Inout_
virtual Status PrePack(const Tensor &, int, AllocatorPtr, bool &is_packed, PrePackedWeights *)
Definition: op_kernel.h:95
std::vector< MLDataType > operator()() const
Definition: op_kernel.h:341
KernelCreateInfo BuildKernelCreateInfo()
virtual ~OpKernel()=default
std::function< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)> KernelCreateFn
Definition: op_kernel.h:134
std::vector< MLDataType > BuildKernelDefSparseConstraintsFromTypeList()
Definition: op_kernel.h:378
KernelCreateInfo(*)( BuildKernelCreateInfoFn)
Definition: op_kernel.h:189
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:190
KernelCreateInfo(std::unique_ptr< KernelDef > definition, KernelCreateFn create_func)
Definition: op_kernel.h:142
KernelCreateFn kernel_create_func
Definition: op_kernel.h:139
const onnxruntime::Node & Node() const
virtual Status Compute(_Inout_ OpKernelContext *context) const =0
std::vector< MLDataType > BuildKernelDefSparseConstraints()
Definition: op_kernel.h:365
std::add_pointer< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)>::type KernelCreatePtrFn
Definition: op_kernel.h:135
std::vector< MLDataType > operator()() const
Definition: op_kernel.h:349
type
Definition: core.h:1059
virtual bool IsAsync() const
Definition: op_kernel.h:55
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
const OpKernelInfo & Info() const
Definition: op_kernel.h:125