HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ort_kernel_invoker.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string>
7 #include <vector>
8 
9 #include "core/common/common.h"
11 #include "core/framework/tensor.h"
13 #include "core/graph/constants.h"
15 #include "core/graph/basic_types.h"
16 #include "core/graph/model.h"
17 
18 namespace onnxruntime {
19 #ifdef __GNUC__
20 #pragma GCC diagnostic push
21 #endif
22 
23 class ORTInvoker {
24  public:
25  ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
26  const logging::Logger& logger,
27  const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) :
28  execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
29  if (!execution_provider_) {
30  ORT_THROW("Execution provider is nullptr");
31  }
32  }
33 
35  return *execution_provider_;
36  }
37 
38  common::Status Invoke(const std::string& op_name,
39  //optional inputs / outputs?
40  const std::vector<OrtValue>& inputs,
41  std::vector<OrtValue>& outputs,
42  const NodeAttributes* attributes,
43  const std::string& domain = kOnnxDomain,
44  const int version = -1);
45 
46  private:
47  std::shared_ptr<IExecutionProvider> execution_provider_;
48  const logging::Logger& logger_;
49  // custom ops for current execution provider
50  // we need the op schema to resolve the output type during invoke
51  const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries_;
52 };
53 
54 #ifdef __GNUC__
55 #pragma GCC diagnostic pop
56 #endif
57 } // namespace onnxruntime
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
ORTInvoker(std::shared_ptr< IExecutionProvider > execution_provider, const logging::Logger &logger, const IOnnxRuntimeOpSchemaRegistryList &custom_op_registries)
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
Definition: basic_types.h:42
common::Status Invoke(const std::string &op_name, const std::vector< OrtValue > &inputs, std::vector< OrtValue > &outputs, const NodeAttributes *attributes, const std::string &domain=kOnnxDomain, const int version=-1)
constexpr const char * kOnnxDomain
Definition: constants.h:12
GT_API const UT_StringHolder version
#define ORT_THROW(...)
Definition: common.h:163
IExecutionProvider & GetCurrentExecutionProvider()