10 #include <unordered_map>
11 #include <unordered_set>
16 #pragma warning(disable : 4244)
19 #if !defined(ORT_MINIMAL_BUILD)
20 #include "onnx/defs/schema.h"
23 #include "onnx/defs/data_type_utils.h"
25 #include "onnx/onnx_pb.h"
26 #include "onnx/onnx-operators_pb.h"
37 #include "core/common/path.h"
44 #if !defined(ORT_MINIMAL_BUILD)
45 #include "core/graph/function_template.h"
49 #include "core/graph/ort_format_load_options.h"
51 namespace flatbuffers {
52 class FlatBufferBuilder;
57 namespace onnxruntime {
59 struct IndexedSubGraph;
63 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
64 class RuntimeOptimizationRecordContainer;
85 explicit Node() =
default;
87 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
91 gsl::span<NodeArg* const> input_args,
92 gsl::span<NodeArg* const> output_args,
96 std::vector<NodeArg*>{input_args.begin(), input_args.end()},
97 std::vector<NodeArg*>{output_args.begin(), output_args.end()},
119 EdgeEnd(
const Node& node,
int src_arg_index,
int dst_arg_index) noexcept;
139 const int src_arg_index_;
140 const int dst_arg_index_;
186 #if !defined(ORT_MINIMAL_BUILD)
189 const ONNX_NAMESPACE::OpSchema*
Op()
const noexcept {
return op_; }
210 for (
size_t index = 0; index < node_args.size(); ++index) {
211 auto arg = node_args[index];
241 #if !defined(ORT_MINIMAL_BUILD)
252 for (
size_t index = 0; index < node_args.size(); ++index) {
253 auto arg = node_args[index];
265 #endif // !defined(ORT_MINIMAL_BUILD)
267 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
282 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
297 using EdgeSet = std::set<EdgeEnd, EdgeEndCompare>;
371 #define ADD_ATTR_SINGLE_INTERFACE(Type) \
372 void AddAttribute(std::string attr_name, Type value)
374 #define ADD_ATTR_LIST_INTERFACE(Type) \
375 void AddAttribute(std::string attr_name, gsl::span<const Type> values)
377 #define ADD_ATTR_INTERFACES(Type) \
378 ADD_ATTR_SINGLE_INTERFACE(Type); \
379 ADD_ATTR_LIST_INTERFACE(Type)
384 #if !defined(DISABLE_SPARSE_TENSORS)
391 #undef ADD_ATTR_SINGLE_INTERFACE
392 #undef ADD_ATTR_LIST_INTERFACE
393 #undef ADD_ATTR_INTERFACES
405 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
408 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
410 #if !defined(ORT_MINIMAL_BUILD)
425 #endif // !defined(ORT_MINIMAL_BUILD)
431 return !attr_to_subgraph_map_.empty();
436 std::vector<gsl::not_null<const Graph*>>
GetSubgraphs()
const;
443 return attr_to_subgraph_map_;
457 execution_provider_type_ = execution_provider_type;
466 bool include_missing_optional_defs =
false)
const;
468 #if !defined(ORT_MINIMAL_BUILD)
472 void ReplaceDefs(
const std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*>& replacements);
479 void ToProto(ONNX_NAMESPACE::NodeProto& proto,
bool update_subgraphs =
false)
const;
491 const OrtFormatLoadOptions& load_options,
495 const OrtFormatLoadOptions& load_options,
568 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
572 const std::vector<NodeArg*>& input_args,
573 const std::vector<NodeArg*>& output_args,
578 Definitions& MutableDefinitions() noexcept;
581 Relationships& MutableRelationships() noexcept;
583 void SetNodeType(
Node::
Type node_type) noexcept { node_type_ = node_type; }
589 const std::vector<std::unique_ptr<Graph>>& MutableSubgraphs() noexcept {
return subgraphs_; }
594 const Definitions& GetDefinitions()
const noexcept {
return definitions_; }
595 const Relationships& GetRelationships()
const noexcept {
return relationships_; }
609 #if !defined(ORT_MINIMAL_BUILD)
611 const ONNX_NAMESPACE::OpSchema* op_ =
nullptr;
614 const FunctionTemplate* func_template_ =
nullptr;
621 int since_version_ = -1;
626 std::unique_ptr<Function> func_body_ =
nullptr;
632 Definitions definitions_;
635 Relationships relationships_;
645 Graph* graph_ =
nullptr;
648 std::unordered_map<std::string, gsl::not_null<Graph*>> attr_to_subgraph_map_;
651 std::vector<std::unique_ptr<Graph>> subgraphs_;
682 #if !defined(ORT_MINIMAL_BUILD)
697 #if !defined(DISABLE_EXTERNAL_INITIALIZERS)
702 #endif // !defined(DISABLE_EXTERNAL_INITIALIZERS)
704 #endif // !defined(ORT_MINIMAL_BUILD)
706 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
717 #if !defined(DISABLE_SPARSE_TENSORS)
759 const std::vector<const NodeArg*>&
GetInputs()
const noexcept {
return graph_inputs_excluding_initializers_; }
765 return graph_inputs_including_initializers_;
770 return std::find(graph_inputs_including_initializers_.begin(),
771 graph_inputs_including_initializers_.end(), node_arg) != graph_inputs_including_initializers_.end();
779 return graph_overridable_initializers_;
784 const std::vector<const NodeArg*>&
GetOutputs()
const noexcept {
return graph_outputs_; }
787 return std::find(graph_outputs_.begin(), graph_outputs_.end(), node_arg) != graph_outputs_.end();
794 auto end_outputs = graph_outputs_.cend();
796 if (
std::find(graph_outputs_.cbegin(), end_outputs, output_def) != end_outputs) {
806 std::vector<int> indexes;
809 indexes.push_back(output_idx);
823 #if !defined(ORT_MINIMAL_BUILD)
865 auto iter = node_args_.find(name);
866 if (iter != node_args_.end()) {
867 return iter->second.get();
887 auto iter = node_args_.find(name);
888 if (iter != node_args_.end()) {
889 return *(iter->second);
891 auto result = node_args_.insert(std::make_pair(name, std::make_unique<NodeArg>(name, p_arg_type)));
892 return *(
result.first->second);
895 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
901 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
903 #if !defined(ORT_MINIMAL_BUILD)
912 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
927 gsl::span<NodeArg* const> input_args,
928 gsl::span<NodeArg* const> output_args,
935 std::initializer_list<NodeArg*> input_args,
936 std::initializer_list<NodeArg*> output_args,
939 return AddNode(name, op_type, description,
948 gsl::span<NodeArg* const> input_args,
949 std::initializer_list<NodeArg*> output_args,
952 return AddNode(name, op_type, description,
961 std::initializer_list<NodeArg*> input_args,
962 gsl::span<NodeArg* const> output_args,
965 return AddNode(name, op_type, description,
1000 #if !defined(ORT_MINIMAL_BUILD)
1007 #endif // !defined(ORT_MINIMAL_BUILD)
1012 graph_resolve_needed_ =
true;
1018 return graph_resolve_needed_;
1023 graph_proto_sync_needed_ =
true;
1029 return graph_proto_sync_needed_;
1040 const std::function<
void(
const Node*)>& enter,
1041 const std::function<
void(
const Node*)>& leave,
1042 const std::function<
bool(
const Node*,
const Node*)>& comp = {})
const;
1052 const std::function<
void(
const Node*)>& enter,
1053 const std::function<
void(
const Node*)>& leave,
1054 const std::function<
bool(
const Node*,
const Node*)>& comp = {})
const;
1065 const std::function<
void(
const Node*)>& enter,
1066 const std::function<
void(
const Node*)>& leave,
1067 const std::function<
bool(
const Node*,
const Node*)>& comp,
1068 const std::function<
bool(
const Node*,
const Node*)>& stop)
const;
1070 #if !defined(ORT_MINIMAL_BUILD)
1076 const std::function<
bool(
const Node*,
const Node*)>& comp)
const;
1082 return domain_to_version_;
1085 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1101 #if !defined(ORT_MINIMAL_BUILD)
1113 size_t initializer_size_threshold)
const;
1155 void SetInputs(gsl::span<const NodeArg* const> inputs);
1157 void SetInputs(std::initializer_list<const NodeArg*> inputs) {
1162 return owning_model_;
1173 void SetOutputs(gsl::span<const NodeArg* const> outputs);
1179 #endif // !defined(ORT_MINIMAL_BUILD)
1181 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1186 return GetProducerNodeImpl(*
this, node_arg_name);
1190 return GetProducerNodeImpl(*
this, node_arg_name);
1194 auto iter = node_arg_to_producer_node_.find(node_arg_name);
1196 if (iter != node_arg_to_producer_node_.end()) {
1197 iter->second = node_index;
1199 node_arg_to_producer_node_[node_arg_name] = node_index;
1204 return GetConsumerNodesImpl(*
this, node_arg_name);
1209 node_arg_to_consumer_nodes_[node_arg_name].insert(consumer->
Index());
1214 node_arg_to_consumer_nodes_[node_arg_name].erase(consumer->
Index());
1216 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1218 #if !defined(ORT_MINIMAL_BUILD)
1220 return GetConsumerNodesImpl(*
this, node_arg_name);
1225 auto& nodes_for_arg = node_arg_to_consumer_nodes_[node_arg_name];
1226 if (!nodes_for_arg.empty()) {
1227 nodes_for_arg.clear();
1230 nodes_for_arg.reserve(nodes.size());
1231 for (
Node* node : nodes) {
1232 nodes_for_arg.insert(node->Index());
1272 return Resolve(default_options);
1276 return outer_scope_node_arg_names_;
1282 #endif // !defined(ORT_MINIMAL_BUILD)
1289 if (!parent_node_)
return false;
1291 return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(),
1292 [&name](
const NodeArg* implicit_input) {
1293 return implicit_input->Name() == name;
1297 #if !defined(ORT_MINIMAL_BUILD)
1304 Graph(
Graph& parent_graph,
const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);
1308 ONNX_NAMESPACE::GraphProto& subgraph_proto,
1309 const std::unordered_map<std::string, int>& domain_version_map,
1311 bool strict_shape_type_inference);
1317 const std::unordered_map<std::string, int>& domain_to_version,
1318 #
if !defined(ORT_MINIMAL_BUILD)
1321 const OrtFormatLoadOptions& load_options,
1326 Graph& parent_graph,
const Node& parent_node,
1327 const OrtFormatLoadOptions& load_options,
1330 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1332 return runtime_optimizations_;
1336 return runtime_optimizations_;
1338 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1348 const std::unordered_map<std::string, int>& domain_to_version,
1349 #
if !defined(ORT_MINIMAL_BUILD)
1352 Graph* parent_graph,
const Node* parent_node,
1354 bool strict_shape_type_inference);
1358 const OrtFormatLoadOptions& load_options);
1360 #if !defined(ORT_MINIMAL_BUILD)
1364 ONNX_NAMESPACE::GraphProto* graph_proto,
1365 const std::unordered_map<std::string, int>& domain_to_version,
1369 bool strict_shape_type_inference);
1373 ONNX_NAMESPACE::GraphProto* graph_proto,
1374 const std::unordered_map<std::string, int>& domain_to_version,
1377 Graph* parent_graph,
1378 const Node* parent_node,
1380 bool strict_shape_type_inference);
1385 void InitializeStateFromModelFileGraphProto();
1388 Node&
AddNode(
const ONNX_NAMESPACE::NodeProto& node_proto,
1398 graph_resolve_needed_ = needed;
1403 graph_proto_sync_needed_ = needed;
1412 struct ResolveContext {
1413 ResolveContext(
const Graph& owning_graph) : graph{owning_graph} {
1416 std::unordered_map<std::string_view, std::pair<Node*, int>> output_args;
1417 std::unordered_set<std::string_view> inputs_and_initializers;
1418 std::unordered_map<std::string_view, NodeIndex> node_name_to_index;
1419 std::unordered_set<Node*> nodes_with_subgraphs;
1430 output_args.clear();
1431 inputs_and_initializers.clear();
1432 node_name_to_index.clear();
1433 nodes_with_subgraphs.clear();
1437 bool IsInputInitializerOrOutput(
const std::string&
name,
bool check_ancestors)
const;
1447 void ComputeOverridableInitializers();
1449 #if !defined(ORT_MINIMAL_BUILD)
1452 common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
1462 common::Status PerformTypeAndShapeInferencing(
const ResolveOptions& options);
1465 void FindAllSubgraphs(std::vector<Graph*>& subgraphs);
1470 common::Status InferAndVerifyTypeMatch(
Node& node,
const ONNX_NAMESPACE::OpSchema& op,
const ResolveOptions& options);
1474 const std::vector<const ONNX_NAMESPACE::TypeProto*>& input_types,
1475 std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
1476 const Graph::ResolveOptions& options);
1487 common::Status VerifyNodeAndOpMatch(
const ResolveOptions& options);
1494 common::Status SetOuterScopeNodeArgs(
const std::unordered_set<std::string>& outer_scope_node_args);
1497 Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer,
bool is_external);
1500 void CleanUnusedInitializersAndNodeArgs(
const std::unordered_set<std::string>* initializer_names_to_preserve =
nullptr);
1502 std::vector<NodeArg*> CreateNodeArgs(
const google::protobuf::RepeatedPtrField<std::string>& names,
1505 void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto)
const;
1507 #endif // !defined(ORT_MINIMAL_BUILD)
1509 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1510 Status PopulateNodeArgToProducerConsumerLookupsFromNodes();
1512 template <
typename TInstance>
1513 static auto GetConsumerNodesImpl(
1514 TInstance& instance,
const std::string& node_arg_name) -> std::vector<decltype(instance.GetNode(0))> {
1515 std::vector<decltype(instance.GetNode(0))> results;
1516 auto iter = instance.node_arg_to_consumer_nodes_.find(node_arg_name);
1517 if (iter != instance.node_arg_to_consumer_nodes_.end()) {
1518 results.reserve(iter->second.size());
1519 for (
auto node_index : iter->second) {
1520 results.push_back(instance.GetNode(node_index));
1526 template <
typename TInstance>
1527 static auto GetProducerNodeImpl(
1528 TInstance& instance,
const std::string& node_arg_name) -> decltype(instance.GetNode(0)) {
1529 auto iter = instance.node_arg_to_producer_node_.find(node_arg_name);
1530 if (iter != instance.node_arg_to_producer_node_.end()) {
1531 auto node_index = iter->second;
1532 return instance.GetNode(node_index);
1537 gsl::not_null<Node*> AllocateNode();
1543 Node& CreateFusedSubGraphNode(
const IndexedSubGraph& sub_graph,
const std::string& fused_node_name);
1544 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1551 ORT_ENFORCE(node_index < nodes_.size(),
"Validating no unexpected access using an invalid node_index. Got:",
1552 node_index,
" Max:", nodes_.size());
1553 return nodes_[node_index].get();
1556 const Model& owning_model_;
1563 ONNX_NAMESPACE::GraphProto* graph_proto_;
1566 ONNX_NAMESPACE::GraphProto deserialized_proto_data_;
1570 std::unordered_set<std::reference_wrapper<const std::string>,
1571 std::hash<std::string>, std::equal_to<std::string>>
1572 sparse_tensor_names_;
1574 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1577 std::unique_ptr<RuntimeOptimizationRecordContainer> runtime_optimizations_ptr_;
1578 RuntimeOptimizationRecordContainer& runtime_optimizations_;
1579 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1581 #if !defined(ORT_MINIMAL_BUILD)
1587 InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
1590 InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1591 #endif // !defined(ORT_MINIMAL_BUILD)
1595 std::vector<std::unique_ptr<Node>> nodes_;
1598 GraphNodes iterable_nodes_{nodes_};
1604 int num_of_nodes_ = 0;
1607 bool graph_resolve_needed_ =
false;
1609 bool graph_proto_sync_needed_ =
false;
1612 std::vector<NodeIndex> nodes_in_topological_order_;
1615 std::vector<const NodeArg*> graph_inputs_including_initializers_;
1616 bool graph_inputs_manually_set_ =
false;
1619 std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
1623 std::vector<const NodeArg*> graph_overridable_initializers_;
1626 std::vector<const NodeArg*> graph_outputs_;
1627 bool graph_outputs_manually_set_ =
false;
1630 std::unordered_set<const NodeArg*> value_info_;
1633 std::unordered_map<std::string, std::unique_ptr<NodeArg>> node_args_;
1635 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1636 int name_generator_ = 0;
1640 std::unordered_set<std::string> generated_node_names_;
1644 std::unordered_set<std::string> generated_node_arg_names_;
1647 std::unordered_map<std::string, NodeIndex> node_arg_to_producer_node_;
1650 std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1651 #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1653 const std::unordered_map<std::string, int> domain_to_version_;
1656 Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
1658 ResolveContext resolve_context_{*
this};
1661 Graph* parent_graph_;
1663 const Node* parent_node_;
1667 std::unordered_set<std::string> outer_scope_node_arg_names_;
1670 int num_resolves_ = 0;
1672 const logging::Logger& logger_;
1678 const bool strict_shape_type_inference_;
1681 const bool is_loaded_from_model_file_;
1684 #if !defined(ORT_MINIMAL_BUILD)
1689 std::ostream&
operator<<(std::ostream& out,
const NodeArg& node_arg);
1707 std::ostream&
operator<<(std::ostream& out,
const Graph& graph);
constexpr auto AsSpan(C &c)
void SetNodeArgType(NodeArg &arg, const ONNX_NAMESPACE::TypeProto &type_proto)
bool IsOuterScopeValue(const std::string &name) const
bool IsInitializedTensor(const std::string &name) const
void UpdateProducerNode(const std::string &node_arg_name, NodeIndex node_index)
std::unordered_map< std::string, const ONNX_NAMESPACE::TensorProto * > InitializedTensorSet
The node refers to a primitive operator.
const std::string & ProviderType
const Node * GetNode(NodeIndex node_index) const
void AddAttributeProto(ONNX_NAMESPACE::AttributeProto value)
const InitializedTensorSet & GetAllInitializedTensors() const noexcept
void ForEachDef(std::function< void(const onnxruntime::NodeArg &, bool is_input)> func, bool include_missing_optional_defs=false) const
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const
ConstPointerContainer< std::vector< NodeArg * > > InputDefs() const noexcept
const std::vector< int > & InputArgCount() const noexcept
void SetInputs(gsl::span< const NodeArg *const > inputs)
const ONNX_NAMESPACE::GraphProto & ToGraphProto()
const Function * GetFunctionBody() const noexcept
std::shared_ptr< IOnnxRuntimeOpSchemaCollection > IOnnxRuntimeOpSchemaCollectionPtr
Node(NodeIndex index, Graph &graph)
int MaxNodeIndex() const noexcept
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph)
The node refers to a function.
const std::string & Description() const noexcept
void SetFunctionTemplate(const FunctionTemplate &func_template)
NodeIndex Index() const noexcept
void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
int GetDstArgIndex() const
const Node * operator->() const
void UpdateConsumerNodes(const std::string &node_arg_name, std::initializer_list< Node * > nodes)
Node & FuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
const RuntimeOptimizationRecordContainer & RuntimeOptimizations() const
bool NodeProducesGraphOutput(const Node &node) const
GLsizei const GLchar *const * string
const Graph * GetGraphAttribute(const std::string &attr_name) const
const NodeAttributes & GetAttributes() const noexcept
const Node * ParentNode() const
size_t GetOutputEdgesCount() const noexcept
common::Status InjectExternalInitializedTensors(const InlinedHashMap< std::string, OrtValue > &external_initializers)
RuntimeOptimizationRecordContainer & MutableRuntimeOptimizations()
static common::Status ForEachMutableWithIndex(std::vector< NodeArg * > &node_args, std::function< common::Status(NodeArg &arg, size_t index)> func)
bool SetOpSchemaFromRegistryForNode(Node &node)
Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Node > &fbs_node) const
#define ORT_ENFORCE(condition,...)
bool IsSparseInitializer(const std::string &name) const
const std::unordered_map< std::string, int > & DomainToVersionMap() const noexcept
void FinalizeFuseSubGraph(const IndexedSubGraph &sub_graph, Node &fused_node)
void ToProto(ONNX_NAMESPACE::NodeProto &proto, bool update_subgraphs=false) const
**But if you need a result
NodeArg * GetNodeArg(const std::string &name)
std::vector< NodeArg * > & MutableImplicitInputDefs() noexcept
auto arg(const Char *name, const T &arg) -> detail::named_arg< Char, T >
const NodeArg * GetNodeArg(const std::string &name) const
NodeConstIterator OutputNodesEnd() const noexcept
static Status LoadFromOrtFormat(const onnxruntime::fbs::Graph &fbs_graph, const Model &owning_model, const std::unordered_map< std::string, int > &domain_to_version, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Graph > &graph)
std::vector< Node * > GetMutableConsumerNodes(const std::string &node_arg_name)
std::unordered_map< std::string, gsl::not_null< const Graph * > > GetAttributeNameToSubgraphMap() const
ConstPointerContainer< std::vector< NodeArg * > > OutputDefs() const noexcept
bool IsInputsIncludingInitializers(const NodeArg *node_arg) const noexcept
const ONNX_NAMESPACE::OpSchema * Op() const noexcept
const std::string & OpType() const noexcept
void SetOutputs(gsl::span< const NodeArg *const > outputs)
void SetExecutionProviderType(ProviderType execution_provider_type)
void AddConsumerNode(const std::string &node_arg_name, Node *consumer)
NodeConstIterator InputNodesBegin() const noexcept
basic_string_view< char > string_view
ConstPointerContainer< std::vector< NodeArg * > > ImplicitInputDefs() const noexcept
Node * GetNode(NodeIndex node_index)
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions)
const std::unordered_set< const NodeArg * > & GetValueInfo() const noexcept
const std::string & Description() const noexcept
GraphNodes & Nodes() noexcept
Graph & SetGraphResolveNeeded() noexcept
Status LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge &fbs_node_edgs, const Graph &graph)
std::vector< gsl::not_null< const Graph * > > GetSubgraphs() const
NodeConstIterator OutputNodesBegin() const noexcept
std::set< EdgeEnd, EdgeEndCompare > EdgeSet
NodeAttributes & GetMutableAttributes() noexcept
bool RemoveNode(NodeIndex node_index)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
ADD_ATTR_SINGLE_INTERFACE(ONNX_NAMESPACE::GraphProto)
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto &func_proto) const
int Priority() const noexcept
int NumberOfNodes() const noexcept
void UpdateConsumerNodes(const std::string &node_arg_name, gsl::span< Node *const > nodes)
void AddAttribute(std::string attr_name, int64_t value)
std::string GenerateNodeName(const std::string &base_name)
const Node & operator*() const
void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index)
std::vector< int > & MutableInputArgsCount()
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName)
bool ClearAttribute(const std::string &attr_name)
std::vector< int > GetNodeOutputsInGraphOutputs(const Node &node) const
bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index)
static common::Status ForEachWithIndex(const ConstPointerContainer< std::vector< NodeArg * >> &node_args, std::function< common::Status(const NodeArg &arg, size_t index)> func)
const ONNX_NAMESPACE::TensorProto * GetInitializer(const std::string &name, bool check_outer_scope) const
Graph * MutableParentGraph()
void SetSinceVersion(int since_version) noexcept
const std::string & Name() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::AttributeProto > NodeAttributes
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
bool GetInitializedTensor(const std::string &tensor_name, const ONNX_NAMESPACE::TensorProto *&value) const
const Path & ModelPath() const noexcept
void RemoveConsumerNode(const std::string &node_arg_name, Node *consumer)
size_t GetInputEdgesCount() const noexcept
constexpr const char * kOnnxDomain
GLuint const GLchar * name
std::set< std::string > control_inputs
bool no_proto_sync_required
std::function< bool(NodeIndex)> NodeFilterFunc
Status InlineFunction(Node &node)
flatbuffers::Offset< onnxruntime::fbs::NodeEdge > SaveEdgesToOrtFormat(flatbuffers::FlatBufferBuilder &builder) const
bool GraphProtoSyncNeeded() const noexcept
Graph & SetGraphProtoSyncNeeded() noexcept
const std::vector< const NodeArg * > & GetOutputs() const noexcept
void RemoveInitializedTensor(const std::string &tensor_name)
std::vector< int > input_arg_count
bool IsOutput(const NodeArg *node_arg) const noexcept
common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder &builder, flatbuffers::Offset< onnxruntime::fbs::Graph > &fbs_graph) const
NodeArg * GetNodeArgIncludingParentGraphs(const std::string &node_arg_name)
const ONNX_NAMESPACE::TensorProto * GetConstantInitializer(const std::string &name, bool check_outer_scope) const
void KahnsTopologicalSort(const std::function< void(const Node *)> &enter, const std::function< bool(const Node *, const Node *)> &comp) const
const Model & GetModel() const
void AddOuterScopeNodeArg(const std::string &name)
const std::unordered_set< std::string > * initializer_names_to_preserve
common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer)
const Node * GetProducerNode(const std::string &node_arg_name) const
void SetOutputs(std::initializer_list< const NodeArg * > outputs)
Status UpdateShapeInference(Node &node)
bool GraphResolveNeeded() const noexcept
const std::string & Name() const noexcept
std::string GenerateNodeArgName(const std::string &base_name)
std::vector< NodeArg * > implicit_input_defs
Graph * GetMutableGraphAttribute(const std::string &attr_name)
ProviderType GetExecutionProviderType() const noexcept
GLenum GLsizei GLsizei GLint * values
void SetPriority(int priority) noexcept
void SetInputs(std::initializer_list< const NodeArg * > inputs)
bool operator==(const NodeConstIterator &p_other) const
Node * GetMutableProducerNode(const std::string &node_arg_name)
const Path & ModelPath() const
Node & AddNode(const Node &other)
const std::string & Domain() const noexcept
NodeArg & GetOrCreateNodeArg(const std::string &name, const ONNX_NAMESPACE::TypeProto *p_arg_type)
void SetName(const std::string &name)
const Node & GetNode() const noexcept
bool StrictShapeTypeInference() const
#define ORT_RETURN_IF_ERROR(expr)
std::vector< const Node * > GetConsumerNodes(const std::string &node_arg_name) const
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
const std::unordered_set< std::string > & GetOuterScopeNodeArgNames() const noexcept
GA_API const UT_StringHolder N
bool operator()(const EdgeEnd &lhs, const EdgeEnd &rhs) const
ADD_ATTR_INTERFACES(float)
NodeConstIterator InputNodesEnd() const noexcept
const std::set< std::string > & ControlInputs() const noexcept
const logging::Logger & GetLogger() const
std::vector< NodeArg * > & MutableOutputDefs() noexcept
EdgeEnd(const Node &node, int src_arg_index, int dst_arg_index) noexcept
static Status LoadFromOrtFormat(const onnxruntime::fbs::Node &fbs_node, Graph &graph, const OrtFormatLoadOptions &load_options, const logging::Logger &logger, std::unique_ptr< Node > &node)
const std::vector< const NodeArg * > & GetOverridableInitializers() const
EdgeConstIterator InputEdgesBegin() const noexcept
const std::vector< const NodeArg * > & GetInputs() const noexcept
NodeConstIterator(EdgeConstIterator p_iter)
int SinceVersion() const noexcept
#define ORT_IGNORE_RETURN_VALUE(fn)
Node & BeginFuseSubGraph(const IndexedSubGraph &sub_graph, const std::string &fused_node_name)
std::vector< NodeArg * > output_defs
std::vector< NodeArg * > input_defs
std::ostream & operator<<(std::ostream &out, AllocKind alloc_kind)
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto &tensor_proto)
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, gsl::span< NodeArg *const > input_args, std::initializer_list< NodeArg * > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
const GraphNodes & Nodes() const noexcept
void ReplaceDefs(const std::map< const onnxruntime::NodeArg *, onnxruntime::NodeArg * > &replacements)
EdgeConstIterator InputEdgesEnd() const noexcept
EdgeConstIterator OutputEdgesBegin() const noexcept
std::unordered_map< std::string, ONNX_NAMESPACE::TypeProto > ArgNameToTypeMap
void CleanAllInitializedTensors() noexcept
void SetDescription(const std::string &description)
bool ContainsSubgraph() const
ConstGraphNodes FilteredNodes(GraphNodes::NodeFilterFunc &&filter_func) const noexcept
std::vector< NodeArg * > & MutableInputDefs() noexcept
bool CanOverrideInitializer() const noexcept
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string &external_file_name, size_t initializer_size_threshold) const
Node::Type NodeType() const noexcept
Node(std::string_view name, std::string_view op_type, std::string_view description, gsl::span< NodeArg *const > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes, std::string_view domain)
const std::vector< const NodeArg * > & GetInputsIncludingInitializers() const noexcept
void AddValueInfo(const NodeArg *new_value_info)
bool operator!=(const NodeConstIterator &p_other) const
EdgeConstIterator OutputEdgesEnd() const noexcept
EdgeSet::const_iterator EdgeConstIterator
Node & AddNode(const std::string &name, const std::string &op_type, const std::string &description, std::initializer_list< NodeArg * > input_args, gsl::span< NodeArg *const > output_args, const NodeAttributes *attributes=nullptr, const std::string &domain=kOnnxDomain)
int GetSrcArgIndex() const
const Graph * ParentGraph() const
FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr &out) -> bool
void ReverseDFSFrom(gsl::span< NodeIndex const > from, const std::function< void(const Node *)> &enter, const std::function< void(const Node *)> &leave, const std::function< bool(const Node *, const Node *)> &comp={}) const
void AddAttribute(std::string attr_name, const char(&value)[N])
bool CanBeInlined() const
bool Exists() const noexcept