6 #ifndef SHARED_PROVIDER
8 #include <unordered_map>
9 #include <unordered_set>
13 #include "core/framework/data_transfer.h"
14 #include "core/framework/external_data_loader.h"
17 namespace onnxruntime {
19 struct ComputeCapability;
21 struct KernelCreateInfo;
30 #include "core/framework/allocator_utils.h"
35 #include "core/framework/tuning_context.h"
39 namespace onnxruntime {
48 using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
131 virtual std::vector<std::unique_ptr<ComputeCapability>>
190 const std::string&
Type()
const {
return type_; }
221 gsl::span<const char* const> ) {
279 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
290 std::vector<NodeComputeInfo>& node_compute_funcs);
329 if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
351 const std::string type_;
virtual FusionStyle GetFusionStyle() const
virtual common::Status Sync() const
virtual common::Status OnRunStart(const onnxruntime::RunOptions &)
virtual std::vector< std::unique_ptr< ComputeCapability > > GetCapability(const onnxruntime::GraphViewer &graph_viewer, const IKernelLookup &kernel_lookup) const
virtual ~IExecutionProvider()=default
virtual const KernelCreateInfo * LookUpKernel(const Node &node) const =0
virtual bool IsGraphCaptured(int) const
std::function< int(ComputeContext *, FunctionState *)> CreateFunctionStateFunc
virtual DataLayout GetPreferredLayout() const
const OrtDevice default_device_
virtual common::Status Compile(const std::vector< FusedNodeAndGraph > &fused_nodes_and_graphs, std::vector< NodeComputeInfo > &node_compute_funcs)
virtual std::shared_ptr< KernelRegistry > GetKernelRegistry() const
virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const
virtual common::Status SetEpDynamicOptions(gsl::span< const char *const >, gsl::span< const char *const >)
const std::string & Type() const
virtual std::unique_ptr< profiling::EpProfiler > GetProfiler()
virtual std::unique_ptr< onnxruntime::IExternalDataLoader > GetExternalDataLoader() const
GLint GLint GLsizei GLint GLenum GLenum type
const std::reference_wrapper< GraphViewer > filtered_graph
const logging::Logger * GetLogger() const
DestroyFunctionStateFunc release_state_func
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry &, AllocatorMap &) const
IExecutionProvider(const std::string &type, OrtDevice device)
absl::InlinedVector< T, N, Allocator > InlinedVector
virtual const void * GetExecutionHandle() const noexcept
std::map< OrtDevice, AllocatorPtr > AllocatorMap
virtual std::unique_ptr< onnxruntime::IDataTransfer > GetDataTransfer() const
virtual ProviderOptions GetProviderOptions() const
std::function< void(FunctionState)> DestroyFunctionStateFunc
virtual bool IsGraphCaptureEnabled() const
std::unordered_map< std::string, std::string > ProviderOptions
virtual int GetDeviceId() const
IExecutionProvider(const std::string &type)
CreateFunctionStateFunc create_state_func
virtual ITuningContext * GetTuningContext() const
std::function< Status(FunctionState, const OrtApi *, OrtKernelContext *)> ComputeFunc
virtual common::Status ReplayGraph(int)
void SetLogger(const logging::Logger *logger)
virtual const InlinedVector< const Node * > GetEpContextNodes() const
const std::reference_wrapper< onnxruntime::Node > fused_node
virtual common::Status OnRunEnd(bool, const onnxruntime::RunOptions &)
virtual void GetCustomOpDomainList(std::vector< OrtCustomOpDomain * > &) const
virtual bool ConcurrentRunSupported() const
virtual std::vector< AllocatorPtr > CreatePreferredAllocators()
virtual common::Status OnSessionInitializationEnd()