HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
kernel_registry.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string_view>
7 
9 
10 namespace onnxruntime {
11 
12 using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
13 using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
14 
15 class IKernelTypeStrResolver;
16 
17 /**
18  * Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider.
19  */
21  public:
22  KernelRegistry() = default;
23 
24  // Register a kernel with kernel definition and function to create the kernel.
25  Status Register(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator);
26 
27  Status Register(KernelCreateInfo&& create_info);
28 
29  // TODO(edgchen1) for TryFindKernel(), consider using `out` != nullptr as indicator of whether kernel was found and
30  // Status as an indication of failure
31 
32  // Check if an execution provider can create kernel for a node and return the kernel if so
33  Status TryFindKernel(const Node& node, ProviderType exec_provider,
34  const IKernelTypeStrResolver& kernel_type_str_resolver,
35  const KernelCreateInfo** out) const;
36 
37  static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
38  ProviderType exec_provider,
39  const IKernelTypeStrResolver& kernel_type_str_resolver) {
40  const KernelCreateInfo* info;
41  Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
42  return st.IsOK();
43  }
44 
45 #if !defined(ORT_MINIMAL_BUILD)
46  // Find KernelCreateInfo in instant mode
47  Status TryFindKernel(const std::string& op_name, const std::string& domain, const int& version,
48  const std::unordered_map<std::string, MLDataType>& type_constraints,
49  ProviderType exec_provider, const KernelCreateInfo** out) const;
50 #endif // !defined(ORT_MINIMAL_BUILD)
51 
52  bool IsEmpty() const { return kernel_creator_fn_map_.empty(); }
53 
54 #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
55  // This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
56  const KernelCreateMap& GetKernelCreateMap() const {
57  return kernel_creator_fn_map_;
58  }
59 #endif
60 
61  private:
62  // Check whether the types of inputs/outputs of the given node match the extra
63  // type-constraints of the given kernel. This serves two purposes: first, to
64  // select the right kernel implementation based on the types of the arguments
65  // when we have multiple kernels, e.g., Clip<float> and Clip<int>; second, to
66  // accommodate (and check) mapping of ONNX (specification) type to the onnxruntime
67  // implementation type (e.g., if we want to implement ONNX's float16 as a regular
68  // float in onnxruntime). (The second, however, requires a globally uniform mapping.)
69  //
70  // Note that this is not intended for type-checking the node against the ONNX
71  // type specification of the corresponding op, which is done before this check.
72  static bool VerifyKernelDef(const Node& node,
73  const KernelDef& kernel_def,
74  const IKernelTypeStrResolver& kernel_type_str_resolver,
75  std::string& error_str);
76 
77  static std::string GetMapKey(std::string_view op_name, std::string_view domain, std::string_view provider) {
78  std::string key(op_name);
79  // use the kOnnxDomainAlias of 'ai.onnx' instead of kOnnxDomain's empty string
80  key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider);
81  return key;
82  }
83 
84  static std::string GetMapKey(const KernelDef& kernel_def) {
85  return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider());
86  }
87  // Kernel create function map from op name to kernel creation info.
88  // key is opname+domain_name+provider_name
89  KernelCreateMap kernel_creator_fn_map_;
90 };
91 } // namespace onnxruntime
const std::string & ProviderType
Definition: basic_types.h:35
constexpr const char * kOnnxDomainAlias
Definition: constants.h:14
std::multimap< std::string, KernelCreateInfo > KernelCreateMap
Definition: Node.h:52
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
static bool HasImplementationOf(const KernelRegistry &r, const Node &node, ProviderType exec_provider, const IKernelTypeStrResolver &kernel_type_str_resolver)
basic_string_view< char > string_view
Definition: core.h:522
Status TryFindKernel(const Node &node, ProviderType exec_provider, const IKernelTypeStrResolver &kernel_type_str_resolver, const KernelCreateInfo **out) const
std::vector< std::pair< std::string, HashValue >> KernelDefHashes
Status Register(KernelDefBuilder &kernel_def_builder, const KernelCreateFn &kernel_creator)
std::function< Status(FuncManager &func_mgr, const OpKernelInfo &info, std::unique_ptr< OpKernel > &out)> KernelCreateFn
Definition: op_kernel.h:134
GT_API const UT_StringHolder version
GT_API const UT_StringHolder st
GLboolean r
Definition: glcorearb.h:1222