HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
execution_provider.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #ifndef SHARED_PROVIDER
7 #include <memory>
8 #include <unordered_map>
9 #include <unordered_set>
10 
12 #include "core/common/status.h"
13 #include "core/framework/data_transfer.h"
14 #include "core/framework/tensor.h"
15 
16 namespace onnxruntime {
17 class GraphViewer;
18 struct ComputeCapability;
19 class KernelRegistry;
20 struct KernelCreateInfo;
21 class Node;
22 } // namespace onnxruntime
23 #else
24 #include <memory>
25 #endif
26 
29 #include "core/framework/allocatormgr.h"
33 
34 namespace onnxruntime {
35 
36 /**
37  Logical device representation.
38 */
39 
40 // if we are export the fused function to dll, the function will still in the same binary as onnxruntime
41 // use std function to give execution provider some chance to capture some state.
42 using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>;
43 using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
44 using DestroyFunctionStateFunc = std::function<void(FunctionState)>;
45 
50 };
51 
52 enum class DataLayout {
53  NCHW,
54  NHWC,
55  NCHWC,
56 };
57 
59  protected:
60  IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
61  : type_{type} {
62  if (use_metadef_id_creator) {
63  metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
64  }
65  }
66 
67  public:
68  virtual ~IExecutionProvider() = default;
69 
70  /**
71  Get all IAllocators for <*this> execution provider.
72  */
73  const std::vector<AllocatorPtr>& GetAllocators() const {
74  return allocator_list_;
75  }
76 
77  /**
78  * Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist
79  */
80  virtual AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const;
81 
82  /**
83  * Returns a data transfer object that implements methods to copy to and
84  * from this device.
85  * If no copy is required for the successful operation of this provider,
86  * return a nullptr.
87  */
88  virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const {
89  return nullptr;
90  }
91 
92  /**
93  * Interface for performing kernel lookup within kernel registries.
94  * Abstracts away lower-level details about kernel registries and kernel matching.
95  */
96  class IKernelLookup {
97  public:
98  /**
99  * Given `node`, try to find a matching kernel for this EP.
100  * The return value is non-null if and only if a matching kernel was found.
101  */
102  virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
103 
104  protected:
105  ~IKernelLookup() = default;
106  };
107 
108  /**
109  Get execution provider's capability for the specified <graph>.
110  Return a bunch of IndexedSubGraphs <*this> execution provider can run if
111  the sub-graph contains only one node or can fuse to run if the sub-graph
112  contains more than one node. The node indexes contained in sub-graphs may
113  have overlap, and it's ONNXRuntime's responsibility to do the partition
114  and decide whether a node will be assigned to <*this> execution provider.
115  For kernels registered in a kernel registry, `kernel_lookup` must be used
116  to find a matching kernel for this EP.
117  */
118  virtual std::vector<std::unique_ptr<ComputeCapability>>
119  GetCapability(const onnxruntime::GraphViewer& graph_viewer,
120  const IKernelLookup& kernel_lookup) const;
121 
122  /**
123  Get kernel registry per execution provider type.
124  The KernelRegistry share pointer returned is shared across sessions.
125 
126  NOTE: this approach was taken to achieve the following goals,
127  1. The execution provider type based kernel registry should be shared
128  across sessions.
129  Only one copy of this kind of kernel registry exists in ONNXRuntime
130  with multiple sessions/models.
131  2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
132  framework/session code.
133  3. onnxruntime (framework/session) does not depend on any specific
134  execution provider lib.
135  */
136  virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; }
137 
138  /**
139  Get the device id of current execution provider
140  */
141  virtual int GetDeviceId() const { return 0; };
142 
143  /**
144  Get execution provider's configuration options.
145  */
146  virtual ProviderOptions GetProviderOptions() const { return {}; }
147 
148  /**
149  Returns an opaque handle whose exact type varies based on the provider
150  and is interpreted accordingly by the corresponding kernel implementation.
151  For Direct3D operator kernels, this may return an IUnknown supporting
152  QueryInterface to ID3D12GraphicsCommandList1.
153  */
154  virtual const void* GetExecutionHandle() const noexcept {
155  return nullptr;
156  }
157 
158  /**
159  @return type of the execution provider; should match that set in the node
160  through the SetExecutionProvider API. Example valid return values are:
161  kCpuExecutionProvider, kCudaExecutionProvider
162  */
163  const std::string& Type() const { return type_; }
164 
165  /**
166  Blocks until the device has completed all preceding requested tasks.
167  Currently this is primarily used by the IOBinding object to ensure that all
168  inputs have been copied to the device before execution begins.
169  */
170  virtual common::Status Sync() const { return Status::OK(); }
171 
172  /**
173  Called when InferenceSession::Run started
174  NOTE that due to async execution in provider, the actual work of previous
175  Run may not be finished on device This function should be regarded as the
176  point after which a new Run would start to submit commands from CPU
177  */
178  virtual common::Status OnRunStart() { return Status::OK(); }
179 
180  /**
181  Called when InferenceSession::Run ended
182  NOTE that due to async execution in provider, the actual work of this Run
183  may not be finished on device This function should be regarded as the point
184  that all commands of current Run has been submmited by CPU
185  */
186  virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
187 
188  /**
189  Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
190  the provider. Currently only CUDA execution provider supports it.
191  */
192  virtual bool IsGraphCaptureEnabled() const { return false; }
193 
194  /**
195  Indicate whether the graph has been captured and instantiated. Currently
196  only CUDA execution provider supports it.
197  */
198  virtual bool IsGraphCaptured() const { return false; }
199 
200  /**
201  Run the instantiated graph. Currently only CUDA execution provider supports
202  it.
203  */
204  virtual common::Status ReplayGraph() { return Status::OK(); }
205 
206  /**
207  Called when session creation is complete
208  This provides an opportunity for execution providers to optionally synchronize and
209  clean up its temporary resources to reduce memory and ensure the first run is fast.
210  */
212 
213  void InsertAllocator(AllocatorPtr allocator);
214  void ReplaceAllocator(AllocatorPtr allocator);
215 
217  const std::reference_wrapper<onnxruntime::Node> fused_node;
218  // GraphViewer that filters the full graph to the nodes that are covered by 'node'
219  const std::reference_wrapper<GraphViewer> filtered_graph;
220  };
221 
222  // Fusion approach that is suppported
223  // !!! The "Function" FusionStyle is deprecated.
224  // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
225  enum class FusionStyle {
226  // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
227  // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
228  // A GraphProto can be produced from the Node body.
229  Function,
230 
231  // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
232  // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
233  // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
234  // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
235  // and can be supported in a minimal build.
237  };
238 
239  virtual FusionStyle GetFusionStyle() const {
240  // All the ORT build in EP has migrate to FilteredGraphViewer style.
241  // For newer EPs, please avoid use Function style as it is deprecated.
243  }
244 
245 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
246  /**
247  Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
248  return create_state/compute/release_state func for each node.
249  @remarks This is now the default interface when execution provider wants to compile nodes
250  for both minimal build and complete ort build.
251 
252  Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
253  as it is only valid for the duration of the call to Compile.
254  */
255  virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
256  std::vector<NodeComputeInfo>& node_compute_funcs);
257 
258 #endif
259 
260  void SetLogger(const logging::Logger* logger) {
261  logger_ = logger;
262  }
263 
264  const logging::Logger* GetLogger() const {
265  return logger_;
266  }
267 
268  /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
269  The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
270  @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
271  @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
272  This is created using the model path if available,
273  or the model input names and the output names from all nodes in the main graph.
274  @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches
275  compiled kernels, so the name must be unique and deterministic across models and sessions.
276  NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and
277  virtual, and ModelMetadefIdGenerator but be defined in the header as well.
278  */
279  virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;
280 
281  /**
282  Register allocators for EP, potentially re-using existing allocators for a device from allocator_manager.
283  If the EP implements this it should generally delay creating any allocators until this is called.
284  */
285  virtual void RegisterAllocator(AllocatorManager& /*allocator_manager*/);
286 
287  virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
288  return {};
289  }
290 
291  virtual DataLayout GetPreferredLayout() const {
292  // NCHW is the default ONNX standard data layout. So default to it.
293  // EPs which prefer a different layout should override to return their preferred layout.
294  return DataLayout::NCHW;
295  }
296 
297  virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/) const {}
298 
299  /** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
300  */
301  virtual bool ConcurrentRunSupported() const { return true; }
302 
303  private:
304  const std::string type_;
305 
306  // allocator lookup is done by combining the device id and OrtMemType.
307  // there's also an implicit connection to the underlying OrtDevice involved that is dependent on the EP.
308  // e.g. for a CPU based EP, 'default' memory is a CPU device, and for a GPU based EP 'default' memory is a
309  // GPU device.
310  using AllocatorMap = std::unordered_map<int, AllocatorPtr>;
311  AllocatorMap allocators_;
312 
313  // It will be set when this object is registered to a session
314  const logging::Logger* logger_ = nullptr;
315  // convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time
316  // contains the same instances as allocators_
317  std::vector<AllocatorPtr> allocator_list_;
318 
319  // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
320  // multiple sessions.
321  class ModelMetadefIdGenerator {
322  public:
323  int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash);
324 
325  private:
326  std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash
327  std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model
328  };
329 
330  std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
331 };
332 } // namespace onnxruntime
virtual FusionStyle GetFusionStyle() const
virtual common::Status Sync() const
virtual std::vector< std::unique_ptr< ComputeCapability > > GetCapability(const onnxruntime::GraphViewer &graph_viewer, const IKernelLookup &kernel_lookup) const
virtual bool IsGraphCaptured() const
virtual int GenerateMetaDefId(const onnxruntime::GraphViewer &graph_viewer, HashValue &model_hash) const
virtual ~IExecutionProvider()=default
Definition: Node.h:52
virtual const KernelCreateInfo * LookUpKernel(const Node &node) const =0
const std::vector< AllocatorPtr > & GetAllocators() const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
std::function< int(ComputeContext *, FunctionState *)> CreateFunctionStateFunc
void InsertAllocator(AllocatorPtr allocator)
virtual DataLayout GetPreferredLayout() const
void ReplaceAllocator(AllocatorPtr allocator)
virtual common::Status Compile(const std::vector< FusedNodeAndGraph > &fused_nodes_and_graphs, std::vector< NodeComputeInfo > &node_compute_funcs)
virtual std::shared_ptr< KernelRegistry > GetKernelRegistry() const
IExecutionProvider(const std::string &type, bool use_metadef_id_creator=false)
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry &) const
const std::string & Type() const
virtual std::unique_ptr< profiling::EpProfiler > GetProfiler()
const std::reference_wrapper< GraphViewer > filtered_graph
virtual AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const
virtual common::Status OnRunStart()
const logging::Logger * GetLogger() const
DestroyFunctionStateFunc release_state_func
virtual void RegisterAllocator(AllocatorManager &)
virtual const void * GetExecutionHandle() const noexcept
virtual std::unique_ptr< onnxruntime::IDataTransfer > GetDataTransfer() const
virtual ProviderOptions GetProviderOptions() const
std::function< void(FunctionState)> DestroyFunctionStateFunc
virtual bool IsGraphCaptureEnabled() const
std::unordered_map< std::string, std::string > ProviderOptions
virtual common::Status ReplayGraph()
CreateFunctionStateFunc create_state_func
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:190
std::function< Status(FunctionState, const OrtApi *, OrtKernelContext *)> ComputeFunc
void SetLogger(const logging::Logger *logger)
const std::reference_wrapper< onnxruntime::Node > fused_node
#define const
Definition: zconf.h:214
virtual bool ConcurrentRunSupported() const
uint64_t HashValue
Definition: basic_types.h:11
type
Definition: core.h:1059
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
virtual common::Status OnSessionInitializationEnd()
virtual common::Status OnRunEnd(bool)