HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph_viewer.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 #include <unordered_set>
6 #include <filesystem>
7 
8 #include "core/graph/graph.h"
9 #include "core/framework/session_options.h"
10 
11 namespace onnxruntime {
12 class Function;
13 struct IndexedSubGraph;
14 } // namespace onnxruntime
15 
16 namespace onnxruntime {
17 
18 // use value-based compare to make sure transformer output order is consistent
19 struct NodeCompare {
20  bool operator()(const Node* n1, const Node* n2) const;
21 };
22 
23 /**
24 @class GraphViewer
25 Class that provides a read-only view of the Graph.
26 @remarks If the underlying Graph is changed, GetNodesInTopologicalOrder and GetRootNodes may become invalid.
27 */
28 class GraphViewer {
29  public:
30  /**
31  Construct a GraphViewer from the provided Graph instance.
32  */
33  explicit GraphViewer(const Graph& graph);
34 
35  /**
36  Construct a GraphViewer from the provided Graph instance, filtering to the nodes specified in the IndexedSubGraph
37  */
38  explicit GraphViewer(const Graph& graph, const IndexedSubGraph& filter_info);
39 
40  /** Gets the Graph name. */
41  const std::string& Name() const noexcept;
42 
43  /** Gets the Graph description. */
44  const std::string& Description() const noexcept;
45 
46  /** Gets the path of the owning model if any **/
47  const std::filesystem::path& ModelPath() const noexcept { return graph_->ModelPath(); }
48 
49  /**
50  Gets a tensor created from an initializer.
51  @param tensor_name The tensor name
52  @param[out] value Sets the pointer to the TensorProto if found, or nullptr if not.
53  @returns True if found. False if not.
54  */
55  bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;
56 
57  /** Returns true if an initializer value can be overridden by a graph input with the same name. */
58  bool CanOverrideInitializer() const noexcept;
59 
60  /**
61  Gets the Graph inputs, excluding initializers.
62  @returns Collection of NodeArg pointers for the graph inputs, excluding inputs that have matching initializers.
63  @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto.
64  Inputs are for filter_info_ if set.
65  */
66  const std::vector<const NodeArg*>& GetInputs() const noexcept;
67 
68  /**
69  Gets the Graph inputs, including any initializers.
70  @returns Collection of NodeArg pointers for all the graph inputs.
71  @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto.
72  Inputs are for filter_info_ if set.
73  */
74  const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept;
75 
76  /**
77  Gets the Graph outputs.
78  @returns Collection of NodeArg pointers for all the graph outputs.
79  @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto.
80  Outputs are for filter_info_ if set.
81  */
82  const std::vector<const NodeArg*>& GetOutputs() const noexcept;
83 
84  /** Returns true if one or more of the Node outputs are Graph outputs.
85  */
86  bool NodeProducesGraphOutput(const Node& node) const;
87 
88  /** Gets all ValueInfo NodeArg instances in the Graph.
89  @remarks NOT filtered using filter_info_.
90  */
91  const std::unordered_set<const NodeArg*>& GetValueInfo() const noexcept;
92 
93  /**
94  Gets the Node instance at the specified index.
95  @param node_index Index to retrieve Node from.
96  @remarks May return nullptr if index no longer points to a valid node due to the node being freed, or if
97  node is excluded by filter_info_.
98  */
99  const Node* GetNode(NodeIndex node_index) const;
100 
101  /** Gets an iterator over all the valid Nodes in the Graph.
102  @remarks Nodes are filtered using filter_info_ if set.
103  */
104  const ConstGraphNodes& Nodes() const noexcept;
105 
106  /** Gets the number of valid nodes in the Graph.
107  @remarks Returns the number of nodes in filter_info_ if set.
108  */
109  int NumberOfNodes() const noexcept;
110 
111  /** Gets the maximum NodeIndex value used by Nodes in the Graph. */
112  int MaxNodeIndex() const noexcept;
113 
114  /** Gets the NodeIndex values for the Graph nodes, sorted into topological order.
115  @remarks Filtered using filter_info_ if set.
116  */
117  const std::vector<NodeIndex>& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const;
118 
119  /**
120  Gets the NodeIndex values for the root nodes in the Graph.
121  The root nodes are the topmost nodes in the Graph that receive inputs from the Graph inputs
122  and no other nodes in the Graph.
123  @remarks Not supported if filter_info_ is set.
124  */
125  const std::vector<NodeIndex>& GetRootNodes() const;
126 
127  /** Gets all tensors created from initializers. */
128  const InitializedTensorSet& GetAllInitializedTensors() const noexcept;
129 
130  /**
131  Gets the NodeArg instance for the given name.
132  @returns A NodeArg if found, a nullptr if not.
133  */
134  const NodeArg* GetNodeArg(const std::string& name) const;
135 
136  /** Gets the map of operator domains to their opset versions. */
137  const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
138  return graph_->DomainToVersionMap();
139  }
140 
141  /** Checks if this is a Subgraph */
142  bool IsSubgraph() const;
143 
144  /** Get the internal graph*/
145  const Graph& GetGraph() const { return *graph_; }
146 
147 #if !defined(ORT_MINIMAL_BUILD)
148  const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept;
149 #endif
150 
151  /**
152  returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
153  @param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
154  if the name is not found in 'graph_'.
155  */
156  bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
157 
158  /** Check if a given name is an initializer tensor's name in this graph. */
159  bool IsInitializedTensor(const std::string& name) const;
160 
161  /** returns the initializer's TensorProto if 'name' is an initializer, is constant and
162  cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
163  @param check_outer_scope If true and the graph is a subgraph,
164  check ancestor graph/s for 'name' if not found in 'graph'.
165  @remarks This function will return the result from GetConstantInitializer of the underlying Graph,
166  if a const initializer is part of the underlying Graph but not part of this GraphViewer,
167  it will still be returned instead of nullptr
168  */
169  const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name,
170  bool check_outer_scope = true) const;
171 
172  /** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */
173  const Node* ParentNode() const noexcept { return graph_->ParentNode(); }
174 
175 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
176  /** Get the consumer nodes of a node arg */
177  std::vector<const Node*> GetConsumerNodes(const std::string& node_arg_name) const {
178  return graph_->GetConsumerNodes(node_arg_name);
179  }
180 
181  /** Get the producer node of a node arg */
182  const Node* GetProducerNode(const std::string& node_arg_name) const {
183  return graph_->GetProducerNode(node_arg_name);
184  }
185 #endif
186 
187  /** Get the filter info that restricts the graph viewer to a subset of nodes if set.
188  @returns Filter info or nullptr
189  */
190  const IndexedSubGraph* GetFilterInfo() const { return filter_info_; }
191 
192 #if !defined(ORT_MINIMAL_BUILD)
194 #endif
195 
196  private:
197  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer);
198  GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info);
199 
200  const Graph* graph_;
201  ConstGraphNodes graph_nodes_;
202 
203  // The NodeIndex values of the graph nodes sorted in topological order.
204  std::vector<NodeIndex> nodes_in_topological_order_;
205 
206 #if !defined(ORT_MINIMAL_BUILD)
207  // The NodeIndex values of the graph nodes sorted in topological order with priority.
208  std::vector<NodeIndex> nodes_in_topological_order_with_priority_;
209 #endif
210 
211 #ifdef ENABLE_TRAINING
212  // The NodeIndex values of the graph nodes sorted in memory efficient topological order.
213  std::vector<NodeIndex> nodes_in_mem_efficient_topological_order_;
214 #endif
215 
216  // Graph root nodes.
217  std::vector<NodeIndex> root_nodes_;
218 
219  // if we're limiting the view to an IndexedSubGraph we need to create a few pieces of infrastructure that would
220  // usually come from the full graph
221  const IndexedSubGraph* filter_info_{nullptr};
222  using FilteredNodeSet = InlinedHashSet<NodeIndex>;
223  FilteredNodeSet filtered_node_indices_;
224  std::vector<const NodeArg*> filtered_node_inputs_;
225  std::vector<const NodeArg*> filtered_node_inputs_including_initializers_;
226  std::vector<const NodeArg*> filtered_node_outputs_;
227  InitializedTensorSet filtered_initializers_;
228 };
229 } // namespace onnxruntime
const std::filesystem::path & ModelPath() const noexcept
Definition: graph_viewer.h:47
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
Definition: basic_types.h:35
const IndexedSubGraph * GetFilterInfo() const
Definition: graph_viewer.h:190
const Node * ParentNode() const noexcept
Definition: graph_viewer.h:173
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
bool IsConstantInitializer(const std::string &name, bool check_outer_scope) const
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Definition: basic_types.h:46
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
bool operator()(const Node *n1, const Node *n2) const
const std::vector< const NodeArg * > & GetInputs() const noexcept
GLsizei const GLfloat * value
Definition: glcorearb.h:824
int NumberOfNodes() const noexcept
GLsizei const GLchar *const * path
Definition: glcorearb.h:3341
const Node * ParentNode() const
Definition: graph.h:1404
bool NodeProducesGraphOutput(const Node &node) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
Definition: graph.h:1131
bool IsInitializedTensor(const std::string &name) const
const std::filesystem::path & ModelPath() const
int MaxNodeIndex() const noexcept
const std::string & Description() const noexcept
const std::string & Name() const noexcept
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
Definition: graph_viewer.h:193
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope=true) const
GraphViewer(const Graph &graph)
const Node * GetProducerNode(const std::string &node_arg_name) const
Definition: graph_viewer.h:182
const ConstGraphNodes & Nodes() const noexcept
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
const std::vector< NodeIndex > & GetRootNodes() const
GLdouble GLdouble GLint GLint order
Definition: glad.h:2676
GLuint const GLchar * name
Definition: glcorearb.h:786
bool CanOverrideInitializer() const noexcept
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
bool IsSubgraph() const
const Node * GetProducerNode(const std::string &node_arg_name) const
Definition: graph.h:1304
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
const std::vector< const NodeArg * > & GetOutputs() const noexcept
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
Definition: graph.h:1322
const Node * GetNode(NodeIndex node_index) const
const Graph & GetGraph() const
Definition: graph_viewer.h:145
const NodeArg * GetNodeArg(const std::string &name) const
std::unordered_map< std::string, int > DomainToVersionMap
const std::vector< NodeIndex > & GetNodesInTopologicalOrder(ExecutionOrder order=ExecutionOrder::DEFAULT) const
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
Definition: graph_viewer.h:177
size_t NodeIndex
Definition: basic_types.h:32
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const