HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <limits>
7 #include <memory>
8 #include <string>
9 #include <type_traits>
10 #include <unordered_map>
11 #include <unordered_set>
12 
13 #ifdef _WIN32
14 #pragma warning(push)
15 // disable some warnings from protobuf to pass Windows build
16 #pragma warning(disable : 4244)
17 #endif
18 
19 #if !defined(ORT_MINIMAL_BUILD)
20 #include "onnx/defs/schema.h"
22 #else
23 #include "onnx/defs/data_type_utils.h"
24 #endif
25 #include "onnx/onnx_pb.h"
26 #include "onnx/onnx-operators_pb.h"
27 
28 #ifdef _WIN32
29 #pragma warning(pop)
30 #endif
31 
32 #include "core/common/gsl.h"
33 
34 #include "core/common/common.h"
37 #include "core/common/path.h"
38 #include "core/common/span_utils.h"
39 #include "core/common/status.h"
41 #include "core/graph/basic_types.h"
42 #include "core/graph/constants.h"
43 #include "core/graph/function.h"
44 #if !defined(ORT_MINIMAL_BUILD)
45 #include "core/graph/function_template.h"
46 #endif
47 #include "core/graph/graph_nodes.h"
48 #include "core/graph/node_arg.h"
49 #include "core/graph/ort_format_load_options.h"
50 
51 namespace flatbuffers {
52 class FlatBufferBuilder;
53 template <typename T>
54 struct Offset;
55 } // namespace flatbuffers
56 
57 namespace onnxruntime {
58 class Graph;
59 struct IndexedSubGraph;
60 class Model;
61 class OpSignature;
62 
63 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
64 class RuntimeOptimizationRecordContainer;
65 #endif
66 
67 namespace fbs {
68 struct Graph;
69 struct Node;
70 struct NodeEdge;
71 } // namespace fbs
72 
73 /**
74 @class Node
75 Class representing a node in the graph.
76 */
77 class Node {
78  public:
79  /** Node types */
80  enum class Type {
81  Primitive = 0, ///< The node refers to a primitive operator.
82  Fused = 1, ///< The node refers to a function.
83  };
84 
85  explicit Node() = default;
86 
87 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
89  std::string_view op_type,
90  std::string_view description,
91  gsl::span<NodeArg* const> input_args,
92  gsl::span<NodeArg* const> output_args,
93  const NodeAttributes* attributes,
94  std::string_view domain) {
95  Init(std::string{name}, std::string{op_type}, std::string{description},
96  std::vector<NodeArg*>{input_args.begin(), input_args.end()},
97  std::vector<NodeArg*>{output_args.begin(), output_args.end()},
98  attributes, std::string{domain});
99  }
100 #endif
101 
102  ~Node() = default;
103 
104  /**
105  @class EdgeEnd
106  Class representing the end of an edge. It could be an input or output edge end of a node.
107  For the node's input edge end, it's the source end, as the destination end is the node itself.
108  For the node's output edge end, it's the destination end, as the source end is the node itself.
109  */
110  class EdgeEnd {
111  public:
112  /**
113  Construct an EdgeEnd
114  @param node The source node if this is an input edge to the current node,
115  or the destination node if this is an output edge from the current node.
116  @param src_arg_index The node arg index of source node of the edge.
117  @param dst_arg_index The node arg index of destination node of the edge.
118  */
119  EdgeEnd(const Node& node, int src_arg_index, int dst_arg_index) noexcept;
120 
121  /** Construct a control edge.
122  @param node The node the edge joins to the current node.
123  */
124  explicit EdgeEnd(const Node& node) noexcept;
125 
126  /** Gets the Node that this EdgeEnd refers to. */
127  const Node& GetNode() const noexcept { return *node_; }
128 
129  /** Gets the source arg index.
130  @returns the source arg index of <*this> edge.*/
131  int GetSrcArgIndex() const { return src_arg_index_; }
132 
133  /** Gets the destination arg index.
134  @returns the destination arg index of <*this> edge.*/
135  int GetDstArgIndex() const { return dst_arg_index_; }
136 
137  private:
138  const Node* node_;
139  const int src_arg_index_;
140  const int dst_arg_index_;
141  };
142 
143  /** Gets the Node's NodeIndex. */
144  NodeIndex Index() const noexcept { return index_; }
145 
146  /** Gets the Node's name. */
147  const std::string& Name() const noexcept { return name_; }
148 
149  /** Gets the Node's operator type. */
150  const std::string& OpType() const noexcept { return op_type_; }
151 
152  /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType.
153  * @remarks If this is an ONNX operator the value will be kOnnxDomain not kOnnxDomainAlias
154  */
155  const std::string& Domain() const noexcept { return domain_; }
156 
157  /** Gets the path of the owning model if any. */
158  const Path& ModelPath() const noexcept;
159 
160  /** Gets the Node's execution priority.
161  @remarks Lower value means higher priority */
162  int Priority() const noexcept { return priority_; };
163 
164  /** Sets the execution priority of a node.
165  @remarks Lower value means higher priority */
166  void SetPriority(int priority) noexcept;
167 
168  /** Gets the node description. */
169  const std::string& Description() const noexcept { return description_; }
170 
171  /** Gets the Node's Node::Type. */
172  Node::Type NodeType() const noexcept { return node_type_; }
173 
174  /** Gets the opset version that the Node's operator was first defined in.
175  @returns Opset version. If -1 the Node's operator has not been set.
176  @remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
177  */
178  int SinceVersion() const noexcept { return since_version_; }
179 
180  /** Sets the since version (opset version that the Node's operator was first defined in.) for this node.
181  @remarks Used during layout transformation for setting since version for layout transformed nodes with
182  domain kMSNHWC.
183  */
184  void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; }
185 
186 #if !defined(ORT_MINIMAL_BUILD)
187  /** Gets the Node's OpSchema.
188  @remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */
189  const ONNX_NAMESPACE::OpSchema* Op() const noexcept { return op_; }
190 
191  /** Create a copy of the called op's FunctionProto if it has one. Returns true if successful. */
192  bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& func_proto) const;
193 
194  bool CanBeInlined() const;
195 
196  /** Gets the function body if applicable otherwise nullptr. */
197  const Function* GetFunctionBody() const noexcept { return func_body_.get(); }
198 #endif
199 
200  /**
201  Helper to iterate through the container returned by #InputDefs() or #OutputDefs() and call the provided function.
202  @param node_args Collection of NodeArgs returned by #InputDefs() or #OutputDefs()
203  @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
204  and the index number in the container.
205  @returns common::Status with success or error information.
206  @remarks Returns immediately on error.
207  */
208  static common::Status ForEachWithIndex(const ConstPointerContainer<std::vector<NodeArg*>>& node_args,
209  std::function<common::Status(const NodeArg& arg, size_t index)> func) {
210  for (size_t index = 0; index < node_args.size(); ++index) {
211  auto arg = node_args[index];
212  if (!arg->Exists())
213  continue;
214  ORT_RETURN_IF_ERROR(func(*arg, index));
215  }
216  return common::Status::OK();
217  }
218 
219  /** Gets the count of arguments for each of the Node's explicit inputs. */
220  const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
221 
222  /** Gets the Node's input definitions.
223  @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
226  }
227 
228  /** Gets the implicit inputs to this Node.
229  If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that
230  subgraph. e.g. If and Loop operators.*/
233  }
234 
235  /** Gets the Node's output definitions.
236  @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
239  }
240 
241 #if !defined(ORT_MINIMAL_BUILD)
242  /**
243  Helper to iterate through the container returned by #MutableInputDefs() or #MutableOutputDefs() and call the provided function.
244  @param node_args Collection of NodeArgs returned by #MutableInputDefs() or #MutableOutputDefs()
245  @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
246  and the index number in the container.
247  @returns common::Status with success or error information.
248  @remarks Returns immediately on error.
249  */
250  static common::Status ForEachMutableWithIndex(std::vector<NodeArg*>& node_args,
251  std::function<common::Status(NodeArg& arg, size_t index)> func) {
252  for (size_t index = 0; index < node_args.size(); ++index) {
253  auto arg = node_args[index];
254  if (!arg->Exists())
255  continue;
256  ORT_RETURN_IF_ERROR(func(*arg, index));
257  }
258  return common::Status::OK();
259  }
260 
261  /** Gets a modifiable collection of the Node's implicit input definitions. */
262  std::vector<NodeArg*>& MutableImplicitInputDefs() noexcept {
263  return definitions_.implicit_input_defs;
264  }
265 #endif // !defined(ORT_MINIMAL_BUILD)
266 
267 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
268  /** Gets a modifiable count of arguments for each of the Node's explicit inputs.
269  @todo This should be removed in favor of a method that updates the input args and the count.
270  Currently these operations are separate which is not a good setup. */
271  std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
272 
273  /** Gets a modifiable collection of the Node's input definitions. */
274  std::vector<NodeArg*>& MutableInputDefs() noexcept {
275  return definitions_.input_defs;
276  }
277 
278  /** Gets a modifiable collection of the Node's output definitions. */
279  std::vector<NodeArg*>& MutableOutputDefs() noexcept {
280  return definitions_.output_defs;
281  }
282 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
283 
284  /** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */
285  struct EdgeEndCompare {
286  bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const {
287  if (lhs.GetNode().Index() == rhs.GetNode().Index()) {
288  if (lhs.GetSrcArgIndex() == rhs.GetSrcArgIndex()) {
289  return lhs.GetDstArgIndex() < rhs.GetDstArgIndex();
290  }
291  return lhs.GetSrcArgIndex() < rhs.GetSrcArgIndex();
292  }
293  return lhs.GetNode().Index() < rhs.GetNode().Index();
294  }
295  };
296 
297  using EdgeSet = std::set<EdgeEnd, EdgeEndCompare>;
298  using EdgeConstIterator = EdgeSet::const_iterator;
299 
300  /**
301  @class NodeConstIterator
302  Class to provide const access to Node instances iterated via an EdgeConstIterator. */
304  public:
306 
307  bool operator==(const NodeConstIterator& p_other) const;
308 
309  bool operator!=(const NodeConstIterator& p_other) const;
310 
311  void operator++();
312  void operator--();
313 
314  const Node& operator*() const;
315  const Node* operator->() const;
316 
317  private:
318  EdgeConstIterator m_iter;
319  };
320 
321  // Functions defined to traverse a Graph as below.
322 
323  /** Gets an iterator to the beginning of the input nodes to this Node. */
324  NodeConstIterator InputNodesBegin() const noexcept { return NodeConstIterator(relationships_.input_edges.cbegin()); };
325  /** Gets an iterator to the end of the input nodes to this Node. */
326  NodeConstIterator InputNodesEnd() const noexcept { return NodeConstIterator(relationships_.input_edges.cend()); }
327 
328  /** Gets an iterator to the beginning of the output nodes from this Node. */
330  return NodeConstIterator(relationships_.output_edges.cbegin());
331  }
332 
333  /** Gets an iterator to the end of the output nodes from this Node. */
334  NodeConstIterator OutputNodesEnd() const noexcept { return NodeConstIterator(relationships_.output_edges.cend()); }
335 
336  /** Gets an iterator to the beginning of the input edges to this Node.
337  @remarks There are no nullptr entries in this collection. */
338  EdgeConstIterator InputEdgesBegin() const noexcept { return relationships_.input_edges.cbegin(); }
339 
340  /** Gets an iterator to the end of the input edges to this Node. */
341  EdgeConstIterator InputEdgesEnd() const noexcept { return relationships_.input_edges.cend(); }
342 
343  /** Gets an iterator to the beginning of the output edges from this Node.
344  @remarks There are no nullptr entries in this collection. */
345  EdgeConstIterator OutputEdgesBegin() const noexcept { return relationships_.output_edges.cbegin(); }
346 
347  /** Gets an iterator to the end of the output edges from this Node. */
348  EdgeConstIterator OutputEdgesEnd() const noexcept { return relationships_.output_edges.cend(); }
349 
350  /** Gets the Node's control inputs. */
351  const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
352 
353  /** Gets the number of input edges to this Node */
354  size_t GetInputEdgesCount() const noexcept { return relationships_.input_edges.size(); }
355 
356  /** Gets the number of output edges from this Node */
357  size_t GetOutputEdgesCount() const noexcept { return relationships_.output_edges.size(); }
358 
359  /** Adds an AttributeProto to this Node.
360  @remarks The attribute name is used as the key in the attribute map. */
361  void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value);
362 
363  // keep this signature in sync with ADD_ATTR_SINGLE_INTERFACE below
364  /** Adds an attribute to this Node with the specified attribute name and value. */
365  void AddAttribute(std::string attr_name, int64_t value);
366 
367  // keep this signature in sync with ADD_ATTR_LIST_INTERFACE below
368  /** Adds an attribute to this Node with the specified attribute name and values. */
369  void AddAttribute(std::string attr_name, gsl::span<const int64_t> values);
370 
371 #define ADD_ATTR_SINGLE_INTERFACE(Type) \
372  void AddAttribute(std::string attr_name, Type value)
373 
374 #define ADD_ATTR_LIST_INTERFACE(Type) \
375  void AddAttribute(std::string attr_name, gsl::span<const Type> values)
376 
377 #define ADD_ATTR_INTERFACES(Type) \
378  ADD_ATTR_SINGLE_INTERFACE(Type); \
379  ADD_ATTR_LIST_INTERFACE(Type)
380 
381  ADD_ATTR_INTERFACES(float);
383  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TensorProto);
384 #if !defined(DISABLE_SPARSE_TENSORS)
385  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::SparseTensorProto);
386 #endif
387  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TypeProto);
388 
389  ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto);
390 
391 #undef ADD_ATTR_SINGLE_INTERFACE
392 #undef ADD_ATTR_LIST_INTERFACE
393 #undef ADD_ATTR_INTERFACES
394 
395  // The below overload is made so the compiler does not attempt to resolve
396  // string literals with the gsl::span overload
397  template <size_t N>
398  void AddAttribute(std::string attr_name, const char (&value)[N]) {
399  this->AddAttribute(std::move(attr_name), std::string(value, N - 1));
400  }
401 
402  /** Gets the Node's attributes. */
403  const NodeAttributes& GetAttributes() const noexcept { return attributes_; }
404 
405 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
406  /** Remove the specified attribute from this Node */
407  bool ClearAttribute(const std::string& attr_name);
408 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
409 
410 #if !defined(ORT_MINIMAL_BUILD)
411  /** Gets the Node's mutable attributes. */
412  NodeAttributes& GetMutableAttributes() noexcept { return attributes_; }
413 
414  /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
415  @param attr_name Attribute name for the GraphProto attribute.
416  @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto.
417  */
418  const Graph* GetGraphAttribute(const std::string& attr_name) const;
419 
420  /** Gets the mutable Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
421  @param attr_name Attribute name for the GraphProto attribute.
422  @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto.
423  */
424  Graph* GetMutableGraphAttribute(const std::string& attr_name);
425 #endif // !defined(ORT_MINIMAL_BUILD)
426 
427  /** Checks if the Node contains at least one subgraph (this is the case for control flow operators, such as If, Scan, Loop).
428  @returns true if the Node contains a subgraph.
429  */
430  bool ContainsSubgraph() const {
431  return !attr_to_subgraph_map_.empty();
432  }
433 
434  /** Get the const subgraphs from a node.
435  @remarks Creates a new vector so calling ContainsSubgraphs first is preferred. */
436  std::vector<gsl::not_null<const Graph*>> GetSubgraphs() const;
437 
438  /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
439  @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
440  nullptr if the Node has no subgraphs.
441  */
442  const std::unordered_map<std::string, gsl::not_null<Graph*>>& GetAttributeNameToMutableSubgraphMap() {
443  return attr_to_subgraph_map_;
444  }
445 
446  /** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node.
447  @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
448  nullptr if the Node has no subgraphs.
449  */
450  std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;
451 
452  /** Gets the execution ProviderType that this node will be executed by. */
453  ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
454 
455  /** Sets the execution ProviderType that this Node will be executed by. */
456  void SetExecutionProviderType(ProviderType execution_provider_type) {
457  execution_provider_type_ = execution_provider_type;
458  }
459 
460  /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
461  If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
462  @param include_missing_optional_defs Include NodeArgs that are optional and were not provided
463  i.e. NodeArg::Exists() == false.
464  */
465  void ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
466  bool include_missing_optional_defs = false) const;
467 
468 #if !defined(ORT_MINIMAL_BUILD)
469  /** Replaces any matching definitions in the Node's explicit inputs or explicit outputs.
470  @param replacements Map of current NodeArg to replacement NodeArg.
471  */
472  void ReplaceDefs(const std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*>& replacements);
473 
474  /** Gets the NodeProto representation of this Node.
475  @param update_subgraphs Update the GraphProto values for any subgraphs in the returned NodeProto.
476  If graph optimization has been run this is most likely required
477  to ensure the complete Graph is valid.
478  */
479  void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const;
480 
481  Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
483 
485  SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder& builder) const;
486 
487  void SetFunctionTemplate(const FunctionTemplate& func_template);
488 #endif
489 
490  static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
491  const OrtFormatLoadOptions& load_options,
492  const logging::Logger& logger, std::unique_ptr<Node>& node);
493 
494  Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
495  const OrtFormatLoadOptions& load_options,
496  const logging::Logger& logger);
497  Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph);
498 
499  /**
500  @class Definitions
501  The input and output definitions for this Node.
502  */
503  class Definitions {
504  public:
505  Definitions() = default;
506 
507  /** The Node's explicit input definitions. */
508  std::vector<NodeArg*> input_defs;
509 
510  /**
511  The number of inputs for each argument of the operator or function which this node refers.
512  @remarks For example, #input_defs has 10 elements (inputs), and #input_arg_count is {4, 6}.
513  This means that 4 elements (inputs) of input_defs map to the first argument of the operator or function, and
514  the other 6 map to the second argument.
515  */
516  std::vector<int> input_arg_count;
517 
518  /** The Node's output definitions. */
519  std::vector<NodeArg*> output_defs;
520 
521  /** The Node's implicit input definitions if the Node contains one or more subgraphs
522  (i.e. GraphProto attributes) and the subgraph/s implicitly consume these values.
523  @remarks For example, a subgraph in an 'If' node gets all its input values via this mechanism rather than
524  there being explicit inputs to the 'If' node that are passed to the subgraph.
525  They are pseudo-inputs to this Node as it has an implicit dependency on them. */
526  std::vector<NodeArg*> implicit_input_defs;
527 
529 
530  private:
531  };
532 
533  /**
534  @class Relationships
535  Defines the relationships between this Node and other Nodes in the Graph.
536  */
538  public:
539  Relationships() = default;
540 
541  void Clear() noexcept {
542  input_edges.clear();
543  output_edges.clear();
544  control_inputs.clear();
545  }
546 
547  /** The edges for Nodes that provide inputs to this Node. */
549 
550  /** The edges for Nodes that receive outputs from this Node. */
552 
553  /** The Node names of the control inputs to this Node. */
554  std::set<std::string> control_inputs;
555 
556  private:
557  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Relationships);
558  };
559 
560  // NOTE: This friendship relationship should ONLY be used for calling methods of the Node class and not accessing
561  // the data members directly, so that the Node can maintain its internal invariants.
562  friend class Graph;
563  Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {}
564 
565  private:
567 
568 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
569  void Init(const std::string& name,
570  const std::string& op_type,
571  const std::string& description,
572  const std::vector<NodeArg*>& input_args,
573  const std::vector<NodeArg*>& output_args,
574  const NodeAttributes* attributes,
575  const std::string& domain);
576 
577  // internal only method to allow selected classes to directly alter the input/output definitions and arg counts
578  Definitions& MutableDefinitions() noexcept;
579 
580  // internal only method to allow selected classes to directly alter the links between nodes.
581  Relationships& MutableRelationships() noexcept;
582 
583  void SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; }
584 #endif
585 
586  // create a Graph instance for an attribute that contains a GraphProto
587  void CreateSubgraph(const std::string& attr_name);
588 
589  const std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
590 
591  // validate and update the input arg count
592  common::Status UpdateInputArgCount();
593 
594  const Definitions& GetDefinitions() const noexcept { return definitions_; }
595  const Relationships& GetRelationships() const noexcept { return relationships_; }
596 
597  // Node index. Default to impossible value rather than 0.
599 
600  // Node name.
601  std::string name_;
602 
603  // Node operator type.
604  std::string op_type_;
605 
606  // OperatorSet domain of op_type_.
607  std::string domain_;
608 
609 #if !defined(ORT_MINIMAL_BUILD)
610  // OperatorSchema that <*this> node refers to.
611  const ONNX_NAMESPACE::OpSchema* op_ = nullptr;
612 
613  // Reference to the function template defined in the model.
614  const FunctionTemplate* func_template_ = nullptr;
615 #endif
616 
617  // Execution priority, lower value for higher priority
618  int priority_ = 0;
619 
620  // set from op_->SinceVersion() or via deserialization when OpSchema is not available
621  int since_version_ = -1;
622 
623  Node::Type node_type_ = Node::Type::Primitive;
624 
625  // The function body is owned by graph_
626  std::unique_ptr<Function> func_body_ = nullptr;
627 
628  // Node doc string.
629  std::string description_;
630 
631  // input/output defs and arg count
632  Definitions definitions_;
633 
634  // Relationships between this node and others in the graph
635  Relationships relationships_;
636 
637  // Device.
638  std::string execution_provider_type_;
639 
640  // Map from attribute name to attribute.
641  // This allows attribute adding and removing.
642  NodeAttributes attributes_;
643 
644  // Graph that contains this Node
645  Graph* graph_ = nullptr;
646 
647  // Map of attribute name to the Graph instance created from the GraphProto attribute
648  std::unordered_map<std::string, gsl::not_null<Graph*>> attr_to_subgraph_map_;
649 
650  // Graph instances for subgraphs that are owned by this Node
651  std::vector<std::unique_ptr<Graph>> subgraphs_;
652 };
653 
654 /**
655 @class Graph
656 The Graph representation containing the graph inputs and outputs, the Node instances,
657 and the edges connecting the nodes.
658 */
659 class Graph {
660  public:
661  /** Gets the Graph name. */
662  const std::string& Name() const noexcept;
663 
664  /** Gets the Graph description. */
665  const std::string& Description() const noexcept;
666 
667  /** Gets the path of the owning model, if any. */
668  const Path& ModelPath() const;
669 
670  /** Returns true if this is a subgraph or false if it is a high-level graph. */
671  bool IsSubgraph() const { return parent_graph_ != nullptr; }
672 
673  /** Returns the parent graph if this is a subgraph */
674  const Graph* ParentGraph() const { return parent_graph_; }
675 
676  /** Returns the mutable parent graph if this is a subgraph */
677  Graph* MutableParentGraph() { return parent_graph_; }
678 
679  /** Returns the strict_shape_type_inference that was passed into the constructor. */
680  bool StrictShapeTypeInference() const { return strict_shape_type_inference_; }
681 
682 #if !defined(ORT_MINIMAL_BUILD)
683  /** Sets the Graph name. */
684  void SetName(const std::string& name);
685 
686  /** Gets the Graph description. */
687  void SetDescription(const std::string& description);
688 
689  /** Replaces the initializer tensor with the same name as the given initializer tensor.
690  The replacement initializer tensor must have the same type and shape as the existing initializer tensor.
691 
692  Note: This currently has linear time complexity. There is room for improvement but it would likely require changes to
693  how initializer tensors are stored and tracked.
694  */
695  common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer);
696 
697 #if !defined(DISABLE_EXTERNAL_INITIALIZERS)
698  /** This function takes externally provided data for initializers with external data
699  * and replaces graph initializers with its content.
700  */
702 #endif // !defined(DISABLE_EXTERNAL_INITIALIZERS)
703 
704 #endif // !defined(ORT_MINIMAL_BUILD)
705 
706 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
707  /** Add an initializer tensor to the Graph. */
708  void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto);
709 #endif
710 
711  /** Remove the initializer tensor with the provided name from the Graph. */
712  void RemoveInitializedTensor(const std::string& tensor_name);
713 
714  /** Check if a given name is an initializer tensor's name in this graph. */
715  bool IsInitializedTensor(const std::string& name) const;
716 
717 #if !defined(DISABLE_SPARSE_TENSORS)
718  /** Check if a given name is a sparse initializer's name in the model
719  * we currently convert sparse_initializer field in the model into dense Tensor instances.
720  * However, we sometimes want to check if this initializer was stored as sparse in the model.
721  */
722  bool IsSparseInitializer(const std::string& name) const;
723 #endif
724 
725  /** Gets an initializer tensor with the provided name.
726  @param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not.
727  @returns True if found.
728  */
729  bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;
730 
731  /** Gets all the initializer tensors in this Graph. */
732  const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; }
733 
734  /** Removes all initializer tensors from this Graph and releases the memory they were using. */
735  void CleanAllInitializedTensors() noexcept;
736 
737  /** Returns true if an initializer value can be overridden by a graph input with the same name. */
738  bool CanOverrideInitializer() const noexcept { return ir_version_ >= 4; }
739 
740  /** returns the initializer's TensorProto if 'name' is an initializer, is constant and
741  cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
742  @param check_outer_scope If true and the graph is a subgraph,
743  check ancestor graph/s for 'name' if not found in 'graph'.
744  @remarks check_outer_scope of true is not supported in a minimal build
745  */
746  const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
747 
748  /** returns the initializer's TensorProto if 'name' is an initializer (both constant and overridable).
749  If the initializer is not found, a nullptr is returned.
750  @param check_outer_scope If true and the graph is a subgraph,
751  check ancestor graph/s for 'name' if not found in 'graph'.
752  @remarks check_outer_scope of true is not supported in a minimal build
753  */
754  const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const;
755 
756  /** Gets the Graph inputs excluding initializers.
757  These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs.
758  @remarks Contains no nullptr values. */
759  const std::vector<const NodeArg*>& GetInputs() const noexcept { return graph_inputs_excluding_initializers_; }
760 
761  /** Gets the Graph inputs including initializers.
762  This is the full set of inputs, in the same order as defined in the GraphProto.
763  @remarks Contains no nullptr values. */
764  const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept {
765  return graph_inputs_including_initializers_;
766  }
767 
768  /** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
769  bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept {
770  return std::find(graph_inputs_including_initializers_.begin(),
771  graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
772  }
773 
774  /** Gets the Graph inputs that are initializers
775  These are overridable initializers. This is a difference between
776  graph_inputs_including_initializers_ and graph_inputs_excluding_initializers_
777  @remarks Contains no nullptr values. */
778  const std::vector<const NodeArg*>& GetOverridableInitializers() const {
779  return graph_overridable_initializers_;
780  }
781 
782  /** Gets the Graph outputs.
783  @remarks Contains no nullptr values.*/
784  const std::vector<const NodeArg*>& GetOutputs() const noexcept { return graph_outputs_; }
785 
786  bool IsOutput(const NodeArg* node_arg) const noexcept {
787  return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
788  }
789 
790  /** Returns true if one or more of the Node outputs are Graph outputs.
791  @remarks Cheaper than calling GetNodeOutputsInGraphOutputs.
792  */
793  bool NodeProducesGraphOutput(const Node& node) const {
794  auto end_outputs = graph_outputs_.cend();
795  for (auto output_def : node.OutputDefs()) {
796  if (std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
797  return true;
798  }
799  }
800  return false;
801  }
802 
803  /** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
804  std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
805  int output_idx = 0;
806  std::vector<int> indexes;
807  for (auto output_def : node.OutputDefs()) {
808  if (std::find(GetOutputs().cbegin(), GetOutputs().cend(), output_def) != GetOutputs().cend()) {
809  indexes.push_back(output_idx);
810  }
811 
812  ++output_idx;
813  }
814 
815  return indexes;
816  }
817 
818  /** Gets the NodeArgs that represent value_info instances in the Graph.
819  These are the values that are neither Graph inputs nor outputs.
820  @remarks Contains no nullptr values. */
821  const std::unordered_set<const NodeArg*>& GetValueInfo() const noexcept { return value_info_; }
822 
823 #if !defined(ORT_MINIMAL_BUILD)
824  void AddValueInfo(const NodeArg* new_value_info);
825 #endif
826 
827  /** Gets the Node with the specified node index.
828  @returns Node instance if found. nullptr if node_index is invalid or node has been freed.
829  */
830  const Node* GetNode(NodeIndex node_index) const { return NodeAtIndexImpl(node_index); }
831 
832  /** Gets the mutable Node with the specified node index.
833  @returns Mutable Node instance if found. nullptr if node_index is invalid or node has been freed.
834  */
835  Node* GetNode(NodeIndex node_index) { return NodeAtIndexImpl(node_index); }
836 
837  /** Get a GraphNodes instance that provides mutable access to all valid Nodes in the Graph. */
838  GraphNodes& Nodes() noexcept { return iterable_nodes_; }
839 
840  /** Get a GraphNodes instance that provides const access to all valid Nodes in the Graph. */
841  const GraphNodes& Nodes() const noexcept { return iterable_nodes_; }
842 
843  /** Get a ConstGraphNodes instance that provides access to a filtered set of valid Nodes in the Graph.
844  @remarks We can't use GraphNodes as that would provide mutable access to the nodes by default, and we can't prevent
845  that by returning a const instance of GraphNodes as we're creating a new instance here due to the filter
846  being something we don't control (i.e. we have to return a new instance so it can't be const).
847  */
849  return ConstGraphNodes(nodes_, std::move(filter_func));
850  }
851 
852  /** Gets the maximum NodeIndex value used in the Graph.
853  WARNING: This actually returns the max index value used + 1.
854  */
855  int MaxNodeIndex() const noexcept { return static_cast<int>(nodes_.size()); } // assume the casting won't overflow
856 
857  /** Gets the number of valid Nodes in the Graph.
858  @remarks This may be smaller than MaxNodeIndex(), as Nodes may be removed during optimization.
859  */
860  int NumberOfNodes() const noexcept { return num_of_nodes_; }
861 
862  /** Gets the mutable NodeArg with the provided name.
863  @returns Pointer to NodeArg if found, nullptr if not. */
865  auto iter = node_args_.find(name);
866  if (iter != node_args_.end()) {
867  return iter->second.get();
868  }
869  return nullptr;
870  }
871 
872  /** Gets the const NodeArg with the provided name.
873  @returns Pointer to const NodeArg if found, nullptr if not. */
874  const NodeArg* GetNodeArg(const std::string& name) const {
875  return const_cast<Graph*>(this)->GetNodeArg(name);
876  }
877 
878  // search this and up through any parent_graph_ instance for a NodeArg
879  NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name);
880 
881  /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found.
882  @param name The NodeArg name.
883  @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created.
884  @returns NodeArg reference.
885  */
886  NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) {
887  auto iter = node_args_.find(name);
888  if (iter != node_args_.end()) {
889  return *(iter->second);
890  }
891  auto result = node_args_.insert(std::make_pair(name, std::make_unique<NodeArg>(name, p_arg_type)));
892  return *(result.first->second);
893  }
894 
895 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
896  /** Generate a unique name in this Graph for a NodeArg */
897  std::string GenerateNodeArgName(const std::string& base_name);
898 
899  /** Generate a unique name in this Graph for a Node */
900  std::string GenerateNodeName(const std::string& base_name);
901 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
902 
903 #if !defined(ORT_MINIMAL_BUILD)
904  /** Copy a Node and add it to this Graph.
905  @param other Node to copy
906  @returns Reference to the Node that was created and added to this Graph.
907  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
908  */
909  Node& AddNode(const Node& other);
910 #endif
911 
912 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
913  /** Add a Node to this Graph.
914  @param name The Node name. Must be unique in this Graph.
915  @param op_type The operator type. e.g. ONNX operator name.
916  @param description Arbitrary description of the Node.
917  @param input_args The explicit inputs to this Node.
918  @param output_args The outputs from this Node.
919  @param attributes Optional NodeAttributes to add.
920  @param domain The domain for the op_type.
921  @returns Reference to the new Node.
922  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
923  */
924  Node& AddNode(const std::string& name,
925  const std::string& op_type,
926  const std::string& description,
927  gsl::span<NodeArg* const> input_args,
928  gsl::span<NodeArg* const> output_args,
929  const NodeAttributes* attributes = nullptr,
930  const std::string& domain = kOnnxDomain);
931 
933  const std::string& op_type,
934  const std::string& description,
935  std::initializer_list<NodeArg*> input_args,
936  std::initializer_list<NodeArg*> output_args,
937  const NodeAttributes* attributes = nullptr,
938  const std::string& domain = kOnnxDomain) {
939  return AddNode(name, op_type, description,
940  AsSpan(input_args),
941  AsSpan(output_args),
942  attributes, domain);
943  }
944 
946  const std::string& op_type,
947  const std::string& description,
948  gsl::span<NodeArg* const> input_args,
949  std::initializer_list<NodeArg*> output_args,
950  const NodeAttributes* attributes = nullptr,
951  const std::string& domain = kOnnxDomain) {
952  return AddNode(name, op_type, description,
953  input_args,
954  AsSpan(output_args),
955  attributes, domain);
956  }
957 
959  const std::string& op_type,
960  const std::string& description,
961  std::initializer_list<NodeArg*> input_args,
962  gsl::span<NodeArg* const> output_args,
963  const NodeAttributes* attributes = nullptr,
964  const std::string& domain = kOnnxDomain) {
965  return AddNode(name, op_type, description,
966  AsSpan(input_args),
967  output_args,
968  attributes, domain);
969  }
970 
971  /** Remove a Node from this Graph and free it.
972  The output edges of this specified node MUST have been removed before removing the node.
973  The input edges of this specified node is removed while removing the node. The process of
974  removing a node from a graph should be,
975  1. Remove out edges of this specified node.
976  2. Remove this specified node.
977  3. Add new input edges connected with all out nodes.
978  @returns true if the node_index was valid
979  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
980  */
981  bool RemoveNode(NodeIndex node_index);
982 
983  /** Add an edge between two Nodes.
984  @param src_node_index NodeIndex of source Node that is providing output to the destination Node.
985  @param dst_node_index NodeIndex of destination Node that is receiving input from the source Node.
986  @param src_arg_index node arg index of source node.
987  @param dst_arg_index node arg index of destination node.
988  */
989  void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
990 
991  /** Remove an edge between two Nodes.
992  @param src_node_index NodeIndex of source Node to remove an output edge from.
993  @param dst_node_index NodeIndex of destination Node to remove an input edge from.
994  @param src_arg_index node arg index of source node.
995  @param dst_arg_index node arg index of destination node.
996  */
997  void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
998 #endif
999 
1000 #if !defined(ORT_MINIMAL_BUILD)
1001  /**
1002  Add a control edge between two Nodes in this Graph.
1003  The source Node does not produce output that is directly consumed by the destination Node, however the
1004  destination Node must execute after the source node. The control edge allows this ordering to occur.
1005  */
1006  bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index);
1007 #endif // !defined(ORT_MINIMAL_BUILD)
1008 
1009  /** Mark the Graph as needing Resolve() to be called.
1010  This should be done after modifying any aspect of the Graph that changes the Nodes or relationships between them. */
1012  graph_resolve_needed_ = true;
1013  return *this;
1014  }
1015 
1016  /** Gets flag indicating whether Graph::Resolve needs to be called before using the Graph. */
1017  bool GraphResolveNeeded() const noexcept {
1018  return graph_resolve_needed_;
1019  }
1020 
1021  /** Sets flag that Graph::graph_proto_ needs to be updated to reflect changes in the Graph. */
1023  graph_proto_sync_needed_ = true;
1024  return *this;
1025  }
1026 
1027  /** Gets flag indicating whether Graph::graph_proto_ needs to be synchronized with this Graph instance. */
1028  bool GraphProtoSyncNeeded() const noexcept {
1029  return graph_proto_sync_needed_;
1030  }
1031 
1032  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1033  up to their source node/s.
1034  @param from NodeIndex values for a set of Nodes to traverse from.
1035  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1036  @param leave Visit function invoked on the node after its parents have all been visited.
1037  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1038  */
1039  void ReverseDFSFrom(gsl::span<NodeIndex const> from,
1040  const std::function<void(const Node*)>& enter,
1041  const std::function<void(const Node*)>& leave,
1042  const std::function<bool(const Node*, const Node*)>& comp = {}) const;
1043 
1044  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1045  up to their source node/s.
1046  @param from Set of Nodes to traverse from.
1047  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1048  @param leave Visit function invoked on the node after its parents have all been visited.
1049  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1050  */
1051  void ReverseDFSFrom(gsl::span<const Node* const> from,
1052  const std::function<void(const Node*)>& enter,
1053  const std::function<void(const Node*)>& leave,
1054  const std::function<bool(const Node*, const Node*)>& comp = {}) const;
1055 
1056  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1057  up to their source node/s.
1058  @param from Set of Nodes to traverse from.
1059  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1060  @param leave Visit function invoked on the node after its parents have all been visited.
1061  @param stop Stop traversal from node n to input node p if stop(n, p) is true.
1062  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1063  */
1064  void ReverseDFSFrom(gsl::span<const Node* const> from,
1065  const std::function<void(const Node*)>& enter,
1066  const std::function<void(const Node*)>& leave,
1067  const std::function<bool(const Node*, const Node*)>& comp,
1068  const std::function<bool(const Node*, const Node*)>& stop) const;
1069 
1070 #if !defined(ORT_MINIMAL_BUILD)
1071  /** Performs topological sort with Kahn's algorithm on the graph/s.
1072  @param enter Visit function that will be invoked on a node when it is visited.
1073  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1074  */
1075  void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
1076  const std::function<bool(const Node*, const Node*)>& comp) const;
1077 
1078 #endif
1079 
1080  /** Gets the map of operator domains to their opset versions. */
1081  const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
1082  return domain_to_version_;
1083  }
1084 
1085 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1086  /**
1087  Create a single Node that will be the result of the a fusion of multiple nodes in this Graph.
1088  @param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
1089  @param fused_node_name The name for the new Node.
1090  @returns Node with fused subgraph.
1091  @remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
1092  IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
1093  while this is in use.
1094  Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
1095  */
1096  Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1097 
1098  void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
1099 #endif
1100 
1101 #if !defined(ORT_MINIMAL_BUILD)
1102  /** Gets the GraphProto representation of this Graph. */
1103  const ONNX_NAMESPACE::GraphProto& ToGraphProto();
1104  ONNX_NAMESPACE::GraphProto ToGraphProto() const;
1105 
1106  /** Gets the GraphProto representation of this Graph
1107  @params external_file_name name of the binary file to use for initializers
1108  @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved
1109  in the external file. Initializer smaller than this threshold are included in the onnx file.
1110  @returns GraphProto serialization of the graph.
1111  */
1112  ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name,
1113  size_t initializer_size_threshold) const;
1114 
1115  /** Gets the ISchemaRegistry instances being used with this Graph. */
1117 
1118  /**
1119  Looks up the op schema in the schema registry and sets it for the given node.
1120  @param node The node to update.
1121  @return Whether the node's op schema was set to a valid value.
1122  */
1124 
1125  /**
1126  Create a single Function based Node that is the result of the a fusion of multiple nodes in this Graph.
1127  A new Graph instance will be created for the fused nodes.
1128  @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. Ownership is transferred to the new Node
1129  @param fused_node_name The name for the new Node.
1130  @returns Function based Node with fused subgraph. The Node body will contain a Function instance.
1131  */
1132  Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1133 
1134  /**
1135  Directly insert the nodes in the function Node provided into this Graph.
1136  @param node Node with Node::Type of Node::Type::Fused
1137  @returns Status indicating success or providing an error message.
1138  */
1139  Status InlineFunction(Node& node);
1140 
1141  /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
1142  be used as a GraphProto attribute in another Node..
1143  e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
1144  define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
1145  when the Graph is resolved.
1146  */
1148  ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name));
1149  }
1150 
1151  /** Explicitly set graph inputs.
1152  @param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered.
1153  @remarks Note that the input order matters for subgraphs.
1154  */
1155  void SetInputs(gsl::span<const NodeArg* const> inputs);
1156 
1157  void SetInputs(std::initializer_list<const NodeArg*> inputs) {
1158  SetInputs(AsSpan(inputs));
1159  }
1160 
1161  const Model& GetModel() const {
1162  return owning_model_;
1163  }
1164 
1165  const logging::Logger& GetLogger() const {
1166  return logger_;
1167  }
1168 
1169  /** Explicitly set graph outputs.
1170  @param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered.
1171  @remarks Note that the output order matters for subgraphs.
1172  */
1173  void SetOutputs(gsl::span<const NodeArg* const> outputs);
1174 
1175  void SetOutputs(std::initializer_list<const NodeArg*> outputs) {
1176  SetOutputs(AsSpan(outputs));
1177  }
1178 
1179 #endif // !defined(ORT_MINIMAL_BUILD)
1180 
1181 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1182  /** Sets the type of a NodeArg, replacing existing type/shape if any */
1183  void SetNodeArgType(NodeArg& arg, const ONNX_NAMESPACE::TypeProto& type_proto);
1184 
1185  const Node* GetProducerNode(const std::string& node_arg_name) const {
1186  return GetProducerNodeImpl(*this, node_arg_name);
1187  }
1188 
1189  Node* GetMutableProducerNode(const std::string& node_arg_name) {
1190  return GetProducerNodeImpl(*this, node_arg_name);
1191  }
1192 
1193  void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) {
1194  auto iter = node_arg_to_producer_node_.find(node_arg_name);
1195 
1196  if (iter != node_arg_to_producer_node_.end()) {
1197  iter->second = node_index;
1198  } else {
1199  node_arg_to_producer_node_[node_arg_name] = node_index;
1200  }
1201  }
1202 
1203  std::vector<const Node*> GetConsumerNodes(const std::string& node_arg_name) const {
1204  return GetConsumerNodesImpl(*this, node_arg_name);
1205  }
1206 
1207  // Without removing the existing consumers, add a consumer to the give node arg name.
1208  void AddConsumerNode(const std::string& node_arg_name, Node* consumer) {
1209  node_arg_to_consumer_nodes_[node_arg_name].insert(consumer->Index());
1210  }
1211 
1212  // Remove a consumer from the set
1213  void RemoveConsumerNode(const std::string& node_arg_name, Node* consumer) {
1214  node_arg_to_consumer_nodes_[node_arg_name].erase(consumer->Index());
1215  }
1216 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1217 
1218 #if !defined(ORT_MINIMAL_BUILD)
1219  std::vector<Node*> GetMutableConsumerNodes(const std::string& node_arg_name) {
1220  return GetConsumerNodesImpl(*this, node_arg_name);
1221  }
1222 
1223  void UpdateConsumerNodes(const std::string& node_arg_name, gsl::span<Node* const> nodes) {
1224  // Replace nodes for the arg
1225  auto& nodes_for_arg = node_arg_to_consumer_nodes_[node_arg_name];
1226  if (!nodes_for_arg.empty()) {
1227  nodes_for_arg.clear();
1228  }
1229 
1230  nodes_for_arg.reserve(nodes.size());
1231  for (Node* node : nodes) {
1232  nodes_for_arg.insert(node->Index());
1233  }
1234  }
1235 
1236  void UpdateConsumerNodes(const std::string& node_arg_name, std::initializer_list<Node*> nodes) {
1237  UpdateConsumerNodes(node_arg_name, AsSpan(nodes));
1238  }
1239 
1240  /** During constant folding it may become possible to infer the shape for a node.
1241  To avoid running a full Resolve allow an individual node to have the shape inferencing re-run.
1242  */
1244 
1245  // Options to control Graph::Resolve.
1247  // Whether to override existing types with inferred types.
1248  bool override_types = false;
1249  // Names of initializers to keep even if unused (optional).
1250  const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr;
1251  // Whether to set that no proto sync is required after resolving.
1252  // Useful for resolving right after loading from a GraphProto.
1254  };
1255 
1256  /**
1257  Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed.
1258  1. Run through all validation rules.
1259  a. Node name and node output's names should be unique.
1260  b. Attribute match between node and op definition.
1261  c. Input/Output match between node and op definition.
1262  d. Graph is acyclic and sort nodes in topological order.
1263  2. Check & Setup inner nodes' dependency.
1264  3. Cleanup function definition lists.
1265  Note: the weights for training can't be cleaned during resolve.
1266  @returns common::Status with success or error information.
1267  */
1268  common::Status Resolve(const ResolveOptions& options);
1269 
1271  ResolveOptions default_options;
1272  return Resolve(default_options);
1273  }
1274 
1275  const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept {
1276  return outer_scope_node_arg_names_;
1277  }
1278 
1279  common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
1281 
1282 #endif // !defined(ORT_MINIMAL_BUILD)
1283 
1284  /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */
1285  const Node* ParentNode() const { return parent_node_; }
1286 
1287  /** Returns true if the name is for a value that is coming from outer scope */
1288  bool IsOuterScopeValue(const std::string& name) const {
1289  if (!parent_node_) return false;
1290  const auto& implicit_input_defs = parent_node_->ImplicitInputDefs();
1291  return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
1292  [&name](const NodeArg* implicit_input) {
1293  return implicit_input->Name() == name;
1294  });
1295  }
1296 
1297 #if !defined(ORT_MINIMAL_BUILD)
1298  /** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node.
1299  Inherits some properties from the parent graph.
1300  @param parent_graph The Graph containing the Node that has the GraphProto attribute.
1301  @param parent_node The Node that has the GraphProto attribute.
1302  @param subgraph_proto The GraphProto from the Node attribute.
1303  */
1304  Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);
1305 
1306  Graph(const Model& owning_model,
1307  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1308  ONNX_NAMESPACE::GraphProto& subgraph_proto,
1309  const std::unordered_map<std::string, int>& domain_version_map,
1310  const logging::Logger& logger,
1311  bool strict_shape_type_inference);
1312 #endif
1313 
1314  virtual ~Graph();
1315 
1316  static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
1317  const std::unordered_map<std::string, int>& domain_to_version,
1318 #if !defined(ORT_MINIMAL_BUILD)
1319  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1320 #endif
1321  const OrtFormatLoadOptions& load_options,
1322  const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1323 
1324  // deserialize a subgraph
1325  static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
1326  Graph& parent_graph, const Node& parent_node,
1327  const OrtFormatLoadOptions& load_options,
1328  const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1329 
1330 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1331  const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
1332  return runtime_optimizations_;
1333  }
1334 
1335  RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() {
1336  return runtime_optimizations_;
1337  }
1338 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1339 
1340  // This friendship relationship should only be used to call Graph::Graph and
1341  // Graph::LoadGraph All other access should be via the public API.
1342  friend class Model;
1343 
1344  Graph() = delete;
1345 
1346  // Create empty Graph instance to re-create from ORT format serialized data.
1347  Graph(const Model& owning_model,
1348  const std::unordered_map<std::string, int>& domain_to_version,
1349 #if !defined(ORT_MINIMAL_BUILD)
1350  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1351 #endif
1352  Graph* parent_graph, const Node* parent_node,
1353  const logging::Logger& logger,
1354  bool strict_shape_type_inference);
1355 
1356  // Populate Graph instance from ORT format serialized data.
1357  Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
1358  const OrtFormatLoadOptions& load_options);
1359 
1360 #if !defined(ORT_MINIMAL_BUILD)
1361  // Constructor: Given a <GraphProto> loaded from model file, construct
1362  // a <Graph> object. Used by Model to create a Graph instance.
1363  Graph(const Model& owning_model,
1364  ONNX_NAMESPACE::GraphProto* graph_proto,
1365  const std::unordered_map<std::string, int>& domain_to_version,
1366  Version ir_version,
1367  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1368  const logging::Logger& logger,
1369  bool strict_shape_type_inference);
1370 
1371  // internal use by the Graph class only
1372  Graph(const Model& owning_model,
1373  ONNX_NAMESPACE::GraphProto* graph_proto,
1374  const std::unordered_map<std::string, int>& domain_to_version,
1375  Version ir_version,
1376  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1377  Graph* parent_graph,
1378  const Node* parent_node,
1379  const logging::Logger& logger,
1380  bool strict_shape_type_inference);
1381 
1383 
1384  private:
1385  void InitializeStateFromModelFileGraphProto();
1386 
1387  // Add node with specified <node_proto>.
1388  Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
1389  const ArgNameToTypeMap& name_to_type);
1390 
1391 #endif
1392 
1393  Version IrVersion() const noexcept {
1394  return ir_version_;
1395  }
1396 
1397  Graph& GraphResolveNeeded(bool needed) noexcept {
1398  graph_resolve_needed_ = needed;
1399  return *this;
1400  }
1401 
1402  Graph& GraphProtoSyncNeeded(bool needed) noexcept {
1403  graph_proto_sync_needed_ = needed;
1404  return *this;
1405  }
1406 
1407  // During the Resolve of a Graph it is necessary to recursively descend into subgraphs (created from GraphProto
1408  // Node attributes in the Graph) if present.
1409  // The ResolveContext holds the collection of values for the current Graph instance, be it the main graph
1410  // or a subgraph, so that the various operations that are part of the Resolve can work iteratively or
1411  // recursively as needed.
1412  struct ResolveContext {
1413  ResolveContext(const Graph& owning_graph) : graph{owning_graph} {
1414  }
1415 
1416  std::unordered_map<std::string_view, std::pair<Node*, int>> output_args;
1417  std::unordered_set<std::string_view> inputs_and_initializers;
1418  std::unordered_map<std::string_view, NodeIndex> node_name_to_index;
1419  std::unordered_set<Node*> nodes_with_subgraphs;
1420 
1421  // check if the provided name is an input/initialize/node output of this Graph instance during Graph::Resolve.
1422  // Graph::node_args_ can have stale entries so we can't rely on that.
1423  bool IsLocalValue(const std::string& name) const;
1424 
1425  // check if an ancestor graph has a valid value with the provided name during Graph::Resolve.
1426  // Once Graph::Resolve completes Graph::IsOuterScopeValue can be used and is more efficient.
1427  bool IsOuterScopeValue(const std::string& name) const;
1428 
1429  void Clear() {
1430  output_args.clear();
1431  inputs_and_initializers.clear();
1432  node_name_to_index.clear();
1433  nodes_with_subgraphs.clear();
1434  }
1435 
1436  private:
1437  bool IsInputInitializerOrOutput(const std::string& name, bool check_ancestors) const;
1438 
1439  const Graph& graph;
1440  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ResolveContext);
1441  };
1442 
1443  // Initialize all the graph inputs, initializers and outputs
1444  common::Status InitInputsInitializersOutputs();
1445 
1446  // Initialize overridable initializers container
1447  void ComputeOverridableInitializers();
1448 
1449 #if !defined(ORT_MINIMAL_BUILD)
1450  // Build and verify node connection (edges).
1451  // Verify NodeArg name/type/shape matching correctly.
1452  common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
1453 
1454  common::Status VerifyNoDuplicateName();
1455 
1456  // Check whether <*this> graph is acyclic while performing a topological sort.
1457  // Depth-first going from bottom up through the graph and checking whether there are any back edges.
1458  // NodesInTopologicalOrder is updated with the nodes' indexes in topological
1459  // order if <Status> returned is "OK", otherwise it's undefined.
1460  common::Status PerformTopologicalSortAndCheckIsAcyclic();
1461 
1462  common::Status PerformTypeAndShapeInferencing(const ResolveOptions& options);
1463 
1464  // Recursively find all subgraphs including nested subgraphs
1465  void FindAllSubgraphs(std::vector<Graph*>& subgraphs);
1466 
1467  // Iterate this Graph instance and all subgraphs, calling the provided function for each.
1468  common::Status ForThisAndAllSubgraphs(const std::vector<Graph*>& subgraphs, std::function<Status(Graph&)> func);
1469 
1470  common::Status InferAndVerifyTypeMatch(Node& node, const ONNX_NAMESPACE::OpSchema& op, const ResolveOptions& options);
1471 
1472  // perform type and shape inferencing on the subgraph and Resolve to validate
1473  static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
1474  const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
1475  std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
1476  const Graph::ResolveOptions& options);
1477 
1478  // Apply type-inference and type-checking to all inputs and initializers:
1479  common::Status TypeCheckInputsAndInitializers();
1480 
1481  // Compute set of input and initializer names and checking for duplicate names
1482  common::Status VerifyInputAndInitializerNames();
1483 
1484  // Infer and set type information across <*this> graph if needed, and verify type/attribute
1485  // information matches between node and op.
1486 
1487  common::Status VerifyNodeAndOpMatch(const ResolveOptions& options);
1488 
1489  // Set graph inputs/outputs when resolving a graph..
1490  common::Status SetGraphInputsOutputs();
1491 
1492  // recursively accumulate and set the outer scope node args in the resolve context for all subgraphs
1493  // so they can be used to resolve outer scope dependencies when running BuildConnections for the subgraphs.
1494  common::Status SetOuterScopeNodeArgs(const std::unordered_set<std::string>& outer_scope_node_args);
1495 
1496  // Implementation for initializer replacement
1497  Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external);
1498 
1499  // Clear all unused initializers and NodeArgs
1500  void CleanUnusedInitializersAndNodeArgs(const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr);
1501 
1502  std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
1503  const ArgNameToTypeMap& name_to_type_map);
1504 
1505  void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const;
1506 
1507 #endif // !defined(ORT_MINIMAL_BUILD)
1508 
1509 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1510  Status PopulateNodeArgToProducerConsumerLookupsFromNodes();
1511 
1512  template <typename TInstance>
1513  static auto GetConsumerNodesImpl(
1514  TInstance& instance, const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
1515  std::vector<decltype(instance.GetNode(0))> results;
1516  auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
1517  if (iter != instance.node_arg_to_consumer_nodes_.end()) {
1518  results.reserve(iter->second.size());
1519  for (auto node_index : iter->second) {
1520  results.push_back(instance.GetNode(node_index));
1521  }
1522  }
1523  return results;
1524  }
1525 
1526  template <typename TInstance>
1527  static auto GetProducerNodeImpl(
1528  TInstance& instance, const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
1529  auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
1530  if (iter != instance.node_arg_to_producer_node_.end()) {
1531  auto node_index = iter->second;
1532  return instance.GetNode(node_index);
1533  }
1534  return nullptr;
1535  }
1536 
1537  gsl::not_null<Node*> AllocateNode();
1538 
1539  // Release the node.
1540  // @returns false if node_index was invalid.
1541  bool ReleaseNode(NodeIndex node_index);
1542 
1543  Node& CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1544 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1545 
1546  Node* NodeAtIndexImpl(NodeIndex node_index) const {
1547  // if we are trying to access a node that doesn't exist there's (most
1548  // likely) either a logic issue or a graph consistency/correctness issue.
1549  // use ORT_ENFORCE to prove that or uncover scenarios where we actually
1550  // expect attempts to retrieve a non-existent node.
1551  ORT_ENFORCE(node_index < nodes_.size(), "Validating no unexpected access using an invalid node_index. Got:",
1552  node_index, " Max:", nodes_.size());
1553  return nodes_[node_index].get();
1554  }
1555 
1556  const Model& owning_model_;
1557 
1558  // GraphProto to store name, version, initializer.
1559  // When serializing <*this> Graph to a GraphProto, the nodes and
1560  // functions in <Graph> will also be fed into <graph_proto_> so that
1561  // it's consistent with <*this> graph.
1562  // This pointer is owned by parent model.
1563  ONNX_NAMESPACE::GraphProto* graph_proto_;
1564 
1565  // GraphProto that provides storage for the ONNX proto types deserialized from a flexbuffer/flatbuffer
1566  ONNX_NAMESPACE::GraphProto deserialized_proto_data_;
1567 
1568  InitializedTensorSet name_to_initial_tensor_;
1569 
1570  std::unordered_set<std::reference_wrapper<const std::string>,
1571  std::hash<std::string>, std::equal_to<std::string>>
1572  sparse_tensor_names_;
1573 
1574 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1575  // Runtime optimization storage.
1576  // Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized
1577  std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
1578  RuntimeOptimizationRecordContainer& runtime_optimizations_;
1579 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1580 
1581 #if !defined(ORT_MINIMAL_BUILD)
1582  IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;
1583 
1584  // Currently to make the ORT in-memory graph work, we have to create a temporary op schema
1585  // for the fused kernel. I really don't like it. but for short-term solution, let's host
1586  // those schemas here.
1587  InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
1588  // in some case, a fused sub-graph will happens multiple times in one model, we use a map
1589  // to store reusable-schema in lookup.
1590  InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1591 #endif // !defined(ORT_MINIMAL_BUILD)
1592 
1593  // Graph nodes.
1594  // Element in <nodes_> may be nullptr due to graph optimization.
1595  std::vector<std::unique_ptr<Node>> nodes_;
1596 
1597  // Wrapper of Graph nodes to provide iteration services that hide nullptr entries
1598  GraphNodes iterable_nodes_{nodes_};
1599 
1600  // Number of nodes.
1601  // Normally this is smaller than the size of <m_nodes>, as some
1602  // elements in <m_nodes> may be removed when doing graph optimization,
1603  // or some elements may be merged, etc.
1604  int num_of_nodes_ = 0;
1605 
1606  // A flag indicates whether <*this> graph needs to be resolved.
1607  bool graph_resolve_needed_ = false;
1608 
1609  bool graph_proto_sync_needed_ = false;
1610 
1611  // The topological order of node index used to do node and op match verification temporarily.
1612  std::vector<NodeIndex> nodes_in_topological_order_;
1613 
1614  // Full list of graph inputs. Matches number and order of inputs in the GraphProto.
1615  std::vector<const NodeArg*> graph_inputs_including_initializers_;
1616  bool graph_inputs_manually_set_ = false;
1617 
1618  // Graph inputs excluding initializers.
1619  std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
1620 
1621  // Overridable Initializers. The difference between graph_inputs_including_initializers_
1622  // and graph_inputs_excluding_initializers_
1623  std::vector<const NodeArg*> graph_overridable_initializers_;
1624 
1625  // Graph outputs.
1626  std::vector<const NodeArg*> graph_outputs_;
1627  bool graph_outputs_manually_set_ = false;
1628 
1629  // Graph value_info.
1630  std::unordered_set<const NodeArg*> value_info_;
1631 
1632  // All node args owned by <*this> graph. Key is node arg name.
1633  std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
1634 
1635 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1636  int name_generator_ = 0;
1637 
1638  // Strings which have been used as node names.
1639  // New node name should not conflict with this set.
1640  std::unordered_set<std::string> generated_node_names_;
1641 
1642  // Strings which have been used as node_arg names.
1643  // New node_arg name should not conflict this this set.
1644  std::unordered_set<std::string> generated_node_arg_names_;
1645 
1646  // node arg to its producer node
1647  std::unordered_map<std::string, NodeIndex> node_arg_to_producer_node_;
1648 
1649  // node arg to its consumer nodes
1650  std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1651 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1652 
1653  const std::unordered_map<std::string, int> domain_to_version_;
1654 
1655  // Model IR version.
1656  Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
1657 
1658  ResolveContext resolve_context_{*this};
1659 
1660  // the parent graph if this is a subgraph.
1661  Graph* parent_graph_;
1662  // the node containing the graph if parent_graph_ is not nullptr
1663  const Node* parent_node_;
1664 
1665  // NodeArgs that come from outer scope. Used when building a graph so that
1666  // these don't get recorded as graph inputs in the GraphProto.
1667  std::unordered_set<std::string> outer_scope_node_arg_names_;
1668 
1669  // number of times Resolve has run.
1670  int num_resolves_ = 0;
1671 
1672  const logging::Logger& logger_;
1673 
1674  // If true, all inconsistencies encountered during shape and type inference
1675  // will be exposed to the caller as failures. If false, in some cases
1676  // warnings will be logged but processing will continue and no error will
1677  // be returned.
1678  const bool strict_shape_type_inference_;
1679 
1680  // distinguishes between graph loaded from model file and graph created from scratch
1681  const bool is_loaded_from_model_file_;
1682 };
1683 
1684 #if !defined(ORT_MINIMAL_BUILD)
1685 // Print NodeArg as
1686 // name : type
1687 // For example,
1688 // "110": tensor(float)
1689 std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg);
1690 // Print Node as,
1691 // (operator's name, operator's type, domain, version) : (input0, input1, ...) -> (output0, output1, ...)
1692 // For example,
1693 // ("Add_14", Add, "", 7) : ("110": tensor(float),"109": tensor(float),) -> ("111": tensor(float),)
1694 std::ostream& operator<<(std::ostream& out, const Node& node);
1695 // Print Graph as, for example,
1696 // Inputs:
1697 // "Input": tensor(float)
1698 // Nodes:
1699 // ("add0", Add, "", 7) : ("Input": tensor(float),"Bias": tensor(float),) -> ("add0_out": tensor(float),)
1700 // ("matmul", MatMul, "", 9) : ("add0_out": tensor(float),"matmul_weight": tensor(float),) -> ("matmul_out": tensor(float),)
1701 // ("add1", Add, "", 7) : ("matmul_out": tensor(float),"add_weight": tensor(float),) -> ("add1_out": tensor(float),)
1702 // ("reshape", Reshape, "", 5) : ("add1_out": tensor(float),"concat_out": tensor(int64),) -> ("Result": tensor(float),)
1703 // Outputs:
1704 // "Result": tensor(float)
1705 // Inputs' and outputs' format is described in document of NodeArg's operator<< above.
1706 // Node format is described in Node's operator<< above.
1707 std::ostream& operator<<(std::ostream& out, const Graph& graph);
1708 #endif
1709 
1710 } // namespace onnxruntime
constexpr auto AsSpan(C &c)
Definition: span_utils.h:41
void SetNodeArgType(NodeArg &arg, const ONNX_NAMESPACE::TypeProto &type_proto)
bool IsOuterScopeValue(const std::string &name) const
Definition: graph.h:1288
bool IsInitializedTensor(const std::string &name) const
void UpdateProducerNode(const std::string &node_arg_name, NodeIndex node_index)
Definition: graph.h:1193
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
Definition: basic_types.h:33
The node refers to a primitive operator.
const std::string & ProviderType
Definition: basic_types.h:35
const Node * GetNode(NodeIndex node_index) const
Definition: graph.h:830
void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value)
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
Definition: graph.h:732
void ForEachDef(std::function< void(const onnxruntime::NodeArg &, bool is_input)> func, bool include_missing_optional_defs=false) const
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
Definition: graph.h:224
const std::vector< int > & InputArgCount() const noexcept
Definition: graph.h:220
void SetInputs(gsl::span< const NodeArg *const > inputs)
const ONNX_NAMESPACE::GraphProto & ToGraphProto()
const Function * GetFunctionBody() const noexcept
Definition: graph.h:197
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Definition: basic_types.h:44
Node(NodeIndex index, Graph &graph)
Definition: graph.h:563
int MaxNodeIndex() const noexcept
Definition: graph.h:855
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph)
The node refers to a function.
const std::string & Description() const noexcept
Definition: graph.h:169
bool IsSubgraph() const
Definition: graph.h:671
void SetFunctionTemplate(const FunctionTemplate &func_template)
NodeIndex Index() const noexcept
Definition: graph.h:144
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
int GetDstArgIndex() const
Definition: graph.h:135
Definition: Node.h:52
const Node * operator->() const
void UpdateConsumerNodes(const std::string &node_arg_name, std::initializer_list< Node * > nodes)
Definition: graph.h:1236
Node & FuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
const RuntimeOptimizationRecordContainer & RuntimeOptimizations() const
Definition: graph.h:1331
bool NodeProducesGraphOutput(const Node &node) const
Definition: graph.h:793
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
const Graph * GetGraphAttribute(const std::string &attr_name) const
const NodeAttributes & GetAttributes() const noexcept
Definition: graph.h:403
const Node * ParentNode() const
Definition: graph.h:1285
size_t GetOutputEdgesCount() const noexcept
Definition: graph.h:357
common::Status InjectExternalInitializedTensors(const InlinedHashMap< std::string, OrtValue > &external_initializers)
RuntimeOptimizationRecordContainer & MutableRuntimeOptimizations()
Definition: graph.h:1335
static common::Status ForEachMutableWithIndex(std::vector< NodeArg * > &node_args, std::function< common::Status(NodeArg &arg, size_t index)> func)
Definition: graph.h:250
bool SetOpSchemaFromRegistryForNode(Node &node)
Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Node > &fbs_node) const
#define ORT_ENFORCE(condition,...)
Definition: common.h:173
bool IsSparseInitializer(const std::string &name) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
Definition: graph.h:1081
void FinalizeFuseSubGraph(const IndexedSubGraph &sub_graph, Node &fused_node)
void ToProto(ONNX_NAMESPACE::NodeProto &proto, bool update_subgraphs=false) const
**But if you need a result
Definition: thread.h:613
NodeArg * GetNodeArg(const std::string &name)
Definition: graph.h:864
std::vector< NodeArg * > & MutableImplicitInputDefs() noexcept
Definition: graph.h:262
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Definition: core.h:1736
const NodeArg * GetNodeArg(const std::string &name) const
Definition: graph.h:874
NodeConstIterator OutputNodesEnd() const noexcept
Definition: graph.h:334
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph &fbs_graph, const Model &owning_model, const std::unordered_map< std::string, int > &domain_to_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Graph > &graph)
std::vector< Node * > GetMutableConsumerNodes(const std::string &node_arg_name)
Definition: graph.h:1219
std::unordered_map< std::string, gsl::not_null< const Graph * > > GetAttributeNameToSubgraphMap() const
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
Definition: graph.h:237
bool IsInputsIncludingInitializers(const NodeArg *node_arg) const noexcept
Definition: graph.h:769
const ONNX_NAMESPACE::OpSchema * Op() const noexcept
Definition: graph.h:189
const std::string & OpType() const noexcept
Definition: graph.h:150
friend class Model
Definition: graph.h:1342
void SetOutputs(gsl::span< const NodeArg *const > outputs)
void SetExecutionProviderType(ProviderType execution_provider_type)
Definition: graph.h:456
void AddConsumerNode(const std::string &node_arg_name, Node *consumer)
Definition: graph.h:1208
NodeConstIterator InputNodesBegin() const noexcept
Definition: graph.h:324
basic_string_view< char > string_view
Definition: core.h:522
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Definition: graph.h:231
Node * GetNode(NodeIndex node_index)
Definition: graph.h:835
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions)
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
Definition: graph.h:821
const std::string & Description() const noexcept
GraphNodes & Nodes() noexcept
Definition: graph.h:838
Graph & SetGraphResolveNeeded() noexcept
Definition: graph.h:1011
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge &fbs_node_edgs, const Graph &graph)
std::vector< gsl::not_null< const Graph * > > GetSubgraphs() const
NodeConstIterator OutputNodesBegin() const noexcept
Definition: graph.h:329
std::set< EdgeEnd, EdgeEndCompare > EdgeSet
Definition: graph.h:297
NodeAttributes & GetMutableAttributes() noexcept
Definition: graph.h:412
bool RemoveNode(NodeIndex node_index)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:932
ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto)
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto &func_proto) const
int Priority() const noexcept
Definition: graph.h:162
int NumberOfNodes() const noexcept
Definition: graph.h:860
void UpdateConsumerNodes(const std::string &node_arg_name, gsl::span< Node *const > nodes)
Definition: graph.h:1223
static Status OK()
Definition: status.h:163
common::Status Resolve()
Definition: graph.h:1270
void AddAttribute(std::string attr_name, int64_t value)
std::string GenerateNodeName(const std::string &base_name)
void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
std::vector< int > & MutableInputArgsCount()
Definition: graph.h:271
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName)
Definition: common.h:220
bool ClearAttribute(const std::string &attr_name)
std::vector< int > GetNodeOutputsInGraphOutputs(const Node &node) const
Definition: graph.h:804
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index)
static common::Status ForEachWithIndex(const ConstPointerContainer< std::vector< NodeArg * >> &node_args, std::function< common::Status(const NodeArg &arg, size_t index)> func)
Definition: graph.h:208
const ONNX_NAMESPACE::TensorProto * GetInitializer(const std::string &name, bool check_outer_scope) const
Graph * MutableParentGraph()
Definition: graph.h:677
void SetSinceVersion(int since_version) noexcept
Definition: graph.h:184
const std::string & Name() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
Definition: basic_types.h:42
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
Definition: graph.h:442
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const
const Path & ModelPath() const noexcept
void RemoveConsumerNode(const std::string &node_arg_name, Node *consumer)
Definition: graph.h:1213
size_t GetInputEdgesCount() const noexcept
Definition: graph.h:354
constexpr const char * kOnnxDomain
Definition: constants.h:12
GLuint const GLchar * name
Definition: glcorearb.h:786
std::set< std::string > control_inputs
Definition: graph.h:554
Status InlineFunction(Node &node)
flatbuffers::Offset< onnxruntime::fbs::NodeEdge > SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder &builder) const
bool GraphProtoSyncNeeded() const noexcept
Definition: graph.h:1028
Graph & SetGraphProtoSyncNeeded() noexcept
Definition: graph.h:1022
const std::vector< const NodeArg * > & GetOutputs() const noexcept
Definition: graph.h:784
void RemoveInitializedTensor(const std::string &tensor_name)
std::vector< int > input_arg_count
Definition: graph.h:516
bool IsOutput(const NodeArg *node_arg) const noexcept
Definition: graph.h:786
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Graph > &fbs_graph) const
NodeArg * GetNodeArgIncludingParentGraphs(const std::string &node_arg_name)
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope) const
void KahnsTopologicalSort(const std::function< void(const Node *)> &enter, const std::function< bool(const Node *, const Node *)> &comp) const
const Model & GetModel() const
Definition: graph.h:1161
void AddOuterScopeNodeArg(const std::string &name)
Definition: graph.h:1147
const std::unordered_set< std::string > * initializer_names_to_preserve
Definition: graph.h:1250
common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer)
const Node * GetProducerNode(const std::string &node_arg_name) const
Definition: graph.h:1185
void SetOutputs(std::initializer_list< const NodeArg * > outputs)
Definition: graph.h:1175
Status UpdateShapeInference(Node &node)
bool GraphResolveNeeded() const noexcept
Definition: graph.h:1017
const std::string & Name() const noexcept
Definition: graph.h:147
GLenum func
Definition: glcorearb.h:783
std::string GenerateNodeArgName(const std::string &base_name)
std::vector< NodeArg * > implicit_input_defs
Definition: graph.h:526
Graph * GetMutableGraphAttribute(const std::string &attr_name)
ProviderType GetExecutionProviderType() const noexcept
Definition: graph.h:453
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void SetPriority(int priority) noexcept
void SetInputs(std::initializer_list< const NodeArg * > inputs)
Definition: graph.h:1157
bool operator==(const NodeConstIterator &p_other) const
Node * GetMutableProducerNode(const std::string &node_arg_name)
Definition: graph.h:1189
const Path & ModelPath() const
Node & AddNode(const Node &other)
const std::string & Domain() const noexcept
Definition: graph.h:155
NodeArg & GetOrCreateNodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
Definition: graph.h:886
void SetName(const std::string &name)
GLuint index
Definition: glcorearb.h:786
const Node & GetNode() const noexcept
Definition: graph.h:127
bool StrictShapeTypeInference() const
Definition: graph.h:680
#define ORT_RETURN_IF_ERROR(expr)
Definition: common.h:234
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
Definition: graph.h:1203
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
Definition: graph.h:1275
GA_API const UT_StringHolder N
bool operator()(const EdgeEnd &lhs, const EdgeEnd &rhs) const
Definition: graph.h:286
ADD_ATTR_INTERFACES(float)
NodeConstIterator InputNodesEnd() const noexcept
Definition: graph.h:326
const std::set< std::string > & ControlInputs() const noexcept
Definition: graph.h:351
const logging::Logger & GetLogger() const
Definition: graph.h:1165
std::vector< NodeArg * > & MutableOutputDefs() noexcept
Definition: graph.h:279
EdgeEnd(const Node &node, int src_arg_index, int dst_arg_index) noexcept
int64_t Version
Definition: basic_types.h:31
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node &fbs_node, Graph &graph, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Node > &node)
const std::vector< const NodeArg * > & GetOverridableInitializers() const
Definition: graph.h:778
EdgeConstIterator InputEdgesBegin() const noexcept
Definition: graph.h:338
const std::vector< const NodeArg * > & GetInputs() const noexcept
Definition: graph.h:759
NodeConstIterator(EdgeConstIterator p_iter)
int SinceVersion() const noexcept
Definition: graph.h:178
Definition: core.h:1131
#define ORT_IGNORE_RETURN_VALUE(fn)
Definition: common.h:79
Node & BeginFuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
std::vector< NodeArg * > output_defs
Definition: graph.h:519
#define const
Definition: zconf.h:214
std::vector< NodeArg * > input_defs
Definition: graph.h:508
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto &tensor_proto)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, gsl::span< NodeArg *const > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:945
const GraphNodes & Nodes() const noexcept
Definition: graph.h:841
friend class Graph
Definition: graph.h:562
void ReplaceDefs(const std::map< const onnxruntime::NodeArg *, onnxruntime::NodeArg * > &replacements)
EdgeConstIterator InputEdgesEnd() const noexcept
Definition: graph.h:341
EdgeConstIterator OutputEdgesBegin() const noexcept
Definition: graph.h:345
std::unordered_map< std::string, ONNX_NAMESPACE::TypeProto > ArgNameToTypeMap
Definition: basic_types.h:34
void CleanAllInitializedTensors() noexcept
void SetDescription(const std::string &description)
bool ContainsSubgraph() const
Definition: graph.h:430
ConstGraphNodes FilteredNodes(GraphNodes::NodeFilterFunc &&filter_func) const noexcept
Definition: graph.h:848
std::vector< NodeArg * > & MutableInputDefs() noexcept
Definition: graph.h:274
bool CanOverrideInitializer() const noexcept
Definition: graph.h:738
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string &external_file_name, size_t initializer_size_threshold) const
Node::Type NodeType() const noexcept
Definition: graph.h:172
Node(std::string_view name, std::string_view op_type, std::string_view description, gsl::span< NodeArg *const > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes, std::string_view domain)
Definition: graph.h:88
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
Definition: graph.h:764
void AddValueInfo(const NodeArg *new_value_info)
bool operator!=(const NodeConstIterator &p_other) const
EdgeConstIterator OutputEdgesEnd() const noexcept
Definition: graph.h:348
EdgeSet::const_iterator EdgeConstIterator
Definition: graph.h:298
size_t NodeIndex
Definition: basic_types.h:30
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:958
int GetSrcArgIndex() const
Definition: graph.h:131
const Graph * ParentGraph() const
Definition: graph.h:674
FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool
Definition: core.h:2089
void ReverseDFSFrom(gsl::span< NodeIndex const > from, const std::function< void(const Node *)> &enter, const std::function< void(const Node *)> &leave, const std::function< bool(const Node *, const Node *)> &comp={}) const
void AddAttribute(std::string attr_name, const char(&value)[N])
Definition: graph.h:398
bool CanBeInlined() const
bool Exists() const noexcept