HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ort_kernel_invoker.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <string>
7 #include <vector>
8 
9 #include "core/common/common.h"
11 #include "core/framework/tensor.h"
13 #include "core/graph/constants.h"
15 #include "core/graph/basic_types.h"
16 #include "core/graph/model.h"
17 
18 namespace onnxruntime {
19 #ifdef __GNUC__
20 #pragma GCC diagnostic push
21 #endif
22 
23 class ORTInvoker {
24  public:
25  ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
26  const logging::Logger& logger,
27  const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries)
28  : execution_provider_(std::move(execution_provider)),
29  logger_(logger),
30  custom_op_registries_(custom_op_registries) {
31  if (!execution_provider_) {
32  ORT_THROW("Execution provider is nullptr");
33  }
34  }
35 
37  return *execution_provider_;
38  }
39 
40  common::Status Invoke(const std::string& op_name,
41  // optional inputs / outputs?
42  const std::vector<OrtValue>& inputs,
43  std::vector<OrtValue>& outputs,
44  const NodeAttributes* attributes,
45  const std::string& domain = kOnnxDomain,
46  const int version = -1);
47 
48  private:
49  std::shared_ptr<IExecutionProvider> execution_provider_;
50  const logging::Logger& logger_;
51  // custom ops for current execution provider
52  // we need the op schema to resolve the output type during invoke
53  const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries_;
54 };
55 
56 #ifdef __GNUC__
57 #pragma GCC diagnostic pop
58 #endif
59 } // namespace onnxruntime
ORTInvoker(std::shared_ptr< IExecutionProvider > execution_provider, const logging::Logger &logger, const IOnnxRuntimeOpSchemaRegistryList &custom_op_registries)
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
Definition: basic_types.h:44
common::Status Invoke(const std::string &op_name, const std::vector< OrtValue > &inputs, std::vector< OrtValue > &outputs, const NodeAttributes *attributes, const std::string &domain=kOnnxDomain, const int version=-1)
constexpr const char * kOnnxDomain
Definition: constants.h:14
GT_API const UT_StringHolder version
#define ORT_THROW(...)
Definition: common.h:162
IExecutionProvider & GetCurrentExecutionProvider()