HDK
Main Page
Related Pages
Modules
Namespaces
Classes
Files
Examples
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
rewrite_rule.h
Go to the documentation of this file.
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3
4
#pragma once
5
6
#include "
core/common/common.h
"
7
#include "
core/graph/graph_viewer.h
"
8
9
namespace
onnxruntime {
10
11
/**
12
@class RewriteRule
13
14
The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a
15
computation graph. It can be used to represent, for example, the elimination of operators that serve as
16
no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation
17
of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer,
18
a rewrite rule is a more local transformation that is triggered on a particular node of the graph.
19
20
Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule
21
to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented:
22
- SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first,
23
because those will lead to discarding fast rules that cannot be applied on a node.
24
- Apply is the actual body of the rule that will be executed if SatisfyCondition returns true for a particular
25
node. Note that additional, more complex checks can be included in the Apply if putting them in the
26
SatisfyCondition would lead to duplicate work (e.g., when we make a check on a Node attribute but we need
27
that attribute to execute the rule too).
28
In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added
29
in the Apply.
30
31
In order to avoid evaluating the SatisfyCondition for each rule and each node of the graph, each rewrite rule
32
should specify the target op types for which a rule will be evaluated, by overriding the TargetOpTypes() function.
33
If the op type of a node is not included in the target op types of a rule, that rule would not be considered at all.
34
If the list of op types is left empty, that rule will be triggered for every op type.
35
*/
36
class
RewriteRule
{
37
public
:
38
/**
39
@class RewriteRuleEffect
40
41
Class used to indicate the effect of rule application on a graph's node.
42
*/
43
enum class
RewriteRuleEffect
: uint8_t {
44
kNone
,
// The rewrite rule has not modified the graph.
45
kUpdatedCurrentNode
,
// The rewrite rule updated (but did not remove) the node on which it was triggered.
46
kRemovedCurrentNode
,
// The rewrite rule removed the node on which it was triggered.
47
kModifiedRestOfGraph
// The rewrite rule modified nodes other than the one it was triggered on.
48
};
49
50
RewriteRule
(
const
std::string
&
name
) : name_(name) {}
51
52
virtual
~RewriteRule
() =
default
;
53
54
/** Gets the name of this rewrite rule. */
55
const
std::string
&
Name
()
const
noexcept {
56
return
name_;
57
}
58
59
/** Returns the node op types for which this rule will be triggered. If the op type of a node is not included in the
60
target op types of a rule, that rule would not be considered at all. Returning an empty list indicates that we
61
will attempt to trigger the rule for every op type. */
62
virtual
std::vector<std::string>
TargetOpTypes
()
const
noexcept = 0;
63
64
/** Checks if the condition of the rule is satisfied, and if so applies the body of the rule.
65
@param[in] graph The Graph.
66
@param[in] node The Node to apply the rewrite to.
67
@param[out] rule_effect Enum to indicate if and how the graph was modified as a result of the rule application.
68
@returns Status indicating success or providing error information */
69
common::
Status
CheckConditionAndApply
(
Graph
& graph,
Node
& node,
RewriteRuleEffect
& rule_effect,
const
logging::Logger& logger)
const
{
70
return
SatisfyCondition(graph, node, logger) ? Apply(graph, node, rule_effect, logger) :
Status::OK
();
71
}
72
73
private
:
74
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(
RewriteRule
);
75
76
const
std::string
name_;
77
78
/** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be
79
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
80
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
81
of the nodes. */
82
virtual
bool
SatisfyCondition(
const
Graph
& graph,
const
Node
& node,
const
logging::Logger
& logger)
const
= 0;
83
84
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
85
The return-value of node may be different from the input-value due to rewriting.
86
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
87
virtual
common::Status
Apply(
Graph
& graph,
Node
& node,
RewriteRuleEffect
& rule_effect,
const
logging::Logger
& logger)
const
= 0;
88
};
89
}
// namespace onnxruntime
onnxruntime::RewriteRule::~RewriteRule
virtual ~RewriteRule()=default
onnxruntime::RewriteRule::RewriteRule
RewriteRule(const std::string &name)
Definition:
rewrite_rule.h:50
onnxruntime::RewriteRule::RewriteRuleEffect::kNone
onnxruntime::RewriteRule::RewriteRuleEffect::kRemovedCurrentNode
string
GLsizei const GLchar *const * string
Definition:
glcorearb.h:814
onnxruntime::RewriteRule::CheckConditionAndApply
common::Status CheckConditionAndApply(Graph &graph, Node &node, RewriteRuleEffect &rule_effect, const logging::Logger &logger) const
Definition:
rewrite_rule.h:69
graph_viewer.h
onnxruntime::RewriteRule
Definition:
rewrite_rule.h:36
onnxruntime::logging::Logger
Definition:
logging.h:210
common.h
RewriteRuleEffect
onnxruntime::Graph
Definition:
graph.h:659
onnxruntime::common::OK
Definition:
status.h:36
name
GLuint const GLchar * name
Definition:
glcorearb.h:786
GU_Flatten2::Status
Status
Definition:
GU_Flatten2.h:36
onnxruntime::RewriteRule::TargetOpTypes
virtual std::vector< std::string > TargetOpTypes() const noexcept=0
onnxruntime::Node
Definition:
graph.h:77
onnxruntime::common::Status
Definition:
status.h:114
onnxruntime::RewriteRule::RewriteRuleEffect::kModifiedRestOfGraph
const
#define const
Definition:
zconf.h:214
onnxruntime::RewriteRule::Name
const std::string & Name() const noexcept
Definition:
rewrite_rule.h:55
onnxruntime::RewriteRule::RewriteRuleEffect::kUpdatedCurrentNode
onnxruntime
core
optimizer
rewrite_rule.h
Generated on Thu May 9 2024 03:15:04 for HDK by
1.8.6