HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <functional>
7 #include <limits>
8 #include <memory>
9 #include <string>
10 #include <type_traits>
11 #include <unordered_map>
12 #include <unordered_set>
13 #include <filesystem>
14 
15 #include "core/common/flatbuffers.h"
16 
17 #include <gsl/gsl>
18 
19 #include "core/common/common.h"
20 #include "core/common/path_string.h"
22 #if !defined(ORT_MINIMAL_BUILD)
24 #endif
26 #include "core/common/span_utils.h"
27 #include "core/common/status.h"
29 #include "core/graph/onnx_protobuf.h"
30 #include "core/graph/basic_types.h"
31 #include "core/graph/constants.h"
32 #include "core/graph/function.h"
33 #if !defined(ORT_MINIMAL_BUILD)
34 #include "core/graph/function_template.h"
35 #endif
36 #include "core/graph/graph_nodes.h"
37 #include "core/graph/node_arg.h"
38 #include "core/graph/ort_format_load_options.h"
39 
40 namespace onnxruntime {
41 class Graph;
42 struct IndexedSubGraph;
43 class Model;
44 class OpSignature;
45 
46 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
47 class RuntimeOptimizationRecordContainer;
48 #endif
49 
50 namespace fbs {
51 struct Graph;
52 struct Node;
53 struct NodeEdge;
54 } // namespace fbs
55 
56 /**
57 @class Node
58 Class representing a node in the graph.
59 */
60 class Node {
61  public:
62  /** Node types */
63  enum class Type {
64  Primitive = 0, ///< The node refers to a primitive operator.
65  Fused = 1, ///< The node refers to a function.
66  };
67 
68  explicit Node() = default;
69 
70 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
72  std::string_view op_type,
73  std::string_view description,
74  gsl::span<NodeArg* const> input_args,
75  gsl::span<NodeArg* const> output_args,
76  const NodeAttributes* attributes,
77  std::string_view domain) {
78  Init(name, op_type, description,
79  input_args,
80  output_args,
81  attributes, domain);
82  }
83 #endif
84 
85  ~Node() = default;
86 
87  /**
88  @class EdgeEnd
89  Class representing the end of an edge. It could be an input or output edge end of a node.
90  For the node's input edge end, it's the source end, as the destination end is the node itself.
91  For the node's output edge end, it's the destination end, as the source end is the node itself.
92  */
93  class EdgeEnd {
94  public:
95  /**
96  Construct an EdgeEnd
97  @param node The source node if this is an input edge to the current node,
98  or the destination node if this is an output edge from the current node.
99  @param src_arg_index The node arg index of source node of the edge.
100  @param dst_arg_index The node arg index of destination node of the edge.
101  */
102  EdgeEnd(const Node& node, int src_arg_index, int dst_arg_index) noexcept;
103 
104  /** Construct a control edge.
105  @param node The node the edge joins to the current node.
106  */
107  explicit EdgeEnd(const Node& node) noexcept;
108 
109  /** Gets the Node that this EdgeEnd refers to. */
110  const Node& GetNode() const noexcept { return *node_; }
111 
112  /** Gets the source arg index.
113  @returns the source arg index of <*this> edge.*/
114  int GetSrcArgIndex() const { return src_arg_index_; }
115 
116  /** Gets the destination arg index.
117  @returns the destination arg index of <*this> edge.*/
118  int GetDstArgIndex() const { return dst_arg_index_; }
119 
120  private:
121  const Node* node_;
122  const int src_arg_index_;
123  const int dst_arg_index_;
124  };
125 
126  /** Gets the Node's NodeIndex. */
127  NodeIndex Index() const noexcept { return index_; }
128 
129  /** Gets the Node's name. */
130  const std::string& Name() const noexcept { return name_; }
131 
132  /** Gets the Node's operator type. */
133  const std::string& OpType() const noexcept { return op_type_; }
134 
135  /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType.
136  * @remarks If this is an ONNX operator the value will be kOnnxDomain not kOnnxDomainAlias
137  */
138  const std::string& Domain() const noexcept { return domain_; }
139 
140  /** Gets the path of the owning model if any. */
141  const std::filesystem::path& ModelPath() const noexcept;
142 
143  /** Gets the Node's execution priority.
144  @remarks Lower value means higher priority */
145  int Priority() const noexcept { return priority_; };
146 
147  /** Sets the execution priority of a node.
148  @remarks Lower value means higher priority */
149  void SetPriority(int priority) noexcept;
150 
151  /** Gets the node description. */
152  const std::string& Description() const noexcept { return description_; }
153 
154  /** Gets the Node's Node::Type. */
155  Node::Type NodeType() const noexcept { return node_type_; }
156 
157  /** Gets the opset version that the Node's operator was first defined in.
158  @returns Opset version. If -1 the Node's operator has not been set.
159  @remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build
160  */
161  int SinceVersion() const noexcept { return since_version_; }
162 
163  /** Sets the since version (opset version that the Node's operator was first defined in.) for this node.
164  @remarks Used during layout transformation for setting since version for layout transformed nodes with
165  domain kMSNHWC.
166  */
167  void SetSinceVersion(int since_version) noexcept { since_version_ = since_version; }
168 
169 #if !defined(ORT_MINIMAL_BUILD)
170  /** Gets the Node's OpSchema.
171  @remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */
172  const ONNX_NAMESPACE::OpSchema* Op() const noexcept { return op_; }
173 
174  /** Create a copy of the called op's FunctionProto if it has one. Returns true if successful. */
175  bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& func_proto) const;
176 
177  bool CanBeInlined() const;
178 
179  /** Gets the function body if applicable otherwise nullptr. */
180  const Function* GetFunctionBody() const noexcept { return func_body_.get(); }
181 #endif
182 
183  /**
184  Helper to iterate through the container returned by #InputDefs() or #OutputDefs() and call the provided function.
185  @param node_args Collection of NodeArgs returned by #InputDefs() or #OutputDefs()
186  @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
187  and the index number in the container.
188  @returns common::Status with success or error information.
189  @remarks Returns immediately on error.
190  */
191  static common::Status ForEachWithIndex(const ConstPointerContainer<std::vector<NodeArg*>>& node_args,
192  std::function<common::Status(const NodeArg& arg, size_t index)> func) {
193  for (size_t index = 0; index < node_args.size(); ++index) {
194  auto arg = node_args[index];
195  if (!arg->Exists())
196  continue;
197  ORT_RETURN_IF_ERROR(func(*arg, index));
198  }
199  return common::Status::OK();
200  }
201 
202  /** Gets the count of arguments for each of the Node's explicit inputs. */
203  const std::vector<int>& InputArgCount() const noexcept { return definitions_.input_arg_count; }
204 
205  /** Gets the Node's input definitions.
206  @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
209  }
210 
211  /** Gets the implicit inputs to this Node.
212  If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that
213  subgraph. e.g. If and Loop operators.*/
216  }
217 
218  /** Gets the Node's output definitions.
219  @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */
222  }
223 
224 #if !defined(ORT_MINIMAL_BUILD)
225  /**
226  Helper to iterate through the container returned by #MutableInputDefs() or #MutableOutputDefs() and call the provided function.
227  @param node_args Collection of NodeArgs returned by #MutableInputDefs() or #MutableOutputDefs()
228  @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg
229  and the index number in the container.
230  @returns common::Status with success or error information.
231  @remarks Returns immediately on error.
232  */
233  static common::Status ForEachMutableWithIndex(std::vector<NodeArg*>& node_args,
234  std::function<common::Status(NodeArg& arg, size_t index)> func) {
235  for (size_t index = 0; index < node_args.size(); ++index) {
236  auto arg = node_args[index];
237  if (!arg->Exists())
238  continue;
239  ORT_RETURN_IF_ERROR(func(*arg, index));
240  }
241  return common::Status::OK();
242  }
243 
244  /** Gets a modifiable collection of the Node's implicit input definitions. */
245  std::vector<NodeArg*>& MutableImplicitInputDefs() noexcept {
246  return definitions_.implicit_input_defs;
247  }
248 #endif // !defined(ORT_MINIMAL_BUILD)
249 
250 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
251  /** Gets a modifiable count of arguments for each of the Node's explicit inputs.
252  @todo This should be removed in favor of a method that updates the input args and the count.
253  Currently these operations are separate which is not a good setup. */
254  std::vector<int>& MutableInputArgsCount() { return definitions_.input_arg_count; }
255 
256  /** Gets a modifiable collection of the Node's input definitions. */
257  std::vector<NodeArg*>& MutableInputDefs() noexcept {
258  return definitions_.input_defs;
259  }
260 
261  /** Gets a modifiable collection of the Node's output definitions. */
262  std::vector<NodeArg*>& MutableOutputDefs() noexcept {
263  return definitions_.output_defs;
264  }
265 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
266 
267  /** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */
268  struct EdgeEndCompare {
269  bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const {
270  if (lhs.GetNode().Index() == rhs.GetNode().Index()) {
271  if (lhs.GetSrcArgIndex() == rhs.GetSrcArgIndex()) {
272  return lhs.GetDstArgIndex() < rhs.GetDstArgIndex();
273  }
274  return lhs.GetSrcArgIndex() < rhs.GetSrcArgIndex();
275  }
276  return lhs.GetNode().Index() < rhs.GetNode().Index();
277  }
278  };
279 
280  using EdgeSet = std::set<EdgeEnd, EdgeEndCompare>;
281  using EdgeConstIterator = EdgeSet::const_iterator;
282 
283  /**
284  @class NodeConstIterator
285  Class to provide const access to Node instances iterated via an EdgeConstIterator. */
287  public:
289 
290  bool operator==(const NodeConstIterator& p_other) const;
291 
292  bool operator!=(const NodeConstIterator& p_other) const;
293 
294  void operator++();
295  void operator--();
296 
297  const Node& operator*() const;
298  const Node* operator->() const;
299 
300  private:
301  EdgeConstIterator m_iter;
302  };
303 
304  // Functions defined to traverse a Graph as below.
305 
306  /** Gets an iterator to the beginning of the input nodes to this Node. */
307  NodeConstIterator InputNodesBegin() const noexcept { return NodeConstIterator(relationships_.input_edges.cbegin()); };
308  /** Gets an iterator to the end of the input nodes to this Node. */
309  NodeConstIterator InputNodesEnd() const noexcept { return NodeConstIterator(relationships_.input_edges.cend()); }
310 
311  /** Gets an iterator to the beginning of the output nodes from this Node. */
313  return NodeConstIterator(relationships_.output_edges.cbegin());
314  }
315 
316  /** Gets an iterator to the end of the output nodes from this Node. */
317  NodeConstIterator OutputNodesEnd() const noexcept { return NodeConstIterator(relationships_.output_edges.cend()); }
318 
319  /** Gets an iterator to the beginning of the input edges to this Node.
320  @remarks There are no nullptr entries in this collection. */
321  EdgeConstIterator InputEdgesBegin() const noexcept { return relationships_.input_edges.cbegin(); }
322 
323  /** Gets an iterator to the end of the input edges to this Node. */
324  EdgeConstIterator InputEdgesEnd() const noexcept { return relationships_.input_edges.cend(); }
325 
326  /** Gets an iterator to the beginning of the output edges from this Node.
327  @remarks There are no nullptr entries in this collection. */
328  EdgeConstIterator OutputEdgesBegin() const noexcept { return relationships_.output_edges.cbegin(); }
329 
330  /** Gets an iterator to the end of the output edges from this Node. */
331  EdgeConstIterator OutputEdgesEnd() const noexcept { return relationships_.output_edges.cend(); }
332 
333  /** Gets the Node's control inputs. */
334  const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
335 
336  /** Gets the number of input edges to this Node */
337  size_t GetInputEdgesCount() const noexcept { return relationships_.input_edges.size(); }
338 
339  /** Gets the number of output edges from this Node */
340  size_t GetOutputEdgesCount() const noexcept { return relationships_.output_edges.size(); }
341 
342  /** Adds an AttributeProto to this Node.
343  @remarks The attribute name is used as the key in the attribute map. */
344  void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value);
345 
346  // keep this signature in sync with ADD_ATTR_SINGLE_INTERFACE below
347  /** Adds an attribute to this Node with the specified attribute name and value. */
348  void AddAttribute(std::string attr_name, int64_t value);
349 
350  // keep this signature in sync with ADD_ATTR_LIST_INTERFACE below
351  /** Adds an attribute to this Node with the specified attribute name and values. */
352  void AddAttribute(std::string attr_name, gsl::span<const int64_t> values);
353 
354 #define ADD_ATTR_SINGLE_INTERFACE(Type) \
355  void AddAttribute(std::string attr_name, Type value)
356 
357 #define ADD_ATTR_LIST_INTERFACE(Type) \
358  void AddAttribute(std::string attr_name, gsl::span<const Type> values)
359 
360 #define ADD_ATTR_INTERFACES(Type) \
361  ADD_ATTR_SINGLE_INTERFACE(Type); \
362  ADD_ATTR_LIST_INTERFACE(Type)
363 
364  ADD_ATTR_INTERFACES(float);
365  ADD_ATTR_INTERFACES(std::string);
366  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TensorProto);
367 #if !defined(DISABLE_SPARSE_TENSORS)
368  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::SparseTensorProto);
369 #endif
370  ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TypeProto);
371 
372  ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto);
373 
374 #undef ADD_ATTR_SINGLE_INTERFACE
375 #undef ADD_ATTR_LIST_INTERFACE
376 #undef ADD_ATTR_INTERFACES
377 
378  // The below overload is made so the compiler does not attempt to resolve
379  // string literals with the gsl::span overload
380  template <size_t N>
381  void AddAttribute(std::string attr_name, const char (&value)[N]) {
382  this->AddAttribute(std::move(attr_name), std::string(value, N - 1));
383  }
384 
385  /** Gets the Node's attributes. */
386  const NodeAttributes& GetAttributes() const noexcept { return attributes_; }
387 
388 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
389  /** Remove the specified attribute from this Node */
390  bool ClearAttribute(const std::string& attr_name);
391 
392  /** Gets the Node's mutable attributes. */
393  NodeAttributes& GetMutableAttributes() noexcept { return attributes_; }
394 
395 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
396 
397  /**
398  * Clears removable attributes. These are no longer needed after the initialization
399  * of the session. The function returns the number of removed attributes.
400  */
401  int PruneRemovableAttributes(gsl::span<const std::string> removable_attributes);
402 
403 #if !defined(ORT_MINIMAL_BUILD)
404 
405  /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
406  @param attr_name Attribute name for the GraphProto attribute.
407  @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto.
408  */
409  const Graph* GetGraphAttribute(const std::string& attr_name) const;
410 
411  /** Gets the mutable Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve.
412  @param attr_name Attribute name for the GraphProto attribute.
413  @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto.
414  */
415  Graph* GetMutableGraphAttribute(const std::string& attr_name);
416 #endif // !defined(ORT_MINIMAL_BUILD)
417 
418  /** Checks if the Node contains at least one subgraph (this is the case for control flow operators, such as If, Scan, Loop).
419  @returns true if the Node contains a subgraph.
420  */
421  bool ContainsSubgraph() const {
422  return !attr_to_subgraph_map_.empty();
423  }
424 
425  /** Get the const subgraphs from a node.
426  @remarks Creates a new vector so calling ContainsSubgraphs first is preferred. */
427  std::vector<gsl::not_null<const Graph*>> GetSubgraphs() const;
428 
429  /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
430  @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
431  nullptr if the Node has no subgraphs.
432  */
433  const std::unordered_map<std::string, gsl::not_null<Graph*>>& GetAttributeNameToMutableSubgraphMap() {
434  return attr_to_subgraph_map_;
435  }
436 
437  /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
438  * @returns a mutable map of mutable subgraphs.
439  */
440  std::unordered_map<std::string, gsl::not_null<Graph*>>& GetMutableMapOfAttributeNameToSubgraph() {
441  return attr_to_subgraph_map_;
442  }
443 
444  /** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node.
445  @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
446  nullptr if the Node has no subgraphs.
447  */
448  std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;
449 
450  /** Gets the execution ProviderType that this node will be executed by. */
451  ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }
452 
453  /** Sets the execution ProviderType that this Node will be executed by. */
454  void SetExecutionProviderType(ProviderType execution_provider_type) {
455  execution_provider_type_ = execution_provider_type;
456  }
457 
458  /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node.
459  If the NodeArg is an explicit or implicit input, is_input will be true when func is called.
460  @param include_missing_optional_defs Include NodeArgs that are optional and were not provided
461  i.e. NodeArg::Exists() == false.
462  */
463  void ForEachDef(std::function<void(const onnxruntime::NodeArg&, bool is_input)> func,
464  bool include_missing_optional_defs = false) const;
465 
466 #if !defined(ORT_MINIMAL_BUILD)
467  /** Replaces any matching definitions in the Node's explicit inputs or explicit outputs.
468  @param replacements Map of current NodeArg to replacement NodeArg.
469  */
470  void ReplaceDefs(const std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*>& replacements);
471 
472  /** Gets the NodeProto representation of this Node.
473  @param update_subgraphs Update the GraphProto values for any subgraphs in the returned NodeProto.
474  If graph optimization has been run this is most likely required
475  to ensure the complete Graph is valid.
476  */
477  void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const;
478 
479  Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
480  flatbuffers::Offset<onnxruntime::fbs::Node>& fbs_node) const;
481 
482  flatbuffers::Offset<onnxruntime::fbs::NodeEdge>
483  SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder& builder) const;
484 
485  void SetFunctionTemplate(const FunctionTemplate& func_template);
486 #endif
487 
488  static Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node, Graph& graph,
489  const OrtFormatLoadOptions& load_options,
490  const logging::Logger& logger, std::unique_ptr<Node>& node);
491 
492  Status LoadFromOrtFormat(const onnxruntime::fbs::Node& fbs_node,
493  const OrtFormatLoadOptions& load_options,
494  const logging::Logger& logger);
495  Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_edgs, const Graph& graph);
496 
497  /**
498  @class Definitions
499  The input and output definitions for this Node.
500  */
501  class Definitions {
502  public:
503  Definitions() = default;
504 
505  /** The Node's explicit input definitions. */
506  std::vector<NodeArg*> input_defs;
507 
508  /**
509  The number of inputs for each argument of the operator or function which this node refers.
510  @remarks For example, #input_defs has 10 elements (inputs), and #input_arg_count is {4, 6}.
511  This means that 4 elements (inputs) of input_defs map to the first argument of the operator or function, and
512  the other 6 map to the second argument.
513  */
514  std::vector<int> input_arg_count;
515 
516  /** The Node's output definitions. */
517  std::vector<NodeArg*> output_defs;
518 
519  /** The Node's implicit input definitions if the Node contains one or more subgraphs
520  (i.e. GraphProto attributes) and the subgraph/s implicitly consume these values.
521  @remarks For example, a subgraph in an 'If' node gets all its input values via this mechanism rather than
522  there being explicit inputs to the 'If' node that are passed to the subgraph.
523  They are pseudo-inputs to this Node as it has an implicit dependency on them. */
524  std::vector<NodeArg*> implicit_input_defs;
525 
527 
528  private:
529  };
530 
531  /**
532  @class Relationships
533  Defines the relationships between this Node and other Nodes in the Graph.
534  */
536  public:
537  Relationships() = default;
538 
539  void Clear() noexcept {
540  input_edges.clear();
541  output_edges.clear();
542  control_inputs.clear();
543  }
544 
545  /** The edges for Nodes that provide inputs to this Node. */
547 
548  /** The edges for Nodes that receive outputs from this Node. */
550 
551  /** The Node names of the control inputs to this Node. */
552  std::set<std::string> control_inputs;
553 
554  private:
555  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Relationships);
556  };
557 
558  // NOTE: This friendship relationship should ONLY be used for calling methods of the Node class and not accessing
559  // the data members directly, so that the Node can maintain its internal invariants.
560  friend class Graph;
561  Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph), can_be_saved_(true) {}
562 
563  private:
565 
566 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
567  void Init(std::string_view name,
568  std::string_view op_type,
569  std::string_view description,
570  gsl::span<NodeArg* const> input_args,
571  gsl::span<NodeArg* const> output_args,
572  const NodeAttributes* attributes,
573  std::string_view domain);
574  void Init(std::string_view name,
575  std::string_view op_type,
576  std::string_view description,
577  gsl::span<NodeArg* const> input_args,
578  gsl::span<NodeArg* const> output_args,
579  NodeAttributes&& attributes,
580  std::string_view domain);
581 #endif
582 
583 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
584  // internal only method to allow selected classes to directly alter the input/output definitions and arg counts
585  Definitions& MutableDefinitions() noexcept;
586 
587  // internal only method to allow selected classes to directly alter the links between nodes.
588  Relationships& MutableRelationships() noexcept;
589 
590  void SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; }
591 #endif
592 
593  // create a Graph instance for an attribute that contains a GraphProto
594  void CreateSubgraph(const std::string& attr_name);
595 
596  std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept { return subgraphs_; }
597 
598  // validate and update the input arg count
599  common::Status UpdateInputArgCount();
600 
601  const Definitions& GetDefinitions() const noexcept { return definitions_; }
602  const Relationships& GetRelationships() const noexcept { return relationships_; }
603 
604  // Node index. Default to impossible value rather than 0.
606 
607  // Node name.
608  std::string name_;
609 
610  // Node operator type.
611  std::string op_type_;
612 
613  // OperatorSet domain of op_type_.
614  std::string domain_;
615 
616 #if !defined(ORT_MINIMAL_BUILD)
617  // OperatorSchema that <*this> node refers to.
618  const ONNX_NAMESPACE::OpSchema* op_ = nullptr;
619 
620  // Reference to the function template defined in the model.
621  const FunctionTemplate* func_template_ = nullptr;
622 
623  // set/clear NodeProto that the Node was created from.
624  // Set by Graph ctor when loading a model from file.
625  // Cleared after first call to onnx::check_node in VerifyNodeAndOpMatch when the first Graph::Resolve runs.
626  void SetOriginalNodeProto(const ONNX_NAMESPACE::NodeProto* node_proto) {
627  original_node_proto_ = node_proto;
628  }
629 
630  const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto() const {
631  return original_node_proto_;
632  }
633 
634  // NodeProto that the Node was created from. We temporarily set this as a performance optimization to avoid calling
635  // Node::ToProto when running onnx::check_node in the first Graph::Resolve. At that point we know all the nodes are
636  // unchanged from the original model.
637  const ONNX_NAMESPACE::NodeProto* original_node_proto_ = nullptr;
638 #endif
639 
640  // Execution priority, lower value for higher priority
641  int priority_ = 0;
642 
643  // set from op_->SinceVersion() or via deserialization when OpSchema is not available
644  int since_version_ = -1;
645 
646  Node::Type node_type_ = Node::Type::Primitive;
647 
648  // The function body is owned by graph_
649  std::unique_ptr<Function> func_body_ = nullptr;
650 
651  // Node doc string.
652  std::string description_;
653 
654  // input/output defs and arg count
655  Definitions definitions_;
656 
657  // Relationships between this node and others in the graph
658  Relationships relationships_;
659 
660  // Device.
661  std::string execution_provider_type_;
662 
663  // Map from attribute name to attribute.
664  // This allows attribute adding and removing.
665  NodeAttributes attributes_;
666 
667  // Graph that contains this Node
668  Graph* graph_ = nullptr;
669 
670  // Map of attribute name to the Graph instance created from the GraphProto attribute
671  std::unordered_map<std::string, gsl::not_null<Graph*>> attr_to_subgraph_map_;
672 
673  // Graph instances for subgraphs that are owned by this Node
674  std::vector<std::unique_ptr<Graph>> subgraphs_;
675 
676  // Can be saved? The node cannot be saved anymore if removable attributes have been cleared.
677  bool can_be_saved_;
678 };
679 
680 /**
681 @class Graph
682 The Graph representation containing the graph inputs and outputs, the Node instances,
683 and the edges connecting the nodes.
684 */
685 class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve existing data member order for readability
686  public:
687  /** Gets the Graph name. */
688  const std::string& Name() const noexcept;
689 
690  /** Gets the Graph description. */
691  const std::string& Description() const noexcept;
692 
693  /** Gets the path of the owning model, if any. */
694  const std::filesystem::path& ModelPath() const;
695 
696  /** Returns true if this is a subgraph or false if it is a high-level graph. */
697  bool IsSubgraph() const { return parent_graph_ != nullptr; }
698 
699  /** Returns the parent graph if this is a subgraph */
700  const Graph* ParentGraph() const { return parent_graph_; }
701 
702  /** Returns the mutable parent graph if this is a subgraph */
703  Graph* MutableParentGraph() { return parent_graph_; }
704 
705  /** Returns the strict_shape_type_inference that was passed into the constructor. */
706  bool StrictShapeTypeInference() const { return strict_shape_type_inference_; }
707 
708 #if !defined(ORT_MINIMAL_BUILD)
709  /** Sets the Graph name. */
710  void SetName(const std::string& name);
711 
712  /** Gets the Graph description. */
713  void SetDescription(const std::string& description);
714 
715  /** Replaces the initializer tensor with the same name as the given initializer tensor.
716  The replacement initializer tensor must have the same type and shape as the existing initializer tensor.
717 
718  Note: This currently has linear time complexity. There is room for improvement but it would likely require changes to
719  how initializer tensors are stored and tracked.
720  */
721  common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer);
722 
723 #if !defined(DISABLE_EXTERNAL_INITIALIZERS)
724  /** This function takes externally provided data for initializers with external data
725  * and replaces graph initializers with its content.
726  */
728 
729  /** This function takes externally provided files in memory for initializers with external
730  * data and replaces graph initializers with its content.
731  */
733  const InlinedHashMap<PathString, std::pair<char*, size_t>>& external_initializer_files);
734 #endif // !defined(DISABLE_EXTERNAL_INITIALIZERS)
735 
736 #endif // !defined(ORT_MINIMAL_BUILD)
737 
738 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
739  /** Add an initializer tensor to the Graph. */
740  void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto);
741 #endif
742 
743  /** Remove the initializer tensor with the provided name from the Graph. */
744  void RemoveInitializedTensor(const std::string& tensor_name);
745 
746  /** Check if a given name is an initializer tensor's name in this graph. */
747  bool IsInitializedTensor(const std::string& name) const;
748 
749 #if !defined(DISABLE_SPARSE_TENSORS)
750  /** Check if a given name is a sparse initializer's name in the model
751  * we currently convert sparse_initializer field in the model into dense Tensor instances.
752  * However, we sometimes want to check if this initializer was stored as sparse in the model.
753  */
754  bool IsSparseInitializer(const std::string& name) const;
755 #endif
756 
757  /** Gets an initializer tensor with the provided name.
758  @param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not.
759  @returns True if found.
760  */
761  bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;
762 
763  /** Gets all the initializer tensors in this Graph. */
764  const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; }
765 
766  /** Removes all initializer tensors from this Graph and releases the memory they were using. */
767  void CleanAllInitializedTensors() noexcept;
768 
769  /** Returns true if an initializer value can be overridden by a graph input with the same name. */
770  bool CanOverrideInitializer() const noexcept { return ir_version_ >= 4; }
771 
772  /** returns the initializer's TensorProto if 'name' is an initializer, is constant and
773  cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned.
774  @param check_outer_scope If true and the graph is a subgraph,
775  check ancestor graph/s for 'name' if not found in 'graph'.
776  */
777  const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const;
778 
779  /** returns the initializer's TensorProto if 'name' is an initializer (both constant and overridable).
780  If the initializer is not found, a nullptr is returned.
781  @param check_outer_scope If true and the graph is a subgraph,
782  check ancestor graph/s for 'name' if not found in 'graph'.
783  @remarks check_outer_scope of true is not supported in a minimal build
784  */
785  const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const;
786 
787  /** Gets the Graph inputs excluding initializers.
788  These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs.
789  @remarks Contains no nullptr values. */
790  const std::vector<const NodeArg*>& GetInputs() const noexcept { return graph_inputs_excluding_initializers_; }
791 
792  /** Gets the Graph inputs including initializers.
793  This is the full set of inputs, in the same order as defined in the GraphProto.
794  @remarks Contains no nullptr values. */
795  const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept {
796  return graph_inputs_including_initializers_;
797  }
798 
799  /** Return true if "node_arg" is a input or an initializer. Otherwise, returns false. */
800  bool IsInputsIncludingInitializers(const NodeArg* node_arg) const noexcept {
801  return std::find(graph_inputs_including_initializers_.begin(),
802  graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
803  }
804 
805  /** Gets the Graph inputs that are initializers
806  These are overridable initializers. This is a difference between
807  graph_inputs_including_initializers_ and graph_inputs_excluding_initializers_
808  @remarks Contains no nullptr values. */
809  const std::vector<const NodeArg*>& GetOverridableInitializers() const {
810  return graph_overridable_initializers_;
811  }
812 
813  /** Gets the Graph outputs.
814  @remarks Contains no nullptr values.*/
815  const std::vector<const NodeArg*>& GetOutputs() const noexcept { return graph_outputs_; }
816 
817  bool IsOutput(const NodeArg* node_arg) const noexcept {
818  return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
819  }
820 
821  /** Returns true if one or more of the Node outputs are Graph outputs.
822  @remarks Cheaper than calling GetNodeOutputsInGraphOutputs.
823  */
824  bool NodeProducesGraphOutput(const Node& node) const {
825  auto end_outputs = graph_outputs_.cend();
826  for (auto output_def : node.OutputDefs()) {
827  if (std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
828  return true;
829  }
830  }
831  return false;
832  }
833 
834  /** Returns a vector with the indexes of the outputs of the given Node that are also Graph outputs. */
835  std::vector<int> GetNodeOutputsInGraphOutputs(const Node& node) const {
836  int output_idx = 0;
837  std::vector<int> indexes;
838  for (auto output_def : node.OutputDefs()) {
839  if (std::find(GetOutputs().cbegin(), GetOutputs().cend(), output_def) != GetOutputs().cend()) {
840  indexes.push_back(output_idx);
841  }
842 
843  ++output_idx;
844  }
845 
846  return indexes;
847  }
848 
849  /** Gets the NodeArgs that represent value_info instances in the Graph.
850  These are the values that are neither Graph inputs nor outputs.
851  @remarks Contains no nullptr values. */
852  const std::unordered_set<const NodeArg*>& GetValueInfo() const noexcept { return value_info_; }
853 
854 #if !defined(ORT_MINIMAL_BUILD)
855  void AddValueInfo(const NodeArg* new_value_info);
856 #endif
857 
858  /** Gets the Node with the specified node index.
859  @returns Node instance if found. nullptr if node_index is invalid or node has been freed.
860  */
861  const Node* GetNode(NodeIndex node_index) const { return NodeAtIndexImpl(node_index); }
862 
863  /** Gets the mutable Node with the specified node index.
864  @returns Mutable Node instance if found. nullptr if node_index is invalid or node has been freed.
865  */
866  Node* GetNode(NodeIndex node_index) { return NodeAtIndexImpl(node_index); }
867 
868  /** Get a GraphNodes instance that provides mutable access to all valid Nodes in the Graph. */
869  GraphNodes& Nodes() noexcept { return iterable_nodes_; }
870 
871  /** Get a GraphNodes instance that provides const access to all valid Nodes in the Graph. */
872  const GraphNodes& Nodes() const noexcept { return iterable_nodes_; }
873 
874  /** Get a ConstGraphNodes instance that provides access to a filtered set of valid Nodes in the Graph.
875  @remarks We can't use GraphNodes as that would provide mutable access to the nodes by default, and we can't prevent
876  that by returning a const instance of GraphNodes as we're creating a new instance here due to the filter
877  being something we don't control (i.e. we have to return a new instance so it can't be const).
878  */
880  return ConstGraphNodes(nodes_, std::move(filter_func));
881  }
882 
883  /** Gets the maximum NodeIndex value used in the Graph.
884  WARNING: This actually returns the max index value used + 1.
885  */
886  int MaxNodeIndex() const noexcept { return static_cast<int>(nodes_.size()); } // assume the casting won't overflow
887 
888  /** Gets the number of valid Nodes in the Graph.
889  @remarks This may be smaller than MaxNodeIndex(), as Nodes may be removed during optimization.
890  */
891  int NumberOfNodes() const noexcept { return num_of_nodes_; }
892 
893  /** Gets the mutable NodeArg with the provided name.
894  @returns Pointer to NodeArg if found, nullptr if not. */
895  NodeArg* GetNodeArg(const std::string& name) {
896  auto iter = node_args_.find(name);
897  if (iter != node_args_.end()) {
898  return iter->second.get();
899  }
900  return nullptr;
901  }
902 
903  /** Gets the const NodeArg with the provided name.
904  @returns Pointer to const NodeArg if found, nullptr if not. */
905  const NodeArg* GetNodeArg(const std::string& name) const {
906  return const_cast<Graph*>(this)->GetNodeArg(name);
907  }
908 
909  // search this and up through any parent_graph_ instance for a NodeArg
910  NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name);
911 
912  /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found.
913  @param name The NodeArg name.
914  @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created.
915  @returns NodeArg reference.
916  */
917  NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) {
918  auto insert_result = node_args_.emplace(name, nullptr);
919  if (insert_result.second) {
920  insert_result.first->second = std::make_unique<NodeArg>(name, p_arg_type);
921  }
922  return *(insert_result.first->second);
923  }
924 
925 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
926  /** Generate a unique name in this Graph for a NodeArg */
927  std::string GenerateNodeArgName(const std::string& base_name);
928 
929  /** Generate a unique name in this Graph for a Node */
930  std::string GenerateNodeName(const std::string& base_name);
931 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
932 
933 #if !defined(ORT_MINIMAL_BUILD)
934  /** Copy a Node and add it to this Graph.
935  @param other Node to copy
936  @returns Reference to the Node that was created and added to this Graph.
937  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
938  */
939  Node& AddNode(const Node& other);
940 #endif
941 
942 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
943  /** Add a Node to this Graph.
944  @param name The Node name. Must be unique in this Graph.
945  @param op_type The operator type. e.g. ONNX operator name.
946  @param description Arbitrary description of the Node.
947  @param input_args The explicit inputs to this Node.
948  @param output_args The outputs from this Node.
949  @param attributes Optional NodeAttributes to add.
950  @param domain The domain for the op_type.
951  @returns Reference to the new Node.
952  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
953  */
954  Node& AddNode(const std::string& name,
955  const std::string& op_type,
956  const std::string& description,
957  gsl::span<NodeArg* const> input_args,
958  gsl::span<NodeArg* const> output_args,
959  const NodeAttributes* attributes = nullptr,
960  const std::string& domain = kOnnxDomain);
961 
962  Node& AddNode(const std::string& name,
963  const std::string& op_type,
964  const std::string& description,
965  gsl::span<NodeArg* const> input_args,
966  gsl::span<NodeArg* const> output_args,
967  NodeAttributes&& attributes,
968  const std::string& domain = kOnnxDomain);
969  Node& AddNode(const std::string& name,
970  const std::string& op_type,
971  const std::string& description,
972  std::initializer_list<NodeArg*> input_args,
973  std::initializer_list<NodeArg*> output_args,
974  const NodeAttributes* attributes = nullptr,
975  const std::string& domain = kOnnxDomain) {
976  return AddNode(name, op_type, description,
977  AsSpan(input_args),
978  AsSpan(output_args),
979  attributes, domain);
980  }
981 
982  Node& AddNode(const std::string& name,
983  const std::string& op_type,
984  const std::string& description,
985  gsl::span<NodeArg* const> input_args,
986  std::initializer_list<NodeArg*> output_args,
987  const NodeAttributes* attributes = nullptr,
988  const std::string& domain = kOnnxDomain) {
989  return AddNode(name, op_type, description,
990  input_args,
991  AsSpan(output_args),
992  attributes, domain);
993  }
994 
995  Node& AddNode(const std::string& name,
996  const std::string& op_type,
997  const std::string& description,
998  std::initializer_list<NodeArg*> input_args,
999  gsl::span<NodeArg* const> output_args,
1000  const NodeAttributes* attributes = nullptr,
1001  const std::string& domain = kOnnxDomain) {
1002  return AddNode(name, op_type, description,
1003  AsSpan(input_args),
1004  output_args,
1005  attributes, domain);
1006  }
1007 
1008  /** Remove a Node from this Graph and free it.
1009  The output edges of this specified node MUST have been removed before removing the node.
1010  The input edges of this specified node is removed while removing the node. The process of
1011  removing a node from a graph should be,
1012  1. Remove out edges of this specified node.
1013  2. Remove this specified node.
1014  3. Add new input edges connected with all out nodes.
1015  @returns true if the node_index was valid
1016  @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe.
1017  */
1018  bool RemoveNode(NodeIndex node_index);
1019 
1020  /** Add an edge between two Nodes.
1021  @param src_node_index NodeIndex of source Node that is providing output to the destination Node.
1022  @param dst_node_index NodeIndex of destination Node that is receiving input from the source Node.
1023  @param src_arg_index node arg index of source node.
1024  @param dst_arg_index node arg index of destination node.
1025  */
1026  void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
1027 
1028  /** Remove an edge between two Nodes.
1029  @param src_node_index NodeIndex of source Node to remove an output edge from.
1030  @param dst_node_index NodeIndex of destination Node to remove an input edge from.
1031  @param src_arg_index node arg index of source node.
1032  @param dst_arg_index node arg index of destination node.
1033  */
1034  void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index);
1035 #endif
1036 
1037 #if !defined(ORT_MINIMAL_BUILD)
1038  /**
1039  Add a control edge between two Nodes in this Graph.
1040  The source Node does not produce output that is directly consumed by the destination Node, however the
1041  destination Node must execute after the source node. The control edge allows this ordering to occur.
1042  */
1043  bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index);
1044 #endif // !defined(ORT_MINIMAL_BUILD)
1045 
1046  /** Mark the Graph as needing Resolve() to be called.
1047  This should be done after modifying any aspect of the Graph that changes the Nodes or relationships between them. */
1049  graph_resolve_needed_ = true;
1050  return *this;
1051  }
1052 
1053  /** Gets flag indicating whether Graph::Resolve needs to be called before using the Graph. */
1054  bool GraphResolveNeeded() const noexcept {
1055  return graph_resolve_needed_;
1056  }
1057 
1058  /** Sets flag that Graph::graph_proto_ needs to be updated to reflect changes in the Graph. */
1060  graph_proto_sync_needed_ = true;
1061  return *this;
1062  }
1063 
1064  /** Gets flag indicating whether Graph::graph_proto_ needs to be synchronized with this Graph instance. */
1065  bool GraphProtoSyncNeeded() const noexcept {
1066  return graph_proto_sync_needed_;
1067  }
1068 
1069  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1070  up to their source node/s.
1071  @param from NodeIndex values for a set of Nodes to traverse from.
1072  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1073  @param leave Visit function invoked on the node after its parents have all been visited.
1074  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1075  */
1076  void ReverseDFSFrom(gsl::span<NodeIndex const> from,
1077  const std::function<void(const Node*)>& enter,
1078  const std::function<void(const Node*)>& leave,
1079  const std::function<bool(const Node*, const Node*)>& comp = {}) const;
1080 
1081  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1082  up to their source node/s.
1083  @param from Set of Nodes to traverse from.
1084  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1085  @param leave Visit function invoked on the node after its parents have all been visited.
1086  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1087  */
1088  void ReverseDFSFrom(gsl::span<const Node* const> from,
1089  const std::function<void(const Node*)>& enter,
1090  const std::function<void(const Node*)>& leave,
1091  const std::function<bool(const Node*, const Node*)>& comp = {}) const;
1092 
1093  /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs,
1094  up to their source node/s.
1095  @param from Set of Nodes to traverse from.
1096  @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been.
1097  @param leave Visit function invoked on the node after its parents have all been visited.
1098  @param stop Stop traversal from node n to input node p if stop(n, p) is true.
1099  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1100  */
1101  void ReverseDFSFrom(gsl::span<const Node* const> from,
1102  const std::function<void(const Node*)>& enter,
1103  const std::function<void(const Node*)>& leave,
1104  const std::function<bool(const Node*, const Node*)>& comp,
1105  const std::function<bool(const Node*, const Node*)>& stop) const;
1106 
1107 #if !defined(ORT_MINIMAL_BUILD)
1108  /** Performs topological sort with Kahn's algorithm on the graph/s.
1109  @param enter Visit function that will be invoked on a node when it is visited.
1110  @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic.
1111  */
1112  void KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
1113  const std::function<bool(const Node*, const Node*)>& comp) const;
1114 
1115 #endif
1116 
1117 #ifdef ENABLE_TRAINING
1118  /**
1119  * @brief Performs topological sort with customized Kahn's algorithm on the graph/s.
1120  * This is a specialized version for training where need memory efficient topological sort.
1121  * @param yield_op The YieldOp used in ORTModule training.
1122  * @param shape_size_parents The shape size parents nodes.
1123  * @param node_orders The output node orders.
1124  */
1125  void MemoryEfficientTopologicalSort(const Node* yield_op,
1126  const InlinedHashMap<NodeIndex, InlinedVector<NodeIndex>>& shape_size_parents,
1127  std::vector<NodeIndex>& node_orders) const;
1128 #endif
1129 
1130  /** Gets the map of operator domains to their opset versions. */
1131  const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
1132  return domain_to_version_;
1133  }
1134 
1135 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1136  /**
1137  Create a single Node that will be the result of the a fusion of multiple nodes in this Graph.
1138  @param sub_graph A IndexSubGraph instance with details of the nodes to fuse.
1139  @param fused_node_name The name for the new Node.
1140  @returns Node with fused subgraph.
1141  @remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
1142  IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
1143  while this is in use.
1144  Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
1145  */
1146  Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1147 
1148  void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
1149 #endif
1150 
1151 #if !defined(ORT_MINIMAL_BUILD)
1152  /** Gets the GraphProto representation of this Graph. */
1153  const ONNX_NAMESPACE::GraphProto& ToGraphProto();
1154  ONNX_NAMESPACE::GraphProto ToGraphProto() const;
1155 
1156  // Options to align external initializer offset.
1157  // For models running on CPU, ORT will try to use mmap to load external initializers.
1158  // To use mmap, external initializer need to be offset aligned.
1159  // ORT saves external initializers into signle data file, each initializer is accessed with
1160  // offset(start position of initializer) and length(byte length of initializer) of the data file.
1161  // To use mmap, each offset need to be aligned which means offset need to divisible by
1162  // allocation granularity(64KB for windows and 4K for other OSes).
1163  // With align_offset to true, ORT will align offset for large initializer when
1164  // save ONNX model with external data file.
1166  // Offset will always be page aligned and allocation granularity aligned for mmap support.
1167  // This is done by padding previous tensor data with zeros keeping same length.
1168  bool align_offset = false;
1169  // Alignment threshold for size of data.
1170  // Having a low threshold will waste file space for small initializers.
1171  // Only when tensor's data size is > the page_align_threshold it will be force aligned.
1172  // Default to 1MB.
1173  int64_t align_threshold = 1048576;
1174  // The allocation Granularity for mmap() support.
1175  // Typically 64KB for Windows & 4KB for other OSes. Default to 64KB.
1176  int64_t allocation_granularity = 65536;
1177  };
1178 
1179  /** Gets the GraphProto representation of this Graph
1180  @param external_file_path File path of the binary file to use for initializers.
1181  @param model_file_path path of the model file.
1182  @param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved
1183  in the external file. Initializer smaller than this threshold are included in the onnx file.
1184  @param align_info offset alignment info.
1185  @returns GraphProto serialization of the graph.
1186  */
1187  ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
1188  const std::filesystem::path& model_file_path,
1189  size_t initializer_size_threshold,
1190  const OffsetAlignmentInfo& align_info) const;
1191 
1192  ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path& external_file_path,
1193  const std::filesystem::path& model_file_path,
1194  size_t initializer_size_threshold) const {
1195  OffsetAlignmentInfo default_options;
1196  return ToGraphProtoWithExternalInitializers(external_file_path, model_file_path, initializer_size_threshold, default_options);
1197  }
1198 
1199  /** Gets the ISchemaRegistry instances being used with this Graph. */
1201 
1202  /**
1203  Looks up the op schema in the schema registry and sets it for the given node.
1204  @param node The node to update.
1205  @return Whether the node's op schema was set to a valid value.
1206  */
1208 
1209  /**
1210  Create a single Function based Node that is the result of the a fusion of multiple nodes in this Graph.
1211  A new Graph instance will be created for the fused nodes.
1212  @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. Ownership is transferred to the new Node
1213  @param fused_node_name The name for the new Node.
1214  @returns Function based Node with fused subgraph. The Node body will contain a Function instance.
1215  */
1216  Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1217 
1218  /**
1219  Directly insert one of the If node branches into this Graph.
1220  `If` node condition must be a constant. The function would
1221  rename the nodes of the corresponding subgraph to make sure there is no conflict.
1222 
1223  Explicit and implicit inputs references stay the same.
1224 
1225  All of the outputs of the subgraph being inlined should be renamed
1226  to the outputs of the If node.
1227 
1228  The function will process any subgraphs in each of the nodes being inlined,
1229  and will rename any references to the new names introduced.
1230 
1231  @param condition_value If condition value
1232  @param if_node - the node that contains the graph_to_inline. This node is going
1233  to be deleted and replaced by the corresponding graph (either then or else)
1234  @param logger
1235  */
1236  Status InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger);
1237 
1238  /**
1239  Directly insert the nodes in the function Node provided into this Graph.
1240  The Graph needs to be Resolve()d after this call.
1241  @param node Node with Node::Type of Node::Type::Fused
1242  @returns Status indicating success or providing an error message.
1243  */
1244  Status InlineFunction(Node& node);
1245 
1246  /**
1247  Directly insert the nodes in the function proto provided into the graph.
1248  The function converts Constant nodes into the initializers in the graph.
1249  It then creates a node in the graph for each of the function nodes.
1250  All of the names are expected to be specialized, and, therefore unique.
1251  See function_utils::Specialize().
1252 
1253  The Graph needs to be Resolve()d after this call.
1254  @param func_to_inline
1255  @returns Status indicating success or providing an error message.
1256  */
1257 
1258  Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);
1259 
1260  /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
1261  be used as a GraphProto attribute in another Node.
1262  e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
1263  define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
1264  when the Graph is resolved.
1265  */
1266  void AddOuterScopeNodeArg(const std::string& name) {
1267  ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name));
1268  }
1269 
1270  /** Explicitly set graph inputs.
1271  @param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered.
1272  @remarks Note that the input order matters for subgraphs.
1273  */
1274  void SetInputs(gsl::span<const NodeArg* const> inputs);
1275 
1276  void SetInputs(std::initializer_list<const NodeArg*> inputs) {
1277  SetInputs(AsSpan(inputs));
1278  }
1279 
1280  const Model& GetModel() const {
1281  return owning_model_;
1282  }
1283 
1284  const logging::Logger& GetLogger() const {
1285  return logger_;
1286  }
1287 
1288  /** Explicitly set graph outputs.
1289  @param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered.
1290  @remarks Note that the output order matters for subgraphs.
1291  */
1292  void SetOutputs(gsl::span<const NodeArg* const> outputs);
1293 
1294  void SetOutputs(std::initializer_list<const NodeArg*> outputs) {
1295  SetOutputs(AsSpan(outputs));
1296  }
1297 
1298 #endif // !defined(ORT_MINIMAL_BUILD)
1299 
1300 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1301  /** Sets the type of a NodeArg, replacing existing type/shape if any */
1302  void SetNodeArgType(NodeArg& arg, const ONNX_NAMESPACE::TypeProto& type_proto);
1303 
1304  const Node* GetProducerNode(const std::string& node_arg_name) const {
1305  return GetProducerNodeImpl(*this, node_arg_name);
1306  }
1307 
1308  Node* GetMutableProducerNode(const std::string& node_arg_name) {
1309  return GetProducerNodeImpl(*this, node_arg_name);
1310  }
1311 
1312  void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) {
1313  auto iter = node_arg_to_producer_node_.find(node_arg_name);
1314 
1315  if (iter != node_arg_to_producer_node_.end()) {
1316  iter->second = node_index;
1317  } else {
1318  node_arg_to_producer_node_[node_arg_name] = node_index;
1319  }
1320  }
1321 
1322  std::vector<const Node*> GetConsumerNodes(const std::string& node_arg_name) const {
1323  return GetConsumerNodesImpl(*this, node_arg_name);
1324  }
1325 
1326  // Without removing the existing consumers, add a consumer to the give node arg name.
1327  void AddConsumerNode(const std::string& node_arg_name, Node* consumer) {
1328  node_arg_to_consumer_nodes_[node_arg_name].insert(consumer->Index());
1329  }
1330 
1331  // Remove a consumer from the set
1332  void RemoveConsumerNode(const std::string& node_arg_name, Node* consumer) {
1333  node_arg_to_consumer_nodes_[node_arg_name].erase(consumer->Index());
1334  }
1335 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1336 
1337 #if !defined(ORT_MINIMAL_BUILD)
1338  std::vector<Node*> GetMutableConsumerNodes(const std::string& node_arg_name) {
1339  return GetConsumerNodesImpl(*this, node_arg_name);
1340  }
1341 
1342  void UpdateConsumerNodes(const std::string& node_arg_name, gsl::span<Node* const> nodes) {
1343  // Replace nodes for the arg
1344  auto& nodes_for_arg = node_arg_to_consumer_nodes_[node_arg_name];
1345  if (!nodes_for_arg.empty()) {
1346  nodes_for_arg.clear();
1347  }
1348 
1349  nodes_for_arg.reserve(nodes.size());
1350  for (Node* node : nodes) {
1351  nodes_for_arg.insert(node->Index());
1352  }
1353  }
1354 
1355  void UpdateConsumerNodes(const std::string& node_arg_name, std::initializer_list<Node*> nodes) {
1356  UpdateConsumerNodes(node_arg_name, AsSpan(nodes));
1357  }
1358 
1359  /** During constant folding it may become possible to infer the shape for a node.
1360  To avoid running a full Resolve allow an individual node to have the shape inferencing re-run.
1361  */
1363 
1364  // Options to control Graph::Resolve.
1366  // Whether to override existing types with inferred types.
1367  bool override_types = false;
1368  // Names of initializers to keep even if unused (optional).
1369  const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr;
1370  // Whether to set that no proto sync is required after resolving.
1371  // Useful for resolving right after loading from a GraphProto.
1373  };
1374 
1375  /**
1376  Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed.
1377  1. Run through all validation rules.
1378  a. Node name and node output's names should be unique.
1379  b. Attribute match between node and op definition.
1380  c. Input/Output match between node and op definition.
1381  d. Graph is acyclic and sort nodes in topological order.
1382  2. Check & Setup inner nodes' dependency.
1383  3. Cleanup function definition lists.
1384  Note: the weights for training can't be cleaned during resolve.
1385  @returns common::Status with success or error information.
1386  */
1387  common::Status Resolve(const ResolveOptions& options);
1388 
1390  ResolveOptions default_options;
1391  return Resolve(default_options);
1392  }
1393 
1394  const std::unordered_set<std::string>& GetOuterScopeNodeArgNames() const noexcept {
1395  return outer_scope_node_arg_names_;
1396  }
1397 
1398  common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder,
1399  flatbuffers::Offset<onnxruntime::fbs::Graph>& fbs_graph) const;
1400 
1401 #endif // !defined(ORT_MINIMAL_BUILD)
1402 
1403  /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */
1404  const Node* ParentNode() const { return parent_node_; }
1405 
1406  /** Returns true if the name is for a value that is coming from outer scope */
1407  bool IsOuterScopeValue(const std::string& name) const {
1408  if (!parent_node_) return false;
1409  const auto& implicit_input_defs = parent_node_->ImplicitInputDefs();
1410  return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
1411  [&name](const NodeArg* implicit_input) {
1412  return implicit_input->Name() == name;
1413  });
1414  }
1415 
1416 #if !defined(ORT_MINIMAL_BUILD)
1417  /** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node.
1418  Inherits some properties from the parent graph.
1419  @param parent_graph The Graph containing the Node that has the GraphProto attribute.
1420  @param parent_node The Node that has the GraphProto attribute.
1421  @param subgraph_proto The GraphProto from the Node attribute.
1422  */
1423  Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);
1424 
1425  Graph(const Model& owning_model,
1426  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1427  ONNX_NAMESPACE::GraphProto& subgraph_proto,
1428  const std::unordered_map<std::string, int>& domain_version_map,
1429  const logging::Logger& logger,
1430  bool strict_shape_type_inference);
1431 #endif
1432 
1433  virtual ~Graph();
1434 
1435  static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph, const Model& owning_model,
1436  const std::unordered_map<std::string, int>& domain_to_version,
1437 #if !defined(ORT_MINIMAL_BUILD)
1438  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1439 #endif
1440  const OrtFormatLoadOptions& load_options,
1441  const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1442 
1443  // deserialize a subgraph
1444  static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
1445  Graph& parent_graph, const Node& parent_node,
1446  const OrtFormatLoadOptions& load_options,
1447  const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1448 
1449 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1450  const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
1451  return runtime_optimizations_;
1452  }
1453 
1454  RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() {
1455  return runtime_optimizations_;
1456  }
1457 
1458  // We don't run Graph::Resolve() on an ORT format model, but a compiling EP may copy initializers to its
1459  // compiled model during partitioning, leaving them unused in the ORT Graph. To allow the memory to be freed
1460  // we need to manually run the cleanup that would usually happen as part of Graph::Resolve.
1462 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1463 
1464  // This friendship relationship should only be used to call Graph::Graph and
1465  // Graph::LoadGraph All other access should be via the public API.
1466  friend class Model;
1467 
1468  Graph() = delete;
1469 
1470  // Create empty Graph instance to re-create from ORT format serialized data.
1471  Graph(const Model& owning_model,
1472  const std::unordered_map<std::string, int>& domain_to_version,
1473 #if !defined(ORT_MINIMAL_BUILD)
1474  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1475 #endif
1476  Graph* parent_graph, const Node* parent_node,
1477  const logging::Logger& logger,
1478  bool strict_shape_type_inference);
1479 
1480  // Populate Graph instance from ORT format serialized data.
1481  Status LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph,
1482  const OrtFormatLoadOptions& load_options);
1483 
1484 #if !defined(ORT_MINIMAL_BUILD)
1485  // Constructor: Given a <GraphProto> loaded from model file, construct
1486  // a <Graph> object. Used by Model to create a Graph instance.
1487  Graph(const Model& owning_model,
1488  ONNX_NAMESPACE::GraphProto* graph_proto,
1489  const std::unordered_map<std::string, int>& domain_to_version,
1490  Version ir_version,
1491  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1492  const logging::Logger& logger,
1493  bool strict_shape_type_inference);
1494 
1495  // internal use by the Graph class only
1496  Graph(const Model& owning_model,
1497  ONNX_NAMESPACE::GraphProto* graph_proto,
1498  const std::unordered_map<std::string, int>& domain_to_version,
1499  Version ir_version,
1500  IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1501  Graph* parent_graph,
1502  const Node* parent_node,
1503  const logging::Logger& logger,
1504  bool strict_shape_type_inference);
1505 
1507 
1508  private:
1509  void InitializeStateFromModelFileGraphProto();
1510 
1511  // Add node with specified <node_proto>.
1512  Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
1513  const ArgNameToTypeMap& name_to_type);
1514 
1515  /** Helper that converts and adds constant node proto to an initializer in the graph.
1516  @param constant_node_proto Constant node to convert
1517  @param new_name use the new name for the initializer.
1518  */
1519  Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
1520  std::optional<std::string_view> new_name);
1521 
1522 #endif
1523 
1524  Version IrVersion() const noexcept {
1525  return ir_version_;
1526  }
1527 
1528  Graph& GraphResolveNeeded(bool needed) noexcept {
1529  graph_resolve_needed_ = needed;
1530  return *this;
1531  }
1532 
1533  Graph& GraphProtoSyncNeeded(bool needed) noexcept {
1534  graph_proto_sync_needed_ = needed;
1535  return *this;
1536  }
1537 
1538  // During the Resolve of a Graph it is necessary to recursively descend into subgraphs (created from GraphProto
1539  // Node attributes in the Graph) if present.
1540  // The ResolveContext holds the collection of values for the current Graph instance, be it the main graph
1541  // or a subgraph, so that the various operations that are part of the Resolve can work iteratively or
1542  // recursively as needed.
1543  struct ResolveContext {
1544  ResolveContext(const Graph& owning_graph) : graph{owning_graph} {
1545  }
1546 
1547  std::unordered_map<std::string_view, std::pair<Node*, int>> output_args;
1548  std::unordered_set<std::string_view> inputs_and_initializers;
1549  std::unordered_map<std::string_view, NodeIndex> node_name_to_index;
1550  std::unordered_set<Node*> nodes_with_subgraphs;
1551 
1552  // check if the provided name is an input/initialize/node output of this Graph instance during Graph::Resolve.
1553  // Graph::node_args_ can have stale entries so we can't rely on that.
1554  bool IsLocalValue(const std::string& name) const;
1555 
1556  // check if an ancestor graph has a valid value with the provided name during Graph::Resolve.
1557  // Once Graph::Resolve completes Graph::IsOuterScopeValue can be used and is more efficient.
1558  bool IsOuterScopeValue(const std::string& name) const;
1559 
1560  void Clear() {
1561  output_args.clear();
1562  inputs_and_initializers.clear();
1563  node_name_to_index.clear();
1564  nodes_with_subgraphs.clear();
1565  }
1566 
1567  private:
1568  bool IsInputInitializerOrOutput(const std::string& name, bool check_ancestors) const;
1569 
1570  const Graph& graph;
1571  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ResolveContext);
1572  };
1573 
1574  // Initialize all the graph inputs, initializers and outputs
1575  common::Status InitInputsInitializersOutputs();
1576 
1577  // Initialize overridable initializers container
1578  void ComputeOverridableInitializers();
1579 
1580 #if !defined(ORT_MINIMAL_BUILD)
1581  // Build and verify node connection (edges).
1582  // Verify NodeArg name/type/shape matching correctly.
1583  common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
1584 
1585  common::Status VerifyNoDuplicateName();
1586 
1587  // Check whether <*this> graph is acyclic while performing a topological sort.
1588  // Depth-first going from bottom up through the graph and checking whether there are any back edges.
1589  // NodesInTopologicalOrder is updated with the nodes' indexes in topological
1590  // order if <Status> returned is "OK", otherwise it's undefined.
1591  common::Status PerformTopologicalSortAndCheckIsAcyclic();
1592 
1593  common::Status PerformTypeAndShapeInferencing(const ResolveOptions& options);
1594 
1595  common::Status InferAndVerifyTypeMatch(Node& node, const ONNX_NAMESPACE::OpSchema& op, const ResolveOptions& options);
1596 
1597  // perform type and shape inferencing on the subgraph and Resolve to validate
1598  static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
1599  const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
1600  std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
1601  const Graph::ResolveOptions& options);
1602 
1603  // Apply type-inference and type-checking to all inputs and initializers:
1604  common::Status TypeCheckInputsAndInitializers();
1605 
1606  // Compute set of input and initializer names and checking for duplicate names
1607  common::Status VerifyInputAndInitializerNames();
1608 
1609  // Infer and set type information across <*this> graph if needed, and verify type/attribute
1610  // information matches between node and op.
1611 
1612  common::Status VerifyNodeAndOpMatch(const ResolveOptions& options);
1613 
1614  // Set graph inputs/outputs when resolving a graph..
1615  common::Status SetGraphInputsOutputs();
1616 
1617  // recursively accumulate and set the outer scope node args in the resolve context for all subgraphs
1618  // so they can be used to resolve outer scope dependencies when running BuildConnections for the subgraphs.
1619  common::Status SetOuterScopeNodeArgs(const std::unordered_set<std::string>& outer_scope_node_args);
1620 
1621  // Implementation for initializer replacement
1622  Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external);
1623 
1624  std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
1625  const ArgNameToTypeMap& name_to_type_map);
1626 
1627  void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const;
1628 
1629 #endif // !defined(ORT_MINIMAL_BUILD)
1630 
1631 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1632 
1633  // Recursively find all subgraphs including nested subgraphs
1634  void FindAllSubgraphs(std::vector<Graph*>& subgraphs);
1635 
1636  // Iterate this Graph instance and all subgraphs, calling the provided function for each.
1637  common::Status ForThisAndAllSubgraphs(const std::vector<Graph*>& subgraphs, std::function<Status(Graph&)> func);
1638 
1639  // Clear all unused initializers and NodeArgs
1640  void CleanUnusedInitializersAndNodeArgs(const std::unordered_set<std::string>* initializer_names_to_preserve = nullptr);
1641 
1642  Status PopulateNodeArgToProducerConsumerLookupsFromNodes();
1643 
1644  template <typename TInstance>
1645  static auto GetConsumerNodesImpl(
1646  TInstance& instance, const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
1647  std::vector<decltype(instance.GetNode(0))> results;
1648  auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
1649  if (iter != instance.node_arg_to_consumer_nodes_.end()) {
1650  results.reserve(iter->second.size());
1651  for (auto node_index : iter->second) {
1652  results.push_back(instance.GetNode(node_index));
1653  }
1654  }
1655  return results;
1656  }
1657 
1658  template <typename TInstance>
1659  static auto GetProducerNodeImpl(
1660  TInstance& instance, const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
1661  auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
1662  if (iter != instance.node_arg_to_producer_node_.end()) {
1663  auto node_index = iter->second;
1664  return instance.GetNode(node_index);
1665  }
1666  return nullptr;
1667  }
1668 
1669  gsl::not_null<Node*> AllocateNode();
1670 
1671  // Release the node.
1672  // @returns false if node_index was invalid.
1673  bool ReleaseNode(NodeIndex node_index);
1674 
1675  Node& CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
1676 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1677 
1678  Node* NodeAtIndexImpl(NodeIndex node_index) const {
1679  // if we are trying to access a node that doesn't exist there's (most
1680  // likely) either a logic issue or a graph consistency/correctness issue.
1681  // use ORT_ENFORCE to prove that or uncover scenarios where we actually
1682  // expect attempts to retrieve a non-existent node.
1683  ORT_ENFORCE(node_index < nodes_.size(), "Validating no unexpected access using an invalid node_index. Got:",
1684  node_index, " Max:", nodes_.size());
1685  return nodes_[node_index].get();
1686  }
1687 
1688  const Model& owning_model_;
1689 
1690  // GraphProto to store name, version, initializer.
1691  // When serializing <*this> Graph to a GraphProto, the nodes and
1692  // functions in <Graph> will also be fed into <graph_proto_> so that
1693  // it's consistent with <*this> graph.
1694  // This pointer is owned by parent model.
1695  ONNX_NAMESPACE::GraphProto* graph_proto_;
1696 
1697  // GraphProto that provides storage for the ONNX proto types deserialized from a flexbuffer/flatbuffer
1698  ONNX_NAMESPACE::GraphProto deserialized_proto_data_;
1699 
1700  InitializedTensorSet name_to_initial_tensor_;
1701 
1702  std::unordered_set<std::reference_wrapper<const std::string>,
1703  std::hash<std::string>, std::equal_to<std::string>>
1704  sparse_tensor_names_;
1705 
1706 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1707  // Runtime optimization storage.
1708  // Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized
1709  std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
1710  RuntimeOptimizationRecordContainer& runtime_optimizations_;
1711 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1712 
1713 #if !defined(ORT_MINIMAL_BUILD)
1714  IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;
1715 
1716  // Currently to make the ORT in-memory graph work, we have to create a temporary op schema
1717  // for the fused kernel. I really don't like it. but for short-term solution, let's host
1718  // those schemas here.
1719  InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
1720  // in some case, a fused sub-graph will happens multiple times in one model, we use a map
1721  // to store reusable-schema in lookup.
1722  InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1723 #endif // !defined(ORT_MINIMAL_BUILD)
1724 
1725  // Graph nodes.
1726  // Element in <nodes_> may be nullptr due to graph optimization.
1727  std::vector<std::unique_ptr<Node>> nodes_;
1728 
1729  // Wrapper of Graph nodes to provide iteration services that hide nullptr entries
1730  GraphNodes iterable_nodes_{nodes_};
1731 
1732  // Number of nodes.
1733  // Normally this is smaller than the size of <m_nodes>, as some
1734  // elements in <m_nodes> may be removed when doing graph optimization,
1735  // or some elements may be merged, etc.
1736  int num_of_nodes_ = 0;
1737 
1738  // A flag indicates whether <*this> graph needs to be resolved.
1739  bool graph_resolve_needed_ = false;
1740 
1741  bool graph_proto_sync_needed_ = false;
1742 
1743  // The topological order of node index used to do node and op match verification temporarily.
1744  std::vector<NodeIndex> nodes_in_topological_order_;
1745 
1746  // Full list of graph inputs. Matches number and order of inputs in the GraphProto.
1747  std::vector<const NodeArg*> graph_inputs_including_initializers_;
1748  bool graph_inputs_manually_set_ = false;
1749 
1750  // Graph inputs excluding initializers.
1751  std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
1752 
1753  // Overridable Initializers. The difference between graph_inputs_including_initializers_
1754  // and graph_inputs_excluding_initializers_
1755  std::vector<const NodeArg*> graph_overridable_initializers_;
1756 
1757  // Graph outputs.
1758  std::vector<const NodeArg*> graph_outputs_;
1759  bool graph_outputs_manually_set_ = false;
1760 
1761  // Graph value_info.
1762  std::unordered_set<const NodeArg*> value_info_;
1763 
1764  // All node args owned by <*this> graph. Key is node arg name.
1765  std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
1766 
1767 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1768  int name_generator_ = 0;
1769 
1770  // Strings which have been used as node names.
1771  // New node name should not conflict with this set.
1772  std::unordered_set<std::string> generated_node_names_;
1773 
1774  // Strings which have been used as node_arg names.
1775  // New node_arg name should not conflict this this set.
1776  std::unordered_set<std::string> generated_node_arg_names_;
1777 
1778  // node arg to its producer node
1779  std::unordered_map<std::string, NodeIndex> node_arg_to_producer_node_;
1780 
1781  // node arg to its consumer nodes
1782  std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1783 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1784 
1785  const std::unordered_map<std::string, int> domain_to_version_;
1786 
1787  // Model IR version.
1788  Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
1789 
1790  ResolveContext resolve_context_{*this};
1791 
1792  // the parent graph if this is a subgraph.
1793  Graph* parent_graph_;
1794  // the node containing the graph if parent_graph_ is not nullptr
1795  const Node* parent_node_;
1796 
1797  // NodeArgs that come from outer scope. Used when building a graph so that
1798  // these don't get recorded as graph inputs in the GraphProto.
1799  std::unordered_set<std::string> outer_scope_node_arg_names_;
1800 
1801  // number of times Resolve has run.
1802  int num_resolves_ = 0;
1803 
1804  const logging::Logger& logger_;
1805 
1806  // If true, all inconsistencies encountered during shape and type inference
1807  // will be exposed to the caller as failures. If false, in some cases
1808  // warnings will be logged but processing will continue and no error will
1809  // be returned.
1810  const bool strict_shape_type_inference_;
1811 
1812  // distinguishes between graph loaded from model file and graph created from scratch
1813  const bool is_loaded_from_model_file_;
1814 };
1815 
1816 #if !defined(ORT_MINIMAL_BUILD)
1817 // Print NodeArg as
1818 // name : type
1819 // For example,
1820 // "110": tensor(float)
1821 std::ostream& operator<<(std::ostream& out, const NodeArg& node_arg);
1822 // Print Node as,
1823 // (operator's name, operator's type, domain, version) : (input0, input1, ...) -> (output0, output1, ...)
1824 // For example,
1825 // ("Add_14", Add, "", 7) : ("110": tensor(float),"109": tensor(float),) -> ("111": tensor(float),)
1826 std::ostream& operator<<(std::ostream& out, const Node& node);
1827 // Print Graph as, for example,
1828 // Inputs:
1829 // "Input": tensor(float)
1830 // Nodes:
1831 // ("add0", Add, "", 7) : ("Input": tensor(float),"Bias": tensor(float),) -> ("add0_out": tensor(float),)
1832 // ("matmul", MatMul, "", 9) : ("add0_out": tensor(float),"matmul_weight": tensor(float),) -> ("matmul_out": tensor(float),)
1833 // ("add1", Add, "", 7) : ("matmul_out": tensor(float),"add_weight": tensor(float),) -> ("add1_out": tensor(float),)
1834 // ("reshape", Reshape, "", 5) : ("add1_out": tensor(float),"concat_out": tensor(int64),) -> ("Result": tensor(float),)
1835 // Outputs:
1836 // "Result": tensor(float)
1837 // Inputs' and outputs' format is described in document of NodeArg's operator<< above.
1838 // Node format is described in Node's operator<< above.
1839 std::ostream& operator<<(std::ostream& out, const Graph& graph);
1840 #endif
1841 
1842 } // namespace onnxruntime
constexpr auto AsSpan(C &c)
Definition: span_utils.h:42
void SetNodeArgType(NodeArg &arg, const ONNX_NAMESPACE::TypeProto &type_proto)
bool IsOuterScopeValue(const std::string &name) const
Definition: graph.h:1407
bool IsInitializedTensor(const std::string &name) const
void UpdateProducerNode(const std::string &node_arg_name, NodeIndex node_index)
Definition: graph.h:1312
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
Definition: basic_types.h:35
The node refers to a primitive operator.
const std::string & ProviderType
Definition: basic_types.h:37
const Node * GetNode(NodeIndex node_index) const
Definition: graph.h:861
void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value)
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
Definition: graph.h:764
void ForEachDef(std::function< void(const onnxruntime::NodeArg &, bool is_input)> func, bool include_missing_optional_defs=false) const
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
Definition: graph.h:207
const std::vector< int > & InputArgCount() const noexcept
Definition: graph.h:203
void SetInputs(gsl::span< const NodeArg *const > inputs)
const ONNX_NAMESPACE::GraphProto & ToGraphProto()
const Function * GetFunctionBody() const noexcept
Definition: graph.h:180
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Definition: basic_types.h:46
Node(NodeIndex index, Graph &graph)
Definition: graph.h:561
int MaxNodeIndex() const noexcept
Definition: graph.h:886
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph)
The node refers to a function.
const std::string & Description() const noexcept
Definition: graph.h:152
bool IsSubgraph() const
Definition: graph.h:697
void SetFunctionTemplate(const FunctionTemplate &func_template)
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path &external_file_path, const std::filesystem::path &model_file_path, size_t initializer_size_threshold, const OffsetAlignmentInfo &align_info) const
NodeIndex Index() const noexcept
Definition: graph.h:127
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
int GetDstArgIndex() const
Definition: graph.h:118
Definition: Node.h:52
const Node * operator->() const
void UpdateConsumerNodes(const std::string &node_arg_name, std::initializer_list< Node * > nodes)
Definition: graph.h:1355
Node & FuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
const RuntimeOptimizationRecordContainer & RuntimeOptimizations() const
Definition: graph.h:1450
bool NodeProducesGraphOutput(const Node &node) const
Definition: graph.h:824
GLsizei const GLfloat * value
Definition: glcorearb.h:824
const Graph * GetGraphAttribute(const std::string &attr_name) const
const NodeAttributes & GetAttributes() const noexcept
Definition: graph.h:386
GLsizei const GLchar *const * path
Definition: glcorearb.h:3341
const Node * ParentNode() const
Definition: graph.h:1404
size_t GetOutputEdgesCount() const noexcept
Definition: graph.h:340
common::Status InjectExternalInitializedTensors(const InlinedHashMap< std::string, OrtValue > &external_initializers)
RuntimeOptimizationRecordContainer & MutableRuntimeOptimizations()
Definition: graph.h:1454
static common::Status ForEachMutableWithIndex(std::vector< NodeArg * > &node_args, std::function< common::Status(NodeArg &arg, size_t index)> func)
Definition: graph.h:233
bool SetOpSchemaFromRegistryForNode(Node &node)
Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Node > &fbs_node) const
#define ORT_ENFORCE(condition,...)
Definition: common.h:172
bool IsSparseInitializer(const std::string &name) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
Definition: graph.h:1131
void FinalizeFuseSubGraph(const IndexedSubGraph &sub_graph, Node &fused_node)
void ToProto(ONNX_NAMESPACE::NodeProto &proto, bool update_subgraphs=false) const
FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool
Definition: core.h:2138
NodeArg * GetNodeArg(const std::string &name)
Definition: graph.h:895
std::vector< NodeArg * > & MutableImplicitInputDefs() noexcept
Definition: graph.h:245
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
Definition: core.h:1859
const NodeArg * GetNodeArg(const std::string &name) const
Definition: graph.h:905
NodeConstIterator OutputNodesEnd() const noexcept
Definition: graph.h:317
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph &fbs_graph, const Model &owning_model, const std::unordered_map< std::string, int > &domain_to_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Graph > &graph)
std::vector< Node * > GetMutableConsumerNodes(const std::string &node_arg_name)
Definition: graph.h:1338
std::unordered_map< std::string, gsl::not_null< const Graph * > > GetAttributeNameToSubgraphMap() const
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
Definition: graph.h:220
bool IsInputsIncludingInitializers(const NodeArg *node_arg) const noexcept
Definition: graph.h:800
const ONNX_NAMESPACE::OpSchema * Op() const noexcept
Definition: graph.h:172
const std::string & OpType() const noexcept
Definition: graph.h:133
const std::filesystem::path & ModelPath() const
friend class Model
Definition: graph.h:1466
void SetOutputs(gsl::span< const NodeArg *const > outputs)
void SetExecutionProviderType(ProviderType execution_provider_type)
Definition: graph.h:454
void AddConsumerNode(const std::string &node_arg_name, Node *consumer)
Definition: graph.h:1327
NodeConstIterator InputNodesBegin() const noexcept
Definition: graph.h:307
basic_string_view< char > string_view
Definition: core.h:501
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Definition: graph.h:214
Node * GetNode(NodeIndex node_index)
Definition: graph.h:866
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions)
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
Definition: graph.h:852
const std::string & Description() const noexcept
GraphNodes & Nodes() noexcept
Definition: graph.h:869
Graph & SetGraphResolveNeeded() noexcept
Definition: graph.h:1048
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge &fbs_node_edgs, const Graph &graph)
std::vector< gsl::not_null< const Graph * > > GetSubgraphs() const
NodeConstIterator OutputNodesBegin() const noexcept
Definition: graph.h:312
std::set< EdgeEnd, EdgeEndCompare > EdgeSet
Definition: graph.h:280
common::Status InjectExternalInitializersFromFilesInMemory(const InlinedHashMap< PathString, std::pair< char *, size_t >> &external_initializer_files)
NodeAttributes & GetMutableAttributes() noexcept
Definition: graph.h:393
bool RemoveNode(NodeIndex node_index)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:969
ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto)
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto &func_proto) const
int Priority() const noexcept
Definition: graph.h:145
int NumberOfNodes() const noexcept
Definition: graph.h:891
void UpdateConsumerNodes(const std::string &node_arg_name, gsl::span< Node *const > nodes)
Definition: graph.h:1342
static Status OK()
Definition: status.h:160
common::Status Resolve()
Definition: graph.h:1389
void AddAttribute(std::string attr_name, int64_t value)
std::string GenerateNodeName(const std::string &base_name)
void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
std::vector< int > & MutableInputArgsCount()
Definition: graph.h:254
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName)
Definition: common.h:219
bool ClearAttribute(const std::string &attr_name)
std::vector< int > GetNodeOutputsInGraphOutputs(const Node &node) const
Definition: graph.h:835
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index)
static common::Status ForEachWithIndex(const ConstPointerContainer< std::vector< NodeArg * >> &node_args, std::function< common::Status(const NodeArg &arg, size_t index)> func)
Definition: graph.h:191
const ONNX_NAMESPACE::TensorProto * GetInitializer(const std::string &name, bool check_outer_scope) const
Graph * MutableParentGraph()
Definition: graph.h:703
void SetSinceVersion(int since_version) noexcept
Definition: graph.h:167
const std::string & Name() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
Definition: basic_types.h:44
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
Definition: graph.h:433
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const
Status RemovedUnusedInitializersOrtFormat()
void RemoveConsumerNode(const std::string &node_arg_name, Node *consumer)
Definition: graph.h:1332
size_t GetInputEdgesCount() const noexcept
Definition: graph.h:337
constexpr const char * kOnnxDomain
Definition: constants.h:14
GLuint const GLchar * name
Definition: glcorearb.h:786
std::set< std::string > control_inputs
Definition: graph.h:552
Status InlineFunction(Node &node)
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path &external_file_path, const std::filesystem::path &model_file_path, size_t initializer_size_threshold) const
Definition: graph.h:1192
flatbuffers::Offset< onnxruntime::fbs::NodeEdge > SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder &builder) const
bool GraphProtoSyncNeeded() const noexcept
Definition: graph.h:1065
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto &func_to_inline)
Graph & SetGraphProtoSyncNeeded() noexcept
Definition: graph.h:1059
const std::vector< const NodeArg * > & GetOutputs() const noexcept
Definition: graph.h:815
void RemoveInitializedTensor(const std::string &tensor_name)
std::vector< int > input_arg_count
Definition: graph.h:514
bool IsOutput(const NodeArg *node_arg) const noexcept
Definition: graph.h:817
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Graph > &fbs_graph) const
NodeArg * GetNodeArgIncludingParentGraphs(const std::string &node_arg_name)
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope) const
void KahnsTopologicalSort(const std::function< void(const Node *)> &enter, const std::function< bool(const Node *, const Node *)> &comp) const
const Model & GetModel() const
Definition: graph.h:1280
void AddOuterScopeNodeArg(const std::string &name)
Definition: graph.h:1266
const std::unordered_set< std::string > * initializer_names_to_preserve
Definition: graph.h:1369
common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer)
const Node * GetProducerNode(const std::string &node_arg_name) const
Definition: graph.h:1304
void SetOutputs(std::initializer_list< const NodeArg * > outputs)
Definition: graph.h:1294
Status UpdateShapeInference(Node &node)
bool GraphResolveNeeded() const noexcept
Definition: graph.h:1054
const std::string & Name() const noexcept
Definition: graph.h:130
GLenum func
Definition: glcorearb.h:783
std::string GenerateNodeArgName(const std::string &base_name)
std::vector< NodeArg * > implicit_input_defs
Definition: graph.h:524
Graph * GetMutableGraphAttribute(const std::string &attr_name)
ProviderType GetExecutionProviderType() const noexcept
Definition: graph.h:451
GLenum GLsizei GLsizei GLint * values
Definition: glcorearb.h:1602
void SetPriority(int priority) noexcept
void SetInputs(std::initializer_list< const NodeArg * > inputs)
Definition: graph.h:1276
bool operator==(const NodeConstIterator &p_other) const
Node * GetMutableProducerNode(const std::string &node_arg_name)
Definition: graph.h:1308
Node & AddNode(const Node &other)
const std::string & Domain() const noexcept
Definition: graph.h:138
NodeArg & GetOrCreateNodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
Definition: graph.h:917
void SetName(const std::string &name)
GLuint index
Definition: glcorearb.h:786
const Node & GetNode() const noexcept
Definition: graph.h:110
bool StrictShapeTypeInference() const
Definition: graph.h:706
#define ORT_RETURN_IF_ERROR(expr)
Definition: common.h:233
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
Definition: graph.h:1322
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
Definition: graph.h:1394
GA_API const UT_StringHolder N
bool operator()(const EdgeEnd &lhs, const EdgeEnd &rhs) const
Definition: graph.h:269
ADD_ATTR_INTERFACES(float)
NodeConstIterator InputNodesEnd() const noexcept
Definition: graph.h:309
const std::set< std::string > & ControlInputs() const noexcept
Definition: graph.h:334
const logging::Logger & GetLogger() const
Definition: graph.h:1284
std::vector< NodeArg * > & MutableOutputDefs() noexcept
Definition: graph.h:262
EdgeEnd(const Node &node, int src_arg_index, int dst_arg_index) noexcept
const std::filesystem::path & ModelPath() const noexcept
int64_t Version
Definition: basic_types.h:33
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node &fbs_node, Graph &graph, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Node > &node)
const std::vector< const NodeArg * > & GetOverridableInitializers() const
Definition: graph.h:809
EdgeConstIterator InputEdgesBegin() const noexcept
Definition: graph.h:321
const std::vector< const NodeArg * > & GetInputs() const noexcept
Definition: graph.h:790
NodeConstIterator(EdgeConstIterator p_iter)
int SinceVersion() const noexcept
Definition: graph.h:161
#define ORT_IGNORE_RETURN_VALUE(fn)
Definition: common.h:78
Node & BeginFuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
std::vector< NodeArg * > output_defs
Definition: graph.h:517
std::vector< NodeArg * > input_defs
Definition: graph.h:506
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto &tensor_proto)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, gsl::span< NodeArg *const > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:982
const GraphNodes & Nodes() const noexcept
Definition: graph.h:872
friend class Graph
Definition: graph.h:560
void ReplaceDefs(const std::map< const onnxruntime::NodeArg *, onnxruntime::NodeArg * > &replacements)
EdgeConstIterator InputEdgesEnd() const noexcept
Definition: graph.h:324
EdgeConstIterator OutputEdgesBegin() const noexcept
Definition: graph.h:328
std::unordered_map< std::string, ONNX_NAMESPACE::TypeProto > ArgNameToTypeMap
Definition: basic_types.h:36
void CleanAllInitializedTensors() noexcept
void SetDescription(const std::string &description)
bool ContainsSubgraph() const
Definition: graph.h:421
ConstGraphNodes FilteredNodes(GraphNodes::NodeFilterFunc &&filter_func) const noexcept
Definition: graph.h:879
std::vector< NodeArg * > & MutableInputDefs() noexcept
Definition: graph.h:257
bool CanOverrideInitializer() const noexcept
Definition: graph.h:770
Node::Type NodeType() const noexcept
Definition: graph.h:155
Node(std::string_view name, std::string_view op_type, std::string_view description, gsl::span< NodeArg *const > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes, std::string_view domain)
Definition: graph.h:71
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
Definition: graph.h:795
void AddValueInfo(const NodeArg *new_value_info)
bool operator!=(const NodeConstIterator &p_other) const
Status InlineIfSubgraph(bool condition_value, Node &if_node, const logging::Logger &logger)
EdgeConstIterator OutputEdgesEnd() const noexcept
Definition: graph.h:331
EdgeSet::const_iterator EdgeConstIterator
Definition: graph.h:281
int PruneRemovableAttributes(gsl::span< const std::string > removable_attributes)
size_t NodeIndex
Definition: basic_types.h:32
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
Definition: graph.h:995
int GetSrcArgIndex() const
Definition: graph.h:114
const Graph * ParentGraph() const
Definition: graph.h:700
std::unordered_map< std::string, gsl::not_null< Graph * > > & GetMutableMapOfAttributeNameToSubgraph()
Definition: graph.h:440
void ReverseDFSFrom(gsl::span< NodeIndex const > from, const std::function< void(const Node *)> &enter, const std::function< void(const Node *)> &leave, const std::function< bool(const Node *, const Node *)> &comp={}) const
void AddAttribute(std::string attr_name, const char(&value)[N])
Definition: graph.h:381
bool CanBeInlined() const
bool Exists() const noexcept