HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
rule_based_graph_transformer.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include "core/common/common.h"
10 
11 namespace onnxruntime {
12 
13 /**
14 @class RuleBasedGraphTransformer
15 
16 Rule-based graph transformer that provides an API to register rewrite rules
17 and an API to apply all applicable rules to a Graph.
18 
19 Represents an IGraphTransformer determined by a set of rewrite rules.
20 The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy.
21 Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
22 each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
23 
24 @TODO: Is a bottom-up traversal more efficient?
25 @TODO: Is it worth adding the max number of passes a rule should be applied for?
26 @TODO: We need to define a contract about whether a rewrite rule is allowed to leave
27  the graph in an inconsistent state (this will determine when and where we will be
28  calling Graph::resolve().
29 */
31  public:
33  const InlinedHashSet<std::string_view>& compatible_execution_providers = {})
34  : GraphTransformer(name, compatible_execution_providers) {}
35 
36  /** Registers a rewrite rule in this transformer. */
37  Status Register(std::unique_ptr<RewriteRule> rule);
38 
39  /** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
40  by this rule-based transformer.
41  @returns a pointer to the vector containing all the registered rewrite rules. */
43  auto rules = op_type_to_rules_.find(op_type);
44  return (rules != op_type_to_rules_.cend()) ? &rules->second : nullptr;
45  }
46 
47  /** Gets the rewrite rules that are evaluated on all nodes irrespective of their op type.
48  @returns a pointer to the vector containing all such rewrite rules or nullptr if no such rule. */
50  return &any_op_type_rules_;
51  }
52 
53  /** Returns the total number of rules that are registered in this transformer. */
54  size_t RulesCount() const;
55 
56  protected:
57  /** Applies the given set of rewrite rules on the Node of this Graph.
58  @param[in] graph The Graph.
59  @param[in] node The Node to apply the rules to.
60  @param[in] rules The vector of RewriteRules that will be applied to the Node.
61  @param[out] rule_effect Enum that indicates whether and how the graph was modified as a result of
62  applying rules on this node.
63  @returns Status indicating success or providing error information. */
65  gsl::span<const std::reference_wrapper<const RewriteRule>> rules,
66  RewriteRule::RewriteRuleEffect& rule_effect, const logging::Logger& logger) const;
67 
68  private:
69  using RuleEffect = RewriteRule::RewriteRuleEffect;
70 
71  // The list of unique pointers for all rules (so that rules can be registered for several op types).
73  // Map that associates a node's op type with the vector of rules that are registered to be triggered for that node.
75  // Rules that will be evaluated regardless of the op type of the node.
77 
78  // Performs a single top-down traversal of the graph and applies all registered rules.
79  common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
80 };
81 
82 } // namespace onnxruntime
const InlinedVector< std::reference_wrapper< const RewriteRule > > * GetRewriteRulesForOpType(const std::string &op_type) const
Status Register(std::unique_ptr< RewriteRule > rule)
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
absl::InlinedVector< T, N, Allocator > InlinedVector
GLuint const GLchar * name
Definition: glcorearb.h:786
common::Status ApplyRulesOnNode(Graph &graph, Node &node, gsl::span< const std::reference_wrapper< const RewriteRule >> rules, RewriteRule::RewriteRuleEffect &rule_effect, const logging::Logger &logger) const
RuleBasedGraphTransformer(const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={})
const InlinedVector< std::reference_wrapper< const RewriteRule > > * GetAnyOpRewriteRules() const
GLenum GLenum GLsizei void GLsizei void void * span
Definition: glad.h:5135
GraphTransformer(const std::string &name, const InlinedHashSet< std::string_view > &compatible_execution_providers={}) noexcept