6 #ifndef SHARED_PROVIDER
8 #include <unordered_map>
9 #include <unordered_set>
13 #include "core/framework/data_transfer.h"
16 namespace onnxruntime {
18 struct ComputeCapability;
20 struct KernelCreateInfo;
29 #include "core/framework/allocatormgr.h"
34 namespace onnxruntime {
43 using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
62 if (use_metadef_id_creator) {
63 metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
74 return allocator_list_;
118 virtual std::vector<std::unique_ptr<ComputeCapability>>
245 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
256 std::vector<NodeComputeInfo>& node_compute_funcs);
310 using AllocatorMap = std::unordered_map<int, AllocatorPtr>;
311 AllocatorMap allocators_;
317 std::vector<AllocatorPtr> allocator_list_;
321 class ModelMetadefIdGenerator {
326 std::unordered_map<HashValue, HashValue> main_graph_hash_;
327 std::unordered_map<HashValue, int> model_metadef_id_;
330 std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
virtual FusionStyle GetFusionStyle() const
virtual common::Status Sync() const
virtual std::vector< std::unique_ptr< ComputeCapability > > GetCapability(const onnxruntime::GraphViewer &graph_viewer, const IKernelLookup &kernel_lookup) const
virtual bool IsGraphCaptured() const
virtual int GenerateMetaDefId(const onnxruntime::GraphViewer &graph_viewer, HashValue &model_hash) const
virtual ~IExecutionProvider()=default
virtual const KernelCreateInfo * LookUpKernel(const Node &node) const =0
const std::vector< AllocatorPtr > & GetAllocators() const
GLsizei const GLchar *const * string
std::function< int(ComputeContext *, FunctionState *)> CreateFunctionStateFunc
void InsertAllocator(AllocatorPtr allocator)
virtual DataLayout GetPreferredLayout() const
void ReplaceAllocator(AllocatorPtr allocator)
virtual common::Status Compile(const std::vector< FusedNodeAndGraph > &fused_nodes_and_graphs, std::vector< NodeComputeInfo > &node_compute_funcs)
virtual std::shared_ptr< KernelRegistry > GetKernelRegistry() const
IExecutionProvider(const std::string &type, bool use_metadef_id_creator=false)
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry &) const
const std::string & Type() const
virtual std::unique_ptr< profiling::EpProfiler > GetProfiler()
const std::reference_wrapper< GraphViewer > filtered_graph
virtual AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const
virtual common::Status OnRunStart()
const logging::Logger * GetLogger() const
DestroyFunctionStateFunc release_state_func
virtual void RegisterAllocator(AllocatorManager &)
virtual const void * GetExecutionHandle() const noexcept
virtual std::unique_ptr< onnxruntime::IDataTransfer > GetDataTransfer() const
virtual ProviderOptions GetProviderOptions() const
std::function< void(FunctionState)> DestroyFunctionStateFunc
virtual bool IsGraphCaptureEnabled() const
std::unordered_map< std::string, std::string > ProviderOptions
virtual int GetDeviceId() const
virtual common::Status ReplayGraph()
CreateFunctionStateFunc create_state_func
std::shared_ptr< IAllocator > AllocatorPtr
std::function< Status(FunctionState, const OrtApi *, OrtKernelContext *)> ComputeFunc
void SetLogger(const logging::Logger *logger)
const std::reference_wrapper< onnxruntime::Node > fused_node
virtual bool ConcurrentRunSupported() const
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...
virtual common::Status OnSessionInitializationEnd()
virtual common::Status OnRunEnd(bool)