10 #include <type_traits>
11 #include <unordered_map>
12 #include <unordered_set>
15 #include "core/common/flatbuffers.h"
20 #include "core/common/path_string.h"
22 #if !defined(ORT_MINIMAL_BUILD)
29 #include "core/graph/onnx_protobuf.h"
33 #if !defined(ORT_MINIMAL_BUILD)
34 #include "core/graph/function_template.h"
38 #include "core/graph/ort_format_load_options.h"
40 namespace onnxruntime {
42 struct IndexedSubGraph;
46 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
47 class RuntimeOptimizationRecordContainer;
68 explicit Node() =
default;
70 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
74 gsl::span<NodeArg* const> input_args,
75 gsl::span<NodeArg* const> output_args,
78 Init(name, op_type, description,
102 EdgeEnd(
const Node& node,
int src_arg_index,
int dst_arg_index) noexcept;
122 const int src_arg_index_;
123 const int dst_arg_index_;
130 const std::string&
Name() const noexcept {
return name_; }
133 const std::string&
OpType() const noexcept {
return op_type_; }
138 const std::string&
Domain() const noexcept {
return domain_; }
145 int Priority() const noexcept {
return priority_; };
152 const std::string&
Description() const noexcept {
return description_; }
169 #if !defined(ORT_MINIMAL_BUILD)
172 const ONNX_NAMESPACE::OpSchema*
Op() const noexcept {
return op_; }
193 for (
size_t index = 0; index < node_args.size(); ++index) {
194 auto arg = node_args[index];
224 #if !defined(ORT_MINIMAL_BUILD)
235 for (
size_t index = 0; index < node_args.size(); ++index) {
236 auto arg = node_args[index];
248 #endif // !defined(ORT_MINIMAL_BUILD)
250 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
265 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
280 using EdgeSet = std::set<EdgeEnd, EdgeEndCompare>;
354 #define ADD_ATTR_SINGLE_INTERFACE(Type) \
355 void AddAttribute(std::string attr_name, Type value)
357 #define ADD_ATTR_LIST_INTERFACE(Type) \
358 void AddAttribute(std::string attr_name, gsl::span<const Type> values)
360 #define ADD_ATTR_INTERFACES(Type) \
361 ADD_ATTR_SINGLE_INTERFACE(Type); \
362 ADD_ATTR_LIST_INTERFACE(Type)
367 #if !defined(DISABLE_SPARSE_TENSORS)
374 #undef ADD_ATTR_SINGLE_INTERFACE
375 #undef ADD_ATTR_LIST_INTERFACE
376 #undef ADD_ATTR_INTERFACES
388 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
395 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
403 #if !defined(ORT_MINIMAL_BUILD)
416 #endif // !defined(ORT_MINIMAL_BUILD)
422 return !attr_to_subgraph_map_.empty();
427 std::vector<gsl::not_null<const Graph*>>
GetSubgraphs()
const;
434 return attr_to_subgraph_map_;
441 return attr_to_subgraph_map_;
455 execution_provider_type_ = execution_provider_type;
464 bool include_missing_optional_defs =
false)
const;
466 #if !defined(ORT_MINIMAL_BUILD)
470 void ReplaceDefs(
const std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*>& replacements);
477 void ToProto(ONNX_NAMESPACE::NodeProto& proto,
bool update_subgraphs =
false)
const;
480 flatbuffers::Offset<onnxruntime::fbs::Node>& fbs_node)
const;
482 flatbuffers::Offset<onnxruntime::fbs::NodeEdge>
489 const OrtFormatLoadOptions& load_options,
493 const OrtFormatLoadOptions& load_options,
566 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
570 gsl::span<NodeArg* const> input_args,
571 gsl::span<NodeArg* const> output_args,
577 gsl::span<NodeArg* const> input_args,
578 gsl::span<NodeArg* const> output_args,
583 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
585 Definitions& MutableDefinitions() noexcept;
588 Relationships& MutableRelationships() noexcept;
590 void SetNodeType(
Node::
Type node_type) noexcept { node_type_ = node_type; }
594 void CreateSubgraph(
const std::string& attr_name);
596 std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept {
return subgraphs_; }
601 const Definitions& GetDefinitions() const noexcept {
return definitions_; }
602 const Relationships& GetRelationships() const noexcept {
return relationships_; }
611 std::string op_type_;
616 #if !defined(ORT_MINIMAL_BUILD)
618 const ONNX_NAMESPACE::OpSchema* op_ =
nullptr;
621 const FunctionTemplate* func_template_ =
nullptr;
626 void SetOriginalNodeProto(
const ONNX_NAMESPACE::NodeProto* node_proto) {
627 original_node_proto_ = node_proto;
630 const ONNX_NAMESPACE::NodeProto* GetOriginalNodeProto()
const {
631 return original_node_proto_;
637 const ONNX_NAMESPACE::NodeProto* original_node_proto_ =
nullptr;
644 int since_version_ = -1;
649 std::unique_ptr<Function> func_body_ =
nullptr;
652 std::string description_;
655 Definitions definitions_;
658 Relationships relationships_;
661 std::string execution_provider_type_;
668 Graph* graph_ =
nullptr;
671 std::unordered_map<std::string, gsl::not_null<Graph*>> attr_to_subgraph_map_;
674 std::vector<std::unique_ptr<Graph>> subgraphs_;
688 const std::string&
Name()
const noexcept;
708 #if !defined(ORT_MINIMAL_BUILD)
723 #if !defined(DISABLE_EXTERNAL_INITIALIZERS)
733 const InlinedHashMap<PathString, std::pair<char*, size_t>>& external_initializer_files);
734 #endif // !defined(DISABLE_EXTERNAL_INITIALIZERS)
736 #endif // !defined(ORT_MINIMAL_BUILD)
738 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
749 #if !defined(DISABLE_SPARSE_TENSORS)
785 const ONNX_NAMESPACE::TensorProto*
GetInitializer(
const std::string&
name,
bool check_outer_scope)
const;
790 const std::vector<const NodeArg*>&
GetInputs() const noexcept {
return graph_inputs_excluding_initializers_; }
796 return graph_inputs_including_initializers_;
801 return std::find(graph_inputs_including_initializers_.begin(),
802 graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
810 return graph_overridable_initializers_;
815 const std::vector<const NodeArg*>&
GetOutputs() const noexcept {
return graph_outputs_; }
818 return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
825 auto end_outputs = graph_outputs_.cend();
827 if (
std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
837 std::vector<int> indexes;
840 indexes.push_back(output_idx);
852 const std::unordered_set<const NodeArg*>&
GetValueInfo() const noexcept {
return value_info_; }
854 #if !defined(ORT_MINIMAL_BUILD)
886 int MaxNodeIndex() const noexcept {
return static_cast<int>(nodes_.size()); }
896 auto iter = node_args_.find(name);
897 if (iter != node_args_.end()) {
898 return iter->second.get();
918 auto insert_result = node_args_.emplace(name,
nullptr);
919 if (insert_result.second) {
920 insert_result.first->second = std::make_unique<NodeArg>(name, p_arg_type);
922 return *(insert_result.first->second);
925 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
931 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
933 #if !defined(ORT_MINIMAL_BUILD)
942 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
955 const std::string& op_type,
956 const std::string& description,
957 gsl::span<NodeArg* const> input_args,
958 gsl::span<NodeArg* const> output_args,
963 const std::string& op_type,
964 const std::string& description,
965 gsl::span<NodeArg* const> input_args,
966 gsl::span<NodeArg* const> output_args,
970 const std::string& op_type,
971 const std::string& description,
972 std::initializer_list<NodeArg*> input_args,
973 std::initializer_list<NodeArg*> output_args,
976 return AddNode(name, op_type, description,
983 const std::string& op_type,
984 const std::string& description,
985 gsl::span<NodeArg* const> input_args,
986 std::initializer_list<NodeArg*> output_args,
989 return AddNode(name, op_type, description,
996 const std::string& op_type,
997 const std::string& description,
998 std::initializer_list<NodeArg*> input_args,
999 gsl::span<NodeArg* const> output_args,
1002 return AddNode(name, op_type, description,
1005 attributes, domain);
1037 #if !defined(ORT_MINIMAL_BUILD)
1044 #endif // !defined(ORT_MINIMAL_BUILD)
1049 graph_resolve_needed_ =
true;
1055 return graph_resolve_needed_;
1060 graph_proto_sync_needed_ =
true;
1066 return graph_proto_sync_needed_;
1077 const std::function<
void(
const Node*)>& enter,
1078 const std::function<
void(
const Node*)>& leave,
1079 const std::function<
bool(
const Node*,
const Node*)>& comp = {})
const;
1089 const std::function<
void(
const Node*)>& enter,
1090 const std::function<
void(
const Node*)>& leave,
1091 const std::function<
bool(
const Node*,
const Node*)>& comp = {})
const;
1102 const std::function<
void(
const Node*)>& enter,
1103 const std::function<
void(
const Node*)>& leave,
1104 const std::function<
bool(
const Node*,
const Node*)>& comp,
1105 const std::function<
bool(
const Node*,
const Node*)>& stop)
const;
1107 #if !defined(ORT_MINIMAL_BUILD)
1113 const std::function<
bool(
const Node*,
const Node*)>& comp)
const;
1117 #ifdef ENABLE_TRAINING
1125 void MemoryEfficientTopologicalSort(
const Node* yield_op,
1126 const InlinedHashMap<
NodeIndex, InlinedVector<NodeIndex>>& shape_size_parents,
1127 std::vector<NodeIndex>& node_orders)
const;
1132 return domain_to_version_;
1135 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1151 #if !defined(ORT_MINIMAL_BUILD)
1189 size_t initializer_size_threshold,
1194 size_t initializer_size_threshold)
const {
1274 void SetInputs(gsl::span<const NodeArg* const> inputs);
1276 void SetInputs(std::initializer_list<const NodeArg*> inputs) {
1281 return owning_model_;
1292 void SetOutputs(gsl::span<const NodeArg* const> outputs);
1298 #endif // !defined(ORT_MINIMAL_BUILD)
1300 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1305 return GetProducerNodeImpl(*
this, node_arg_name);
1309 return GetProducerNodeImpl(*
this, node_arg_name);
1313 auto iter = node_arg_to_producer_node_.find(node_arg_name);
1315 if (iter != node_arg_to_producer_node_.end()) {
1316 iter->second = node_index;
1318 node_arg_to_producer_node_[node_arg_name] = node_index;
1323 return GetConsumerNodesImpl(*
this, node_arg_name);
1328 node_arg_to_consumer_nodes_[node_arg_name].insert(consumer->
Index());
1333 node_arg_to_consumer_nodes_[node_arg_name].erase(consumer->
Index());
1335 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1337 #if !defined(ORT_MINIMAL_BUILD)
1339 return GetConsumerNodesImpl(*
this, node_arg_name);
1344 auto& nodes_for_arg = node_arg_to_consumer_nodes_[node_arg_name];
1345 if (!nodes_for_arg.empty()) {
1346 nodes_for_arg.clear();
1349 nodes_for_arg.reserve(nodes.size());
1350 for (
Node* node : nodes) {
1351 nodes_for_arg.insert(node->Index());
1391 return Resolve(default_options);
1395 return outer_scope_node_arg_names_;
1399 flatbuffers::Offset<onnxruntime::fbs::Graph>& fbs_graph)
const;
1401 #endif // !defined(ORT_MINIMAL_BUILD)
1408 if (!parent_node_)
return false;
1410 return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
1411 [&name](
const NodeArg* implicit_input) {
1412 return implicit_input->Name() == name;
1416 #if !defined(ORT_MINIMAL_BUILD)
1423 Graph(
Graph& parent_graph,
const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);
1427 ONNX_NAMESPACE::GraphProto& subgraph_proto,
1428 const std::unordered_map<std::string, int>& domain_version_map,
1430 bool strict_shape_type_inference);
1436 const std::unordered_map<std::string, int>& domain_to_version,
1437 #
if !defined(ORT_MINIMAL_BUILD)
1440 const OrtFormatLoadOptions& load_options,
1445 Graph& parent_graph,
const Node& parent_node,
1446 const OrtFormatLoadOptions& load_options,
1449 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1451 return runtime_optimizations_;
1455 return runtime_optimizations_;
1462 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1472 const std::unordered_map<std::string, int>& domain_to_version,
1473 #
if !defined(ORT_MINIMAL_BUILD)
1476 Graph* parent_graph,
const Node* parent_node,
1478 bool strict_shape_type_inference);
1482 const OrtFormatLoadOptions& load_options);
1484 #if !defined(ORT_MINIMAL_BUILD)
1488 ONNX_NAMESPACE::GraphProto* graph_proto,
1489 const std::unordered_map<std::string, int>& domain_to_version,
1493 bool strict_shape_type_inference);
1497 ONNX_NAMESPACE::GraphProto* graph_proto,
1498 const std::unordered_map<std::string, int>& domain_to_version,
1501 Graph* parent_graph,
1502 const Node* parent_node,
1504 bool strict_shape_type_inference);
1509 void InitializeStateFromModelFileGraphProto();
1512 Node&
AddNode(
const ONNX_NAMESPACE::NodeProto& node_proto,
1519 Status AddConstantProtoAsInitializer(
const ONNX_NAMESPACE::NodeProto& constant_node_proto,
1520 std::optional<std::string_view> new_name);
1524 Version IrVersion() const noexcept {
1529 graph_resolve_needed_ = needed;
1534 graph_proto_sync_needed_ = needed;
1543 struct ResolveContext {
1544 ResolveContext(
const Graph& owning_graph) : graph{owning_graph} {
1547 std::unordered_map<std::string_view, std::pair<Node*, int>> output_args;
1548 std::unordered_set<std::string_view> inputs_and_initializers;
1549 std::unordered_map<std::string_view, NodeIndex> node_name_to_index;
1550 std::unordered_set<Node*> nodes_with_subgraphs;
1554 bool IsLocalValue(
const std::string&
name)
const;
1561 output_args.clear();
1562 inputs_and_initializers.clear();
1563 node_name_to_index.clear();
1564 nodes_with_subgraphs.clear();
1568 bool IsInputInitializerOrOutput(
const std::string&
name,
bool check_ancestors)
const;
1578 void ComputeOverridableInitializers();
1580 #if !defined(ORT_MINIMAL_BUILD)
1583 common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
1593 common::Status PerformTypeAndShapeInferencing(
const ResolveOptions& options);
1595 common::Status InferAndVerifyTypeMatch(
Node& node,
const ONNX_NAMESPACE::OpSchema&
op,
const ResolveOptions& options);
1599 const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
1600 std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
1601 const Graph::ResolveOptions& options);
1612 common::Status VerifyNodeAndOpMatch(
const ResolveOptions& options);
1619 common::Status SetOuterScopeNodeArgs(
const std::unordered_set<std::string>& outer_scope_node_args);
1622 Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer,
bool is_external);
1624 std::vector<NodeArg*> CreateNodeArgs(
const google::protobuf::RepeatedPtrField<std::string>& names,
1627 void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto)
const;
1629 #endif // !defined(ORT_MINIMAL_BUILD)
1631 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1634 void FindAllSubgraphs(std::vector<Graph*>& subgraphs);
1640 void CleanUnusedInitializersAndNodeArgs(
const std::unordered_set<std::string>* initializer_names_to_preserve =
nullptr);
1642 Status PopulateNodeArgToProducerConsumerLookupsFromNodes();
1644 template <
typename TInstance>
1645 static auto GetConsumerNodesImpl(
1646 TInstance& instance,
const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
1647 std::vector<decltype(instance.GetNode(0))> results;
1648 auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
1649 if (iter != instance.node_arg_to_consumer_nodes_.end()) {
1650 results.reserve(iter->second.size());
1651 for (
auto node_index : iter->second) {
1652 results.push_back(instance.GetNode(node_index));
1658 template <
typename TInstance>
1659 static auto GetProducerNodeImpl(
1660 TInstance& instance,
const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
1661 auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
1662 if (iter != instance.node_arg_to_producer_node_.end()) {
1663 auto node_index = iter->second;
1664 return instance.GetNode(node_index);
1669 gsl::not_null<Node*> AllocateNode();
1675 Node& CreateFusedSubGraphNode(
const IndexedSubGraph& sub_graph,
const std::string& fused_node_name);
1676 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1683 ORT_ENFORCE(node_index < nodes_.size(),
"Validating no unexpected access using an invalid node_index. Got:",
1684 node_index,
" Max:", nodes_.size());
1685 return nodes_[node_index].get();
1688 const Model& owning_model_;
1695 ONNX_NAMESPACE::GraphProto* graph_proto_;
1698 ONNX_NAMESPACE::GraphProto deserialized_proto_data_;
1702 std::unordered_set<std::reference_wrapper<const std::string>,
1703 std::hash<std::string>, std::equal_to<std::string>>
1704 sparse_tensor_names_;
1706 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1709 std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
1710 RuntimeOptimizationRecordContainer& runtime_optimizations_;
1711 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1713 #if !defined(ORT_MINIMAL_BUILD)
1719 InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
1722 InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1723 #endif // !defined(ORT_MINIMAL_BUILD)
1727 std::vector<std::unique_ptr<Node>> nodes_;
1730 GraphNodes iterable_nodes_{nodes_};
1736 int num_of_nodes_ = 0;
1739 bool graph_resolve_needed_ =
false;
1741 bool graph_proto_sync_needed_ =
false;
1744 std::vector<NodeIndex> nodes_in_topological_order_;
1747 std::vector<const NodeArg*> graph_inputs_including_initializers_;
1748 bool graph_inputs_manually_set_ =
false;
1751 std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
1755 std::vector<const NodeArg*> graph_overridable_initializers_;
1758 std::vector<const NodeArg*> graph_outputs_;
1759 bool graph_outputs_manually_set_ =
false;
1762 std::unordered_set<const NodeArg*> value_info_;
1765 std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
1767 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1768 int name_generator_ = 0;
1772 std::unordered_set<std::string> generated_node_names_;
1776 std::unordered_set<std::string> generated_node_arg_names_;
1779 std::unordered_map<std::string, NodeIndex> node_arg_to_producer_node_;
1782 std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1783 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1785 const std::unordered_map<std::string, int> domain_to_version_;
1788 Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
1790 ResolveContext resolve_context_{*
this};
1793 Graph* parent_graph_;
1795 const Node* parent_node_;
1799 std::unordered_set<std::string> outer_scope_node_arg_names_;
1802 int num_resolves_ = 0;
1804 const logging::Logger& logger_;
1810 const bool strict_shape_type_inference_;
1813 const bool is_loaded_from_model_file_;
1816 #if !defined(ORT_MINIMAL_BUILD)
1821 std::ostream&
operator<<(std::ostream& out,
const NodeArg& node_arg);
1839 std::ostream&
operator<<(std::ostream& out,
const Graph& graph);
constexpr auto AsSpan(C &c)
void SetNodeArgType(NodeArg &arg, const ONNX_NAMESPACE::TypeProto &type_proto)
bool IsOuterScopeValue(const std::string &name) const
bool IsInitializedTensor(const std::string &name) const
void UpdateProducerNode(const std::string &node_arg_name, NodeIndex node_index)
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
The node refers to a primitive operator.
const std::string & ProviderType
const Node * GetNode(NodeIndex node_index) const
void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value)
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
void ForEachDef(std::function< void(const onnxruntime::NodeArg &, bool is_input)> func, bool include_missing_optional_defs=false) const
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
const std::vector< int > & InputArgCount() const noexcept
void SetInputs(gsl::span< const NodeArg *const > inputs)
const ONNX_NAMESPACE::GraphProto & ToGraphProto()
const Function * GetFunctionBody() const noexcept
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Node(NodeIndex index, Graph &graph)
int MaxNodeIndex() const noexcept
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph)
The node refers to a function.
const std::string & Description() const noexcept
void SetFunctionTemplate(const FunctionTemplate &func_template)
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path &external_file_path, const std::filesystem::path &model_file_path, size_t initializer_size_threshold, const OffsetAlignmentInfo &align_info) const
NodeIndex Index() const noexcept
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
int GetDstArgIndex() const
const Node * operator->() const
void UpdateConsumerNodes(const std::string &node_arg_name, std::initializer_list< Node * > nodes)
Node & FuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
const RuntimeOptimizationRecordContainer & RuntimeOptimizations() const
bool NodeProducesGraphOutput(const Node &node) const
GLsizei const GLfloat * value
const Graph * GetGraphAttribute(const std::string &attr_name) const
const NodeAttributes & GetAttributes() const noexcept
GLsizei const GLchar *const * path
const Node * ParentNode() const
size_t GetOutputEdgesCount() const noexcept
common::Status InjectExternalInitializedTensors(const InlinedHashMap< std::string, OrtValue > &external_initializers)
RuntimeOptimizationRecordContainer & MutableRuntimeOptimizations()
static common::Status ForEachMutableWithIndex(std::vector< NodeArg * > &node_args, std::function< common::Status(NodeArg &arg, size_t index)> func)
bool SetOpSchemaFromRegistryForNode(Node &node)
Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Node > &fbs_node) const
#define ORT_ENFORCE(condition,...)
bool IsSparseInitializer(const std::string &name) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
void FinalizeFuseSubGraph(const IndexedSubGraph &sub_graph, Node &fused_node)
void ToProto(ONNX_NAMESPACE::NodeProto &proto, bool update_subgraphs=false) const
FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool
NodeArg * GetNodeArg(const std::string &name)
std::vector< NodeArg * > & MutableImplicitInputDefs() noexcept
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
const NodeArg * GetNodeArg(const std::string &name) const
NodeConstIterator OutputNodesEnd() const noexcept
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph &fbs_graph, const Model &owning_model, const std::unordered_map< std::string, int > &domain_to_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Graph > &graph)
std::vector< Node * > GetMutableConsumerNodes(const std::string &node_arg_name)
std::unordered_map< std::string, gsl::not_null< const Graph * > > GetAttributeNameToSubgraphMap() const
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
bool IsInputsIncludingInitializers(const NodeArg *node_arg) const noexcept
const ONNX_NAMESPACE::OpSchema * Op() const noexcept
const std::string & OpType() const noexcept
const std::filesystem::path & ModelPath() const
void SetOutputs(gsl::span< const NodeArg *const > outputs)
void SetExecutionProviderType(ProviderType execution_provider_type)
void AddConsumerNode(const std::string &node_arg_name, Node *consumer)
NodeConstIterator InputNodesBegin() const noexcept
basic_string_view< char > string_view
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Node * GetNode(NodeIndex node_index)
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions)
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
const std::string & Description() const noexcept
GraphNodes & Nodes() noexcept
Graph & SetGraphResolveNeeded() noexcept
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge &fbs_node_edgs, const Graph &graph)
std::vector< gsl::not_null< const Graph * > > GetSubgraphs() const
NodeConstIterator OutputNodesBegin() const noexcept
int64_t allocation_granularity
std::set< EdgeEnd, EdgeEndCompare > EdgeSet
common::Status InjectExternalInitializersFromFilesInMemory(const InlinedHashMap< PathString, std::pair< char *, size_t >> &external_initializer_files)
NodeAttributes & GetMutableAttributes() noexcept
bool RemoveNode(NodeIndex node_index)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto)
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto &func_proto) const
int Priority() const noexcept
int NumberOfNodes() const noexcept
void UpdateConsumerNodes(const std::string &node_arg_name, gsl::span< Node *const > nodes)
void AddAttribute(std::string attr_name, int64_t value)
std::string GenerateNodeName(const std::string &base_name)
const Node & operator*() const
void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
std::vector< int > & MutableInputArgsCount()
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName)
bool ClearAttribute(const std::string &attr_name)
std::vector< int > GetNodeOutputsInGraphOutputs(const Node &node) const
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index)
static common::Status ForEachWithIndex(const ConstPointerContainer< std::vector< NodeArg * >> &node_args, std::function< common::Status(const NodeArg &arg, size_t index)> func)
const ONNX_NAMESPACE::TensorProto * GetInitializer(const std::string &name, bool check_outer_scope) const
Graph * MutableParentGraph()
void SetSinceVersion(int since_version) noexcept
const std::string & Name() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const
Status RemovedUnusedInitializersOrtFormat()
void RemoveConsumerNode(const std::string &node_arg_name, Node *consumer)
size_t GetInputEdgesCount() const noexcept
constexpr const char * kOnnxDomain
GLuint const GLchar * name
std::set< std::string > control_inputs
bool no_proto_sync_required
std::function< bool(NodeIndex)> NodeFilterFunc
Status InlineFunction(Node &node)
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::filesystem::path &external_file_path, const std::filesystem::path &model_file_path, size_t initializer_size_threshold) const
flatbuffers::Offset< onnxruntime::fbs::NodeEdge > SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder &builder) const
bool GraphProtoSyncNeeded() const noexcept
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto &func_to_inline)
Graph & SetGraphProtoSyncNeeded() noexcept
const std::vector< const NodeArg * > & GetOutputs() const noexcept
void RemoveInitializedTensor(const std::string &tensor_name)
std::vector< int > input_arg_count
bool IsOutput(const NodeArg *node_arg) const noexcept
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Graph > &fbs_graph) const
NodeArg * GetNodeArgIncludingParentGraphs(const std::string &node_arg_name)
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope) const
void KahnsTopologicalSort(const std::function< void(const Node *)> &enter, const std::function< bool(const Node *, const Node *)> &comp) const
const Model & GetModel() const
void AddOuterScopeNodeArg(const std::string &name)
const std::unordered_set< std::string > * initializer_names_to_preserve
common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer)
const Node * GetProducerNode(const std::string &node_arg_name) const
void SetOutputs(std::initializer_list< const NodeArg * > outputs)
Status UpdateShapeInference(Node &node)
bool GraphResolveNeeded() const noexcept
const std::string & Name() const noexcept
std::string GenerateNodeArgName(const std::string &base_name)
std::vector< NodeArg * > implicit_input_defs
Graph * GetMutableGraphAttribute(const std::string &attr_name)
ProviderType GetExecutionProviderType() const noexcept
GLenum GLsizei GLsizei GLint * values
void SetPriority(int priority) noexcept
void SetInputs(std::initializer_list< const NodeArg * > inputs)
bool operator==(const NodeConstIterator &p_other) const
Node * GetMutableProducerNode(const std::string &node_arg_name)
Node & AddNode(const Node &other)
const std::string & Domain() const noexcept
NodeArg & GetOrCreateNodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
void SetName(const std::string &name)
const Node & GetNode() const noexcept
bool StrictShapeTypeInference() const
#define ORT_RETURN_IF_ERROR(expr)
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
GA_API const UT_StringHolder N
bool operator()(const EdgeEnd &lhs, const EdgeEnd &rhs) const
ADD_ATTR_INTERFACES(float)
NodeConstIterator InputNodesEnd() const noexcept
const std::set< std::string > & ControlInputs() const noexcept
const logging::Logger & GetLogger() const
std::vector< NodeArg * > & MutableOutputDefs() noexcept
EdgeEnd(const Node &node, int src_arg_index, int dst_arg_index) noexcept
const std::filesystem::path & ModelPath() const noexcept
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node &fbs_node, Graph &graph, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Node > &node)
const std::vector< const NodeArg * > & GetOverridableInitializers() const
EdgeConstIterator InputEdgesBegin() const noexcept
const std::vector< const NodeArg * > & GetInputs() const noexcept
NodeConstIterator(EdgeConstIterator p_iter)
int SinceVersion() const noexcept
#define ORT_IGNORE_RETURN_VALUE(fn)
Node & BeginFuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
std::vector< NodeArg * > output_defs
std::vector< NodeArg * > input_defs
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto &tensor_proto)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, gsl::span< NodeArg *const > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
const GraphNodes & Nodes() const noexcept
void ReplaceDefs(const std::map< const onnxruntime::NodeArg *, onnxruntime::NodeArg * > &replacements)
EdgeConstIterator InputEdgesEnd() const noexcept
EdgeConstIterator OutputEdgesBegin() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::TypeProto > ArgNameToTypeMap
void CleanAllInitializedTensors() noexcept
void SetDescription(const std::string &description)
bool ContainsSubgraph() const
ConstGraphNodes FilteredNodes(GraphNodes::NodeFilterFunc &&filter_func) const noexcept
std::vector< NodeArg * > & MutableInputDefs() noexcept
bool CanOverrideInitializer() const noexcept
Node::Type NodeType() const noexcept
Node(std::string_view name, std::string_view op_type, std::string_view description, gsl::span< NodeArg *const > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes, std::string_view domain)
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
void AddValueInfo(const NodeArg *new_value_info)
bool operator!=(const NodeConstIterator &p_other) const
Status InlineIfSubgraph(bool condition_value, Node &if_node, const logging::Logger &logger)
EdgeConstIterator OutputEdgesEnd() const noexcept
EdgeSet::const_iterator EdgeConstIterator
int PruneRemovableAttributes(gsl::span< const std::string > removable_attributes)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
int GetSrcArgIndex() const
const Graph * ParentGraph() const
std::unordered_map< std::string, gsl::not_null< Graph * > > & GetMutableMapOfAttributeNameToSubgraph()
void ReverseDFSFrom(gsl::span< NodeIndex const > from, const std::function< void(const Node *)> &enter, const std::function< void(const Node *)> &leave, const std::function< bool(const Node *, const Node *)> &comp={}) const
void AddAttribute(std::string attr_name, const char(&value)[N])
bool CanBeInlined() const
bool Exists() const noexcept