HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
graph_transformer.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 #include <string>
6 
7 #include "core/common/common.h"
11 
12 namespace onnxruntime {
13 
14 /**
15 @class GraphTransformer
16 
17 The interface for in-place transformation of a Graph.
18 */
20  public:
22  const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
23  : name_(name), compatible_provider_types_(compatible_execution_providers) {
24  }
25 
26  virtual ~GraphTransformer() = default;
27 
28  /** Gets the name of this graph transformer. */
29  const std::string& Name() const noexcept {
30  return name_;
31  }
32 
34  return compatible_provider_types_;
35  }
36 
37  /** Apply the in-place transformation defined by this transformer to the provided Graph instance.
38  @param[out] modified Set to true if the Graph was modified.
39  @returns Status with success or error information.
40  */
41  common::Status Apply(Graph& graph, bool& modified, const logging::Logger& logger) const;
42 
43  virtual bool ShouldOnlyApplyOnce() const { return false; }
44 
45  protected:
46  /** Helper method to call ApplyImpl on any subgraphs in the Node. */
47  common::Status Recurse(Node& node, bool& modified, int graph_level, const logging::Logger& logger) const {
48  int subgraph_level = ++graph_level;
49  for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
50  auto& subgraph = *entry.second;
51  ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level, logger));
52  }
53 
54  return Status::OK();
55  }
56 
57  private:
58  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer);
59 
60  // Apply the transform to the graph.
61  // graph_level is 0 for the main graph, and is incremented when descending into the subgraph of a node.
62  // You MUST call Recurse for all valid Nodes in the graph to ensure any subgraphs in control flow nodes
63  // (Scan/If/Loop) are processed as well.
64  // You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases
65  // the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph
66  // completes (if 'modified' is true) should suffice.
67  virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger)
68  const = 0;
69 
70  const std::string name_;
71  const InlinedHashSet<std::string_view> compatible_provider_types_;
72 };
73 } // namespace onnxruntime
common::Status Apply(Graph &graph, bool &modified, const logging::Logger &logger) const
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
common::Status Recurse(Node &node, bool &modified, int graph_level, const logging::Logger &logger) const
virtual bool ShouldOnlyApplyOnce() const
virtual ~GraphTransformer()=default
const InlinedHashSet< std::string_view > & GetCompatibleExecutionProviders() const noexcept
const std::unordered_map< std::string, gsl::not_null< Graph * > > & GetAttributeNameToMutableSubgraphMap()
Definition: graph.h:442
GLuint const GLchar * name
Definition: glcorearb.h:786
#define ORT_RETURN_IF_ERROR(expr)
Definition: common.h:234
#define const
Definition: zconf.h:214
const std::string & Name() const noexcept
GraphTransformer(const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={}) noexcept