6 #include "boost/mp11.hpp"
10 #include "core/framework/prepacked_weights_container.h"
12 #ifndef SHARED_PROVIDER
26 #if !defined(ORT_MINIMAL_BUILD)
27 #include "onnx/defs/schema.h"
29 #include "onnx/defs/data_type_utils.h"
31 #include "onnx/onnx_pb.h"
32 #include "onnx/onnx-operators_pb.h"
34 namespace onnxruntime {
35 class OpKernelContext;
39 namespace onnxruntime {
55 [[nodiscard]]
virtual bool IsAsync()
const {
96 bool& is_packed, PrePackedWeights* ) {
119 bool& used_shared_buffers) {
120 used_shared_buffers =
false;
126 return *op_kernel_info_;
130 ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(
OpKernel);
131 std::unique_ptr<OpKernelInfo> op_kernel_info_;
134 using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>;
135 using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::
type;
155 template <
typename T>
159 template <
typename T>
164 template <
typename T>
170 template <
typename T>
177 template <
typename T>
184 template <
typename T>
192 #define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
193 provider##_##name##_##domain##_ver##ver
195 #define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \
196 ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
198 #define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \
199 ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
201 #define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \
202 ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
204 #define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
205 class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
208 BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
209 return KernelCreateInfo( \
210 builder.SetName(#name) \
213 .Provider(provider) \
215 static_cast<KernelCreatePtrFn>( \
217 const OpKernelInfo& info, \
218 std::unique_ptr<OpKernel>& out) -> Status { \
219 out = std::make_unique<__VA_ARGS__>(info); \
220 return Status::OK(); \
224 #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \
225 provider##_##name##_##domain##_ver##startver##_##endver
227 #define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \
228 ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
230 #define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \
231 ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
233 #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \
234 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \
237 BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \
238 return KernelCreateInfo( \
239 builder.SetName(#name) \
241 .SinceVersion(startver, endver) \
242 .Provider(provider) \
244 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
247 #define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \
248 provider##_##name##_##domain##_ver##ver##_##type
250 #define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \
251 ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
253 #define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \
254 ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
256 #define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \
257 ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
259 #define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \
260 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
263 BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
264 return KernelCreateInfo( \
265 builder.SetName(#name) \
268 .Provider(provider) \
270 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
273 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \
274 provider##_##name##_##domain##_ver##ver##_##type1##_##type2
276 #define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \
277 class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \
280 BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \
281 return KernelCreateInfo( \
282 builder.SetName(#name) \
285 .Provider(provider) \
287 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
290 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
291 provider##_##name##_##domain##_ver##startver##_##endver##_##type
293 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \
294 ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \
297 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \
298 ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \
301 #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \
302 ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \
305 #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \
306 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \
309 BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
311 return KernelCreateInfo( \
312 builder.SetName(#name) \
314 .SinceVersion(startver, endver) \
315 .Provider(provider) \
317 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
320 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \
321 provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2
323 #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \
324 provider, builder, ...) \
325 class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \
328 BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
329 type1, type2, name)>() { \
330 return KernelCreateInfo( \
331 builder.SetName(#name) \
333 .SinceVersion(startver, endver) \
334 .Provider(provider) \
336 static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
339 template <
typename... Types>
342 return {DataTypeImpl::GetTensorType<Types>()...};
346 #if !defined(DISABLE_SPARSE_TENSORS)
347 template <
typename... Types>
350 return {DataTypeImpl::GetSparseTensorType<Types>()...};
358 template <
typename... Types>
363 #if !defined(DISABLE_SPARSE_TENSORS)
364 template <
typename... Types>
371 template <
typename L>
373 return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}();
376 #if !defined(DISABLE_SPARSE_TENSORS)
377 template <
typename L>
379 return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}();
385 #ifndef SHARED_PROVIDER
KernelCreateInfo BuildKernelCreateInfo()
KernelCreateInfo BuildKernelCreateInfo()
std::unique_ptr< KernelDef > kernel_def
KernelCreateInfo(KernelCreateInfo &&other) noexcept
KernelCreateInfo BuildKernelCreateInfo()
std::function< void()> DoneCallback
const OrtMemoryInfo & Allocator(int id, OrtMemType mem_type) const
virtual Status UseSharedPrePackedBuffers(std::vector< BufferUniquePtr > &, int, bool &used_shared_buffers)
KernelCreateInfo BuildKernelCreateInfo()
std::vector< MLDataType > BuildKernelDefConstraintsFromTypeList()
const onnxruntime::KernelDef & KernelDef() const
OpKernel(const OpKernelInfo &info)
virtual Status ComputeAsync(_Inout_ OpKernelContext *, DoneCallback) const
KernelCreateInfo BuildKernelCreateInfo()
std::unique_ptr< OpKernelInfo > CopyOpKernelInfo(const OpKernelInfo &info)
std::vector< MLDataType > BuildKernelDefConstraints()
virtual Status PrePack(const Tensor &, int, AllocatorPtr, bool &is_packed, PrePackedWeights *)
std::vector< MLDataType > operator()() const
KernelCreateInfo BuildKernelCreateInfo()
virtual ~OpKernel()=default
std::function< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)> KernelCreateFn
std::vector< MLDataType > BuildKernelDefSparseConstraintsFromTypeList()
KernelCreateInfo(*)( BuildKernelCreateInfoFn)
std::shared_ptr< IAllocator > AllocatorPtr
KernelCreateInfo(std::unique_ptr< KernelDef > definition, KernelCreateFn create_func)
KernelCreateFn kernel_create_func
const onnxruntime::Node & Node() const
virtual Status Compute(_Inout_ OpKernelContext *context) const =0
std::vector< MLDataType > BuildKernelDefSparseConstraints()
std::add_pointer< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)>::type KernelCreatePtrFn
KernelCreateInfo()=default
std::vector< MLDataType > operator()() const
virtual bool IsAsync() const
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
const OpKernelInfo & Info() const