This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e7748aa [Relay] Extract dataflow matcher data structure into header
(#8774)
e7748aa is described below
commit e7748aac40bd4c263882323393ea8896837614a9
Author: Haichen Shen <[email protected]>
AuthorDate: Tue Aug 17 23:16:54 2021 -0700
[Relay] Extract dataflow matcher data structure into header (#8774)
* extract dataflow matcher data structure into a header file
* lint
* lint
---
src/relay/ir/dataflow_matcher.cc | 615 ++++++++++++++---------------------
src/relay/ir/dataflow_matcher_impl.h | 164 ++++++++++
2 files changed, 417 insertions(+), 362 deletions(-)
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 5ce06d9..d7f130f 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -29,50 +29,12 @@
#include <stack>
-#include "indexed_graph.h"
+#include "dataflow_matcher_impl.h"
namespace tvm {
namespace relay {
// Pattern Matcher
-
-class DominatorMatcher;
-
-class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const
Expr&)> {
- public:
- explicit DFPatternMatcher(const Expr& root_expr) :
expr_graph_(CreateIndexedGraph(root_expr)) {}
- bool Match(const DFPattern& pattern, const Expr& expr);
- Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern,
Array<Expr>>(memo_); }
- const IndexedGraph<Expr> expr_graph_;
-
- protected:
- bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
- bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr)
override;
- bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr)
override;
- bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr)
override;
- bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr)
override;
- bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr)
override;
- bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr)
override;
-
- void ClearMap(size_t watermark);
- bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
- bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
-
- std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>
memo_;
- std::vector<DFPattern> matched_nodes_;
- bool memoize_ = true;
-};
-
bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
memo_.clear();
matched_nodes_.clear();
@@ -542,304 +504,251 @@ bool MatchPattern(DFPattern pattern, Expr expr) {
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);
-/*!
- * \brief PatternGrouper does pre-rewriting pattern matching and analysis
- *
- * This class creates a number of groups of matched expressions, ensures they
don't overlap, and
- * returns them to the caller for post-analysis rewriting.
- *
- * This is primarily needed to support the post-dominator analysis required
for dominator pattern
- * matching.
- */
-class PatternGrouper {
+/*! \brief Creates a new set of nodes based on Group inputs, used to create
functions and perform
+ * group overlap analysis */
+class MatchExtractor : public ExprMutator {
public:
- /*! \brief Internal Group class for storing analysis */
- struct Group {
- Expr root_node;
- int gid;
- Map<DFPattern, Array<Expr>> matched_nodes;
- std::string name;
- Function function;
- Array<Expr> args;
- };
-
- /*! \brief Return the group assignments of expressions */
- const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>&
GetGIDAssignments() {
- return gid_assignments_;
+ explicit MatchExtractor(
+ const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>&
inputs)
+ : inputs_(inputs) {}
+ const std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>&
GetMemo() {
+ return this->memo_;
}
- /*! \brief Group expressions that match the pattern */
- const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern,
const Expr& pre) {
- groups_.clear();
- gid_assignments_.clear();
+ const std::string& GetName() { return name_; }
- pattern_ = pattern;
- pattern_graph_ = CreateIndexedGraph(pattern_);
- auto matcher = DFPatternMatcher(pre);
- matcher_ = &matcher;
- this->VisitExprs();
- return this->groups_;
+ protected:
+ Expr VisitExpr(const Expr& pre) override {
+ if (inputs_.count(pre)) {
+ return inputs_.at(pre);
+ }
+ return ExprMutator::VisitExpr(pre);
}
+ Expr VisitExpr_(const TupleNode* op) override {
+ auto out = ExprMutator::VisitExpr_(op);
+ name_ += "Tuple_";
+ return out;
+ };
+ Expr VisitExpr_(const FunctionNode* op) override {
+ auto out = ExprMutator::VisitExpr_(op);
+ name_ += "Function";
+ return out;
+ };
+ Expr VisitExpr_(const CallNode* call_node) override {
+ auto out = ExprMutator::VisitExpr_(call_node);
+ if (auto operation = call_node->op.as<OpNode>()) {
+ name_ += operation->name + "_";
+ } else {
+ name_ += "Call_";
+ }
+ return out;
+ };
+ Expr VisitExpr_(const LetNode* op) override {
+ auto out = ExprMutator::VisitExpr_(op);
+ name_ += "Let_";
+ return out;
+ };
+ Expr VisitExpr_(const IfNode* op) override {
+ auto out = ExprMutator::VisitExpr_(op);
+ name_ += "If_";
+ return out;
+ };
+ Expr VisitExpr_(const TupleGetItemNode* op) override {
+ auto out = ExprMutator::VisitExpr_(op);
+ name_ += "TupleGetItem" + std::to_string(op->index) + "_";
+ return out;
+ };
+ Expr VisitExpr_(const MatchNode* op) override {
+ auto out = ExprMutator::VisitExpr_(op);
+ name_ += "Match_";
+ return out;
+ };
+ std::string name_;
+ const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
+};
- protected:
- /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
- *
- * If we traverse the graph in post-order, we can run into situtations where
a small subgraph will
- * match the pattern. Due to options like AltPattern, a larger subgraph with
more nodes later in
- * the graph may also match the pattern. With post-order traversal, we mark
the smaller subgraph
- * as matched and fail to catch the larger subgraph. This problem is fixed
by using pre-order
- * traversal.
- */
- void VisitExprs() {
- std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
- for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0;
--i) {
- size_t index = i - 1;
- Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
- if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've
already grouped
- if (auto op = current.as<FunctionNode>()) {
- if (op->attrs.defined() &&
op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
- pre_partitioned.insert(current);
- PostOrderVisit(op->body,
- [&pre_partitioned](const Expr& expr) {
pre_partitioned.insert(expr); });
- }
- }
- if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_,
current)) {
- CreateGroup(current);
+/*! \brief Group expressions that match the pattern */
+const std::unordered_map<int, PatternGrouper::Group>&
PatternGrouper::GroupMatches(
+ const DFPattern& pattern, const Expr& pre) {
+ groups_.clear();
+ gid_assignments_.clear();
+
+ pattern_ = pattern;
+ pattern_graph_ = CreateIndexedGraph(pattern_);
+ auto matcher = DFPatternMatcher(pre);
+ matcher_ = &matcher;
+ this->VisitExprs();
+ return this->groups_;
+}
+
+void PatternGrouper::VisitExprs() {
+ std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
+ for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0;
--i) {
+ size_t index = i - 1;
+ Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
+ if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've
already grouped
+ if (auto op = current.as<FunctionNode>()) {
+ if (op->attrs.defined() &&
op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
+ pre_partitioned.insert(current);
+ PostOrderVisit(op->body,
+ [&pre_partitioned](const Expr& expr) {
pre_partitioned.insert(expr); });
}
}
+ if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_,
current)) {
+ CreateGroup(current);
+ }
}
}
- /*! \brief Creates a new set of nodes based on Group inputs, used to create
functions and perform
- * group overlap analysis */
- class MatchExtractor : public ExprMutator {
- public:
- explicit MatchExtractor(
- const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>&
inputs)
- : inputs_(inputs) {}
- const std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>&
GetMemo() {
- return this->memo_;
- }
- const std::string& GetName() { return name_; }
+}
- protected:
- Expr VisitExpr(const Expr& pre) override {
- if (inputs_.count(pre)) {
- return inputs_.at(pre);
+void PatternGrouper::CreateGroup(const Expr& expr) {
+ int var_number = 0;
+
+ auto node_map = matcher_->GetMemo();
+ // Get fuzzy patterns
+ std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
+ for (auto node : pattern_graph_.topological_order_) {
+ // Don't treat fuzzy Dominator patterns input variables for partition
+ if (auto op = node->ref_.as<DominatorPatternNode>()) {
+ for (auto fuzzy_op : {op->parent, op->path}) {
+ for (auto match : node_map[fuzzy_op]) {
+ fuzzy_matches.insert(match);
+ }
}
- return ExprMutator::VisitExpr(pre);
}
- Expr VisitExpr_(const TupleNode* op) override {
- auto out = ExprMutator::VisitExpr_(op);
- name_ += "Tuple_";
- return out;
- };
- Expr VisitExpr_(const FunctionNode* op) override {
- auto out = ExprMutator::VisitExpr_(op);
- name_ += "Function";
- return out;
- };
- Expr VisitExpr_(const CallNode* call_node) override {
- auto out = ExprMutator::VisitExpr_(call_node);
- if (auto operation = call_node->op.as<OpNode>()) {
- name_ += operation->name + "_";
- } else {
- name_ += "Call_";
+ // Don't treat Function params or body as input variables for partition
+ if (node->ref_.as<FunctionPatternNode>()) {
+ auto matches = node_map[node->ref_];
+ for (auto match : matches) {
+ auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
+ for (auto node : graph.topological_order_) {
+ fuzzy_matches.insert(node->ref_);
+ }
}
- return out;
- };
- Expr VisitExpr_(const LetNode* op) override {
- auto out = ExprMutator::VisitExpr_(op);
- name_ += "Let_";
- return out;
- };
- Expr VisitExpr_(const IfNode* op) override {
- auto out = ExprMutator::VisitExpr_(op);
- name_ += "If_";
- return out;
- };
- Expr VisitExpr_(const TupleGetItemNode* op) override {
- auto out = ExprMutator::VisitExpr_(op);
- name_ += "TupleGetItem" + std::to_string(op->index) + "_";
- return out;
- };
- Expr VisitExpr_(const MatchNode* op) override {
- auto out = ExprMutator::VisitExpr_(op);
- name_ += "Match_";
- return out;
- };
- std::string name_;
- const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
- };
+ }
+ }
- /*! \brief Create a group based on a matched expression */
- void CreateGroup(const Expr& expr) {
- int var_number = 0;
-
- auto node_map = matcher_->GetMemo();
- // Get fuzzy patterns
- std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
- for (auto node : pattern_graph_.topological_order_) {
- // Don't treat fuzzy Dominator patterns input variables for partition
- if (auto op = node->ref_.as<DominatorPatternNode>()) {
- for (auto fuzzy_op : {op->parent, op->path}) {
- for (auto match : node_map[fuzzy_op]) {
- fuzzy_matches.insert(match);
- }
- }
+ // Create input variables
+ Group group;
+ group.root_node = expr;
+ group.matched_nodes = node_map;
+
+ std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
+ Array<Var> params;
+
+ for (auto node : pattern_graph_.topological_order_) {
+ auto make_input = [&](const Expr& input) {
+ if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
+ input.as<FunctionNode>() == nullptr && !EmbedConst(input,
node->ref_)) {
+ inputs[input] =
+ Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
std::to_string(var_number),
+ NullValue<Type>());
+ group.args.push_back(input);
+ params.push_back(inputs[input]);
+ var_number++;
}
- // Don't treat Function params or body as input variables for partition
- if (node->ref_.as<FunctionPatternNode>()) {
+ };
+ auto tuple = node->ref_.as<TuplePatternNode>();
+ auto call = node->ref_.as<CallPatternNode>();
+ if (tuple && !tuple->fields.defined()) {
+ if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
- auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
- for (auto node : graph.topological_order_) {
- fuzzy_matches.insert(node->ref_);
+ for (auto input : match.as<TupleNode>()->fields) {
+ make_input(input);
}
}
}
- }
-
- // Create input variables
- Group group;
- group.root_node = expr;
- group.matched_nodes = node_map;
-
- std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
- Array<Var> params;
-
- for (auto node : pattern_graph_.topological_order_) {
- auto make_input = [&](const Expr& input) {
- if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
- input.as<FunctionNode>() == nullptr && !EmbedConst(input,
node->ref_)) {
- inputs[input] =
- Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
std::to_string(var_number),
- NullValue<Type>());
- group.args.push_back(input);
- params.push_back(inputs[input]);
- var_number++;
- }
- };
- auto tuple = node->ref_.as<TuplePatternNode>();
- auto call = node->ref_.as<CallPatternNode>();
- if (tuple && !tuple->fields.defined()) {
- if (node_map.count(node->ref_)) {
- auto matches = node_map[node->ref_];
- for (auto match : matches) {
- for (auto input : match.as<TupleNode>()->fields) {
- make_input(input);
- }
- }
- }
- } else if (call && !call->args.defined()) {
- if (node_map.count(node->ref_)) {
- auto matches = node_map[node->ref_];
- for (auto match : matches) {
- for (auto input : match.as<CallNode>()->args) {
- make_input(input);
- }
+ } else if (call && !call->args.defined()) {
+ if (node_map.count(node->ref_)) {
+ auto matches = node_map[node->ref_];
+ for (auto match : matches) {
+ for (auto input : match.as<CallNode>()->args) {
+ make_input(input);
}
}
- } else if (node->inputs_.size() == 0) {
- if (node_map.count(node->ref_)) {
- auto matches = node_map[node->ref_];
- for (auto match : matches) {
- make_input(match);
- }
+ }
+ } else if (node->inputs_.size() == 0) {
+ if (node_map.count(node->ref_)) {
+ auto matches = node_map[node->ref_];
+ for (auto match : matches) {
+ make_input(match);
}
}
}
+ }
- graph_number_++;
-
- // Extract a Function. Used in Partition directly,
- // used to determine Group overlap in other passes
- auto extractor = MatchExtractor(inputs);
- auto body = extractor.Mutate(expr);
-
- group.function = Function(params, body, NullValue<Type>(),
Array<TypeVar>());
- group.name = extractor.GetName();
- // Check to make sure we aren't overlapping with another group or creating
an invalid fusion
- // The MatchExtractor will create a new graph by replacing nodes that
match the inputs of the
- // pattern with the input FunctionVar* Variables. The resulting
memoization map will only
- // contain nodes in the expression that matched the pattern. If a
non-input node of the pattern
- // (i.e., some piece of computation) overlaps with the nodes in a previous
group, we'll have a
- // situation where we try to rewrite the same node twice in the second
rewriting or parition
- // pass. This isn't valid, so we check for it here. We ignore Ops,
functions, and constants
- // because they exist more globally outside of the fusion.
- // Similiarly, if interior nodes in a group are used outside of the group
fusing to a single
- // output would create an invalid graph tranformation, so we block the
creation of such groups.
- auto memo = extractor.GetMemo();
- for (auto kv : memo) {
- // Check to ensure that this node isn't an input or a global
- if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
- kv.first.as<FunctionNode>() == nullptr &&
kv.first.as<ConstantNode>() == nullptr) {
- if (gid_assignments_.count(kv.first) != 0) {
- // check to see if the node is use in other groups
- // Exit due to overlapping partitions
- return;
- } else if (kv.second != body) {
- // if the node isn't the output of the group
- auto node = matcher_->expr_graph_.node_map_.at(kv.first);
- for (auto* output : node->outputs_) {
- // and the node is used by nodes outside of the group
- if (memo.count(output->ref_) == 0 &&
- !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
- // Exit because nodes in this pattern's body are used outside
the pattern
- // fusing it would be invalid
- return;
- }
+ graph_number_++;
+
+ // Extract a Function. Used in Partition directly,
+ // used to determine Group overlap in other passes
+ auto extractor = MatchExtractor(inputs);
+ auto body = extractor.Mutate(expr);
+
+ group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
+ group.name = extractor.GetName();
+ // Check to make sure we aren't overlapping with another group or creating
an invalid fusion
+ // The MatchExtractor will create a new graph by replacing nodes that match
the inputs of the
+ // pattern with the input FunctionVar* Variables. The resulting memoization
map will only
+ // contain nodes in the expression that matched the pattern. If a non-input
node of the pattern
+ // (i.e., some piece of computation) overlaps with the nodes in a previous
group, we'll have a
+ // situation where we try to rewrite the same node twice in the second
rewriting or parition
+ // pass. This isn't valid, so we check for it here. We ignore Ops,
functions, and constants
+ // because they exist more globally outside of the fusion.
+ // Similiarly, if interior nodes in a group are used outside of the group
fusing to a single
+ // output would create an invalid graph tranformation, so we block the
creation of such groups.
+ auto memo = extractor.GetMemo();
+ for (auto kv : memo) {
+ // Check to ensure that this node isn't an input or a global
+ if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
+ kv.first.as<FunctionNode>() == nullptr && kv.first.as<ConstantNode>()
== nullptr) {
+ if (gid_assignments_.count(kv.first) != 0) {
+ // check to see if the node is use in other groups
+ // Exit due to overlapping partitions
+ return;
+ } else if (kv.second != body) {
+ // if the node isn't the output of the group
+ auto node = matcher_->expr_graph_.node_map_.at(kv.first);
+ for (auto* output : node->outputs_) {
+ // and the node is used by nodes outside of the group
+ if (memo.count(output->ref_) == 0 &&
+ !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
+ // Exit because nodes in this pattern's body are used outside the
pattern
+ // fusing it would be invalid
+ return;
}
}
}
}
- // Assign Group Ids
- group.gid = ++gid_;
- for (auto kv : extractor.GetMemo()) {
- gid_assignments_[kv.first] = gid_;
- }
+ }
+ // Assign Group Ids
+ group.gid = ++gid_;
+ for (auto kv : extractor.GetMemo()) {
+ gid_assignments_[kv.first] = gid_;
+ }
+
+ // Save Group
+ groups_[group.gid] = std::move(group);
+}
- // Save Group
- groups_[group.gid] = std::move(group);
- }
-
- /*! \brief EmbedConst implements rules for embedding constants into
partitioned functions or
- * lifting them into the function arguments.
- *
- * The rules depend on what pattern the ConstantNode matched.
- *
- * The basic rules are:
- * If the constant matches ExprPattern(relay.const(*)) or a
ConstantPattern(), embed the constant
- * in the partitioned function. If the constant matched an AltPattern,
recursively check the
- * matched side of the pattern. For any other matching pattern (i.e,
wildcard, VarPattern, etc),
- * lift the constant into the arguments of the partitioned function.
- */
- bool EmbedConst(const Expr& expr, const DFPattern pattern) {
- bool embed = false;
- if (expr.as<ConstantNode>()) {
- if (pattern.as<ConstantPatternNode>() != nullptr) {
+bool PatternGrouper::EmbedConst(const Expr& expr, const DFPattern pattern) {
+ bool embed = false;
+ if (expr.as<ConstantNode>()) {
+ if (pattern.as<ConstantPatternNode>() != nullptr) {
+ embed = true;
+ } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
+ if (expr_pat->expr.as<ConstantNode>()) {
embed = true;
- } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
- if (expr_pat->expr.as<ConstantNode>()) {
- embed = true;
- }
- } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
- if (matcher_->Match(alt_pat->left, expr)) {
- embed = EmbedConst(expr, alt_pat->left);
- } else {
- embed = EmbedConst(expr, alt_pat->right);
- }
+ }
+ } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
+ if (matcher_->Match(alt_pat->left, expr)) {
+ embed = EmbedConst(expr, alt_pat->left);
+ } else {
+ embed = EmbedConst(expr, alt_pat->right);
}
}
- return embed;
}
- // Internal State
- DFPattern pattern_;
- std::unordered_map<int, Group> groups_;
- std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>
gid_assignments_;
- DFPatternMatcher* matcher_ = nullptr;
- IndexedGraph<DFPattern> pattern_graph_;
- int gid_ = 0;
- int graph_number_ = 0;
-};
+ return embed;
+}
// Rewrite
@@ -858,72 +767,54 @@
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback")
return DFPatternCallback(pattern, function, require_type);
});
-/*!
- * \brief PatternRewriter rewrites the expression by finding matches and
allowing user callback
- * function to rewrite those matches
- *
- * The class uses PatternGrouper to support the dominator pattern.
- */
-class PatternRewriter : protected MixedModeMutator {
- public:
- PatternRewriter(IRModule mod) : mod_(mod) {}
- /*! \brief Rewrite can take a number of callbacks and will repeatedly
rewrite the graph with the
- * callbacks until it stops changing */
- Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre) {
- auto post = pre;
- auto last = post;
- // rewrite the graph until it stops changing to make sure all rewrites are
complete
- int count = 0;
- bool equal = true;
- static auto* structural_equal =
runtime::Registry::Get("node.StructuralEqual");
- ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
- do {
- last = post;
- for (auto callback : callbacks) {
- callback_ = callback;
- if (callback_->require_type) {
- post = InferTypeWithModule(post, mod_);
- }
- auto grouper = PatternGrouper();
- groups_ = grouper.GroupMatches(callback_->pattern, post);
- gid_assignments_ = grouper.GetGIDAssignments();
- memo_.clear();
- post = this->VisitExpr(post);
- count++;
- }
- equal = (*structural_equal)(last, post, false, true);
- } while (!equal && count < 100);
- if (count >= 100) {
- LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting
passes?";
+Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const
Expr& pre) {
+ auto post = pre;
+ auto last = post;
+ // rewrite the graph until it stops changing to make sure all rewrites are
complete
+ int count = 0;
+ bool equal = true;
+ static auto* structural_equal =
runtime::Registry::Get("node.StructuralEqual");
+ ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
+ do {
+ last = post;
+ for (auto callback : callbacks) {
+ callback_ = callback;
+ if (callback_->require_type) {
+ post = InferTypeWithModule(post, mod_);
+ }
+ auto grouper = PatternGrouper();
+ groups_ = grouper.GroupMatches(callback_->pattern, post);
+ gid_assignments_ = grouper.GetGIDAssignments();
+ memo_.clear();
+ post = this->VisitExpr(post);
+ count++;
}
- return post;
+ equal = (*structural_equal)(last, post, false, true);
+ } while (!equal && count < 100);
+ if (count >= 100) {
+ LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
}
+ return post;
+}
- protected:
- Expr DispatchVisitExpr(const Expr& pre) override {
- auto post = MixedModeMutator::DispatchVisitExpr(pre);
- if (gid_assignments_.count(pre) && pre ==
groups_[gid_assignments_[pre]].root_node) {
- // Convert the pre-rewrite node map to a post-rewrite node map
- auto group = groups_[gid_assignments_[pre]];
- std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash,
ObjectPtrEqual> node_map;
- for (auto kv : group.matched_nodes) {
- Array<Expr> tmp;
- for (size_t i = 0; i < kv.second.size(); ++i) {
- tmp.push_back(this->memo_[kv.second[i]]);
- }
- node_map.insert({kv.first, tmp});
+Expr PatternRewriter::DispatchVisitExpr(const Expr& pre) {
+ auto post = MixedModeMutator::DispatchVisitExpr(pre);
+ if (gid_assignments_.count(pre) && pre ==
groups_[gid_assignments_[pre]].root_node) {
+ // Convert the pre-rewrite node map to a post-rewrite node map
+ auto group = groups_[gid_assignments_[pre]];
+ std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>
node_map;
+ for (auto kv : group.matched_nodes) {
+ Array<Expr> tmp;
+ for (size_t i = 0; i < kv.second.size(); ++i) {
+ tmp.push_back(this->memo_[kv.second[i]]);
}
- // run the user callback function
- return callback_->function(pre, post, Map<DFPattern,
Array<Expr>>(node_map));
+ node_map.insert({kv.first, tmp});
}
- return post;
+ // run the user callback function
+ return callback_->function(pre, post, Map<DFPattern,
Array<Expr>>(node_map));
}
-
- IRModule mod_;
- DFPatternCallback callback_;
- std::unordered_map<int, PatternGrouper::Group> groups_;
- std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>
gid_assignments_;
-};
+ return post;
+}
Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule
mod) {
return PatternRewriter(mod).Rewrite(callbacks, expr);
diff --git a/src/relay/ir/dataflow_matcher_impl.h
b/src/relay/ir/dataflow_matcher_impl.h
new file mode 100644
index 0000000..d993d47
--- /dev/null
+++ b/src/relay/ir/dataflow_matcher_impl.h
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tvm/relay/dataflow_matcher_impl.h
+ * \brief The auxiliary data structure for dataflow matcher.
+ */
+#ifndef TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
+#define TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/dataflow_pattern_functor.h>
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "indexed_graph.h"
+
+namespace tvm {
+namespace relay {
+
+class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const
Expr&)> {
+ public:
+ explicit DFPatternMatcher(const Expr& root_expr) :
expr_graph_(CreateIndexedGraph(root_expr)) {}
+ bool Match(const DFPattern& pattern, const Expr& expr);
+ Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern,
Array<Expr>>(memo_); }
+ const IndexedGraph<Expr> expr_graph_;
+
+ protected:
+ bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
+ bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr)
override;
+ bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr)
override;
+ bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr)
override;
+ bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr)
override;
+ bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr)
override;
+ bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr)
override;
+
+ void ClearMap(size_t watermark);
+ bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
+ bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
+
+ std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>
memo_;
+ std::vector<DFPattern> matched_nodes_;
+ bool memoize_ = true;
+};
+
+/*!
+ * \brief PatternGrouper does pre-rewriting pattern matching and analysis
+ *
+ * This class creates a number of groups of matched expressions, ensures they
don't overlap, and
+ * returns them to the caller for post-analysis rewriting.
+ *
+ * This is primarily needed to support the post-dominator analysis required
for dominator pattern
+ * matching.
+ */
+class PatternGrouper {
+ public:
+ /*! \brief Internal Group class for storing analysis */
+ struct Group {
+ Expr root_node;
+ int gid;
+ Map<DFPattern, Array<Expr>> matched_nodes;
+ std::string name;
+ Function function;
+ Array<Expr> args;
+ };
+
+ /*! \brief Return the group assignments of expressions */
+ inline const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>&
GetGIDAssignments() {
+ return gid_assignments_;
+ }
+ /*! \brief Group expressions that match the pattern */
+ const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern,
const Expr& pre);
+
+ protected:
+ /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
+ *
+ * If we traverse the graph in post-order, we can run into situtations where
a small subgraph will
+ * match the pattern. Due to options like AltPattern, a larger subgraph with
more nodes later in
+ * the graph may also match the pattern. With post-order traversal, we mark
the smaller subgraph
+ * as matched and fail to catch the larger subgraph. This problem is fixed
by using pre-order
+ * traversal.
+ */
+ void VisitExprs();
+
+ /*! \brief Create a group based on a matched expression */
+ void CreateGroup(const Expr& expr);
+
+ /*! \brief EmbedConst implements rules for embedding constants into
partitioned functions or
+ * lifting them into the function arguments.
+ *
+ * The rules depend on what pattern the ConstantNode matched.
+ *
+ * The basic rules are:
+ * If the constant matches ExprPattern(relay.const(*)) or a
ConstantPattern(), embed the constant
+ * in the partitioned function. If the constant matched an AltPattern,
recursively check the
+ * matched side of the pattern. For any other matching pattern (i.e,
wildcard, VarPattern, etc),
+ * lift the constant into the arguments of the partitioned function.
+ */
+ bool EmbedConst(const Expr& expr, const DFPattern pattern);
+ // Internal State
+ DFPattern pattern_;
+ std::unordered_map<int, Group> groups_;
+ std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>
gid_assignments_;
+ DFPatternMatcher* matcher_ = nullptr;
+ IndexedGraph<DFPattern> pattern_graph_;
+ int gid_ = 0;
+ int graph_number_ = 0;
+};
+
+/*!
+ * \brief PatternRewriter rewrites the expression by finding matches and
allowing user callback
+ * function to rewrite those matches
+ *
+ * The class uses PatternGrouper to support the dominator pattern.
+ */
+class PatternRewriter : protected MixedModeMutator {
+ public:
+ explicit PatternRewriter(IRModule mod) : mod_(mod) {}
+ /*! \brief Rewrite can take a number of callbacks and will repeatedly
rewrite the graph with the
+ * callbacks until it stops changing */
+ virtual Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr&
pre);
+
+ protected:
+ virtual Expr DispatchVisitExpr(const Expr& pre);
+
+ IRModule mod_;
+ DFPatternCallback callback_;
+ std::unordered_map<int, PatternGrouper::Group> groups_;
+ std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>
gid_assignments_;
+};
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_