HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
op_kernel_info.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
11 #include "core/common/gsl.h"
12 
13 namespace onnxruntime {
14 
15 class OrtValueNameIdxMap;
16 class FuncManager;
17 class DataTransferManager;
18 struct AllocPlanPerValue;
19 
20 // A very light-weight class, which works as an aggregated
21 // view of all data needed for constructing a Kernel instance.
22 // NOTE: it does not own/hold any objects.
23 class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
24  public:
25  explicit OpKernelInfo(const onnxruntime::Node& node,
26  const KernelDef& kernel_def,
27  const IExecutionProvider& execution_provider,
28  const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
29  const OrtValueNameIdxMap& mlvalue_name_idx_map,
30  const DataTransferManager& data_transfer_mgr);
31 
32  OpKernelInfo(const OpKernelInfo& other);
33 
34  const OrtMemoryInfo& GetMemoryInfo(int device_id, OrtMemType mem_type) const;
35 
36  AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const;
37 
38  const KernelDef& GetKernelDef() const;
39 
40  const IExecutionProvider* GetExecutionProvider() const noexcept;
41 
42  const DataTransferManager& GetDataTransferManager() const noexcept;
43 
44  const onnxruntime::Node& node() const noexcept;
45 
46  bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
47 
48  private:
49  ORT_DISALLOW_MOVE(OpKernelInfo);
50  ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
51 
52  const onnxruntime::Node& node_;
53  const KernelDef& kernel_def_;
54  // For non cpu/cuda case, this pointer should be set so that function kernel
55  // will delegate kernel compute call to <execution_provider> compute call.
56  gsl::not_null<const ::onnxruntime::IExecutionProvider*> execution_provider_;
57  const std::unordered_map<int, OrtValue>& constant_initialized_tensors_;
58  const OrtValueNameIdxMap& ort_value_name_idx_map_;
59  const DataTransferManager& data_transfer_mgr_;
60  ProtoHelperNodeContext proto_helper_context_;
61 };
62 
63 } // namespace onnxruntime
const IExecutionProvider * GetExecutionProvider() const noexcept
const DataTransferManager & GetDataTransferManager() const noexcept
OpKernelInfo(const onnxruntime::Node &node, const KernelDef &kernel_def, const IExecutionProvider &execution_provider, const std::unordered_map< int, OrtValue > &constant_initialized_tensors, const OrtValueNameIdxMap &mlvalue_name_idx_map, const DataTransferManager &data_transfer_mgr)
const onnxruntime::Node & node() const noexcept
AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const
const OrtMemoryInfo & GetMemoryInfo(int device_id, OrtMemType mem_type) const
std::shared_ptr< IAllocator > AllocatorPtr
Definition: allocator.h:190
bool TryGetConstantInput(int input_index, const Tensor **constant_input_value) const
const KernelDef & GetKernelDef() const
OrtMemType
Memory types for allocated memory, execution provider specific types should be extended in each provi...