HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
execution_provider.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #ifndef SHARED_PROVIDER
7 #include <memory>
8 #include <unordered_map>
9 #include <unordered_set>
10 
12 #include "core/common/status.h"
13 #include "core/framework/data_transfer.h"
14 #include "core/framework/external_data_loader.h"
15 #include "core/framework/tensor.h"
16 
17 namespace onnxruntime {
18 class GraphViewer;
19 struct ComputeCapability;
20 class KernelRegistry;
21 struct KernelCreateInfo;
22 class Node;
23 } // namespace onnxruntime
24 #else
25 #include <memory>
26 #endif
27 
30 #include "core/framework/allocator_utils.h"
35 #include "core/framework/tuning_context.h"
36 
37 struct OrtRunOptions;
38 
39 namespace onnxruntime {
40 
41 /**
42  Logical device representation.
43 */
44 
45 // if we are export the fused function to dll, the function will still in the same binary as onnxruntime
46 // use std function to give execution provider some chance to capture some state.
47 using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>;
48 using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
49 using DestroyFunctionStateFunc = std::function<void(FunctionState)>;
50 
55 };
56 
58 
59 enum class DataLayout {
60  NCHW,
61  NHWC,
62  NCHWC,
63 };
64 
66  protected:
67  IExecutionProvider(const std::string& type)
68  : IExecutionProvider(type, OrtDevice()) {}
69 
70  IExecutionProvider(const std::string& type, OrtDevice device)
71  : default_device_(device), type_{type} {
72  }
73 
74  /*
75  default device for this ExecutionProvider
76  */
78 
79  public:
80  virtual ~IExecutionProvider() = default;
81 
82  /**
83  * Returns a data transfer object that implements methods to copy to and
84  * from this device.
85  * If no copy is required for the successful operation of this provider,
86  * return a nullptr.
87  */
88  virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const {
89  return nullptr;
90  }
91 
92  /**
93  * Returns an external data loader object that implements methods to load data from external sources.
94  *
95  * By default, framework will handle external data loading by loading the data into CPU memory and then copying
96  * it to the target device if required. So in most cases, it's not necessary to override this method. Specifically,
97  * in WebAssembly build, because the memory is limited and Web platform supports loading data from external sources
98  * directly into GPU memory, this method is overridden to provide a custom external data loader to avoid the extra
99  * CPU memory usage.
100  */
101  virtual std::unique_ptr<onnxruntime::IExternalDataLoader> GetExternalDataLoader() const {
102  return nullptr;
103  }
104 
105  /**
106  * Interface for performing kernel lookup within kernel registries.
107  * Abstracts away lower-level details about kernel registries and kernel matching.
108  */
110  public:
111  /**
112  * Given `node`, try to find a matching kernel for this EP.
113  * The return value is non-null if and only if a matching kernel was found.
114  */
115  virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
116 
117  protected:
118  ~IKernelLookup() = default;
119  };
120 
121  /**
122  Get execution provider's capability for the specified <graph>.
123  Return a bunch of IndexedSubGraphs <*this> execution provider can run if
124  the sub-graph contains only one node or can fuse to run if the sub-graph
125  contains more than one node. The node indexes contained in sub-graphs may
126  have overlap, and it's ONNXRuntime's responsibility to do the partition
127  and decide whether a node will be assigned to <*this> execution provider.
128  For kernels registered in a kernel registry, `kernel_lookup` must be used
129  to find a matching kernel for this EP.
130  */
131  virtual std::vector<std::unique_ptr<ComputeCapability>>
132  GetCapability(const onnxruntime::GraphViewer& graph_viewer,
133  const IKernelLookup& kernel_lookup) const;
134 
135  /**
136  Get kernel registry per execution provider type.
137  The KernelRegistry share pointer returned is shared across sessions.
138 
139  NOTE: this approach was taken to achieve the following goals,
140  1. The execution provider type based kernel registry should be shared
141  across sessions.
142  Only one copy of this kind of kernel registry exists in ONNXRuntime
143  with multiple sessions/models.
144  2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
145  framework/session code.
146  3. onnxruntime (framework/session) does not depend on any specific
147  execution provider lib.
148  */
149  virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; }
150 
151  /**
152  Get the device id of current execution provider
153  */
154  virtual int GetDeviceId() const { return 0; };
155 
156  /**
157  Get execution provider's configuration options.
158  */
159  virtual ProviderOptions GetProviderOptions() const { return {}; }
160 
161  /**
162  Get provider specific custom op domain list.
163  Provider has the responsibility to release OrtCustomOpDomain instances it creates.
164 
165  NOTE: In the case of ONNX model having EP specific custom nodes and don't want to ask user to register those nodes,
166  EP might need to a way to register those custom nodes. This API is added for the purpose where EP can use it to
167  leverage ORT custom op to register those custom nodes with one or more custom op domains.
168 
169  For example, TensorRT EP uses this API to support TRT plugins where each custom op is mapped to TRT plugin and no
170  kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as
171  a role to help pass ONNX model validation.
172  */
173  virtual void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/) const {};
174 
175  /**
176  Returns an opaque handle whose exact type varies based on the provider
177  and is interpreted accordingly by the corresponding kernel implementation.
178  For Direct3D operator kernels, this may return an IUnknown supporting
179  QueryInterface to ID3D12GraphicsCommandList1.
180  */
181  virtual const void* GetExecutionHandle() const noexcept {
182  return nullptr;
183  }
184 
185  /**
186  @return type of the execution provider; should match that set in the node
187  through the SetExecutionProvider API. Example valid return values are:
188  kCpuExecutionProvider, kCudaExecutionProvider
189  */
190  const std::string& Type() const { return type_; }
191 
192  /**
193  Blocks until the device has completed all preceding requested tasks.
194  Currently this is primarily used by the IOBinding object to ensure that all
195  inputs have been copied to the device before execution begins.
196  */
197  virtual common::Status Sync() const { return Status::OK(); }
198 
199  /**
200  Called when InferenceSession::Run started
201  NOTE that due to async execution in provider, the actual work of previous
202  Run may not be finished on device This function should be regarded as the
203  point after which a new Run would start to submit commands from CPU
204  */
205  virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }
206 
207  /**
208  Called when InferenceSession::Run ended
209  NOTE that due to async execution in provider, the actual work of this Run
210  may not be finished on device This function should be regarded as the point
211  that all commands of current Run has been submmited by CPU
212  */
213  virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
214  return Status::OK();
215  }
216 
217  /**
218  Called when InferenceSession::SetEpDynamicOptions is called
219  */
220  virtual common::Status SetEpDynamicOptions(gsl::span<const char* const> /*keys*/,
221  gsl::span<const char* const> /*values*/) {
222  return Status::OK();
223  }
224 
225  /**
226  Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
227  the provider.
228  */
229  virtual bool IsGraphCaptureEnabled() const { return false; }
230 
231  /**
232  Indicate whether the graph has been captured and instantiated.
233  */
234  virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; }
235 
236  /**
237  Run the instantiated graph.
238  */
239  virtual common::Status ReplayGraph(int /*graph_annotation_id*/) {
240  return Status::OK();
241  }
242 
243  /**
244  Called when session creation is complete
245  This provides an opportunity for execution providers to optionally synchronize and
246  clean up its temporary resources to reduce memory and ensure the first run is fast.
247  */
249 
251  const std::reference_wrapper<onnxruntime::Node> fused_node;
252  // GraphViewer that filters the full graph to the nodes that are covered by 'node'
253  const std::reference_wrapper<GraphViewer> filtered_graph;
254  };
255 
256  // Fusion approach that is supported
257  // !!! The "Function" FusionStyle is deprecated.
258  // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
259  enum class FusionStyle {
260  // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
261  // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
262  // A GraphProto can be produced from the Node body.
263  Function,
264 
265  // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
266  // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
267  // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
268  // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
269  // and can be supported in a minimal build.
271  };
272 
273  virtual FusionStyle GetFusionStyle() const {
274  // All the ORT build in EP has migrate to FilteredGraphViewer style.
275  // For newer EPs, please avoid use Function style as it is deprecated.
277  }
278 
279 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
280  /**
281  Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
282  return create_state/compute/release_state func for each node.
283  @remarks This is now the default interface when execution provider wants to compile nodes
284  for both minimal build and complete ort build.
285 
286  Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
287  as it is only valid for the duration of the call to Compile.
288  */
289  virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
290  std::vector<NodeComputeInfo>& node_compute_funcs);
291 
292 #endif
293 
294  void SetLogger(const logging::Logger* logger) {
295  logger_ = logger;
296  }
297 
298  const logging::Logger* GetLogger() const {
299  return logger_;
300  }
301 
302  virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
303  return {};
304  }
305 
306  virtual DataLayout GetPreferredLayout() const {
307  // NCHW is the default ONNX standard data layout. So default to it.
308  // EPs which prefer a different layout should override to return their preferred layout.
309  return DataLayout::NCHW;
310  }
311 
312  virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}
313 
314  /** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
315  */
316  virtual bool ConcurrentRunSupported() const { return true; }
317 
318  /**
319  * Return the tuning context which holds all TunableOp state.
320  */
321  virtual ITuningContext* GetTuningContext() const {
322  return nullptr;
323  }
324 
325  /**
326  * Return the appropriate OrtDevice object given OrtMemType.
327  */
328  virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const {
329  if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
330  return OrtDevice(); // default return CPU device.
331  }
332  return default_device_;
333  };
334 
335  /**
336  * Create Preferred allocators for the current Execution Provider
337  * This function is a stateless function which creates new instances of Allocator, without storing them in EP.
338  */
339  virtual std::vector<AllocatorPtr> CreatePreferredAllocators() { return std::vector<AllocatorPtr>(); };
340 
341  /**
342  * Get the array of pointers for EPContext nodes
343  * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it.
344  * Default return an empty vector if not provided by the Execution Provider
345  */
348  }
349 
350  private:
351  const std::string type_;
352 
353  // It will be set when this object is registered to a session
354  const logging::Logger* logger_ = nullptr;
355 };
356 } // namespace onnxruntime
virtual FusionStyle GetFusionStyle() const
virtual common::Status Sync() const
virtual common::Status OnRunStart(const onnxruntime::RunOptions &)
virtual std::vector< std::unique_ptr< ComputeCapability > > GetCapability(const onnxruntime::GraphViewer &graph_viewer, const IKernelLookup &kernel_lookup) const
virtual ~IExecutionProvider()=default
Definition: Node.h:52
virtual const KernelCreateInfo * LookUpKernel(const Node &node) const =0
virtual bool IsGraphCaptured(int) const
std::function< int(ComputeContext *, FunctionState *)> CreateFunctionStateFunc
virtual DataLayout GetPreferredLayout() const
virtual common::Status Compile(const std::vector< FusedNodeAndGraph > &fused_nodes_and_graphs, std::vector< NodeComputeInfo > &node_compute_funcs)
virtual std::shared_ptr< KernelRegistry > GetKernelRegistry() const
virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const
virtual common::Status SetEpDynamicOptions(gsl::span< const char *const >, gsl::span< const char *const >)
const std::string & Type() const
virtual std::unique_ptr< profiling::EpProfiler > GetProfiler()
virtual std::unique_ptr< onnxruntime::IExternalDataLoader > GetExternalDataLoader() const
GLint GLint GLsizei GLint GLenum GLenum type
Definition: glcorearb.h:108
const std::reference_wrapper< GraphViewer > filtered_graph
const logging::Logger * GetLogger() const
DestroyFunctionStateFunc release_state_func
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry &, AllocatorMap &) const
IExecutionProvider(const std::string &type, OrtDevice device)
absl::InlinedVector< T, N, Allocator > InlinedVector
virtual const void * GetExecutionHandle() const noexcept
std::map< OrtDevice, AllocatorPtr > AllocatorMap
Definition: allocator.h:264
virtual std::unique_ptr< onnxruntime::IDataTransfer > GetDataTransfer() const
virtual ProviderOptions GetProviderOptions() const
std::function< void(FunctionState)> DestroyFunctionStateFunc
virtual bool IsGraphCaptureEnabled() const
std::unordered_map< std::string, std::string > ProviderOptions
IExecutionProvider(const std::string &type)
CreateFunctionStateFunc create_state_func
virtual ITuningContext * GetTuningContext() const
std::function< Status(FunctionState, const OrtApi *, OrtKernelContext *)> ComputeFunc
virtual common::Status ReplayGraph(int)
void SetLogger(const logging::Logger *logger)
virtual const InlinedVector< const Node * > GetEpContextNodes() const
const std::reference_wrapper< onnxruntime::Node > fused_node
virtual common::Status OnRunEnd(bool, const onnxruntime::RunOptions &)
virtual void GetCustomOpDomainList(std::vector< OrtCustomOpDomain * > &) const
virtual bool ConcurrentRunSupported() const
virtual std::vector< AllocatorPtr > CreatePreferredAllocators()
virtual common::Status OnSessionInitializationEnd()