This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new 53251c8 [PatternLang] Convert PatternGrouper to do pre-order, non-recursive analysis (#5653) 53251c8 is described below commit 53251c87b2a7be53d00a968629bfc688585d8e4e Author: Matthew Brookhart <mbrookh...@octoml.ai> AuthorDate: Fri May 22 20:17:39 2020 -0700 [PatternLang] Convert PatternGrouper to do pre-order, non-recursive analysis (#5653) * make the PatternGrouper iterate over the input Expr in a non-recursive pre-order fasion * add a comment --- src/relay/ir/dataflow_matcher.cc | 43 +++++++++++++++++++---------- tests/python/relay/test_dataflow_pattern.py | 32 ++++++++++++++++++++- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 980935c..2f25733 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -43,6 +43,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} bool Match(const DFPattern& pattern, const Expr& expr); Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); } + const IndexedGraph<Expr> expr_graph_; protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -63,7 +64,6 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_; std::vector<DFPattern> matched_nodes_; - IndexedGraph<Expr> expr_graph_; bool memoize_ = true; }; @@ -291,7 +291,7 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as<CallNode>(); - for (auto node : expr_graph_.node_map_[expr]->inputs_) { + for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { if (!(call_node && node->ref_ == call_node->op)) { memoize_ = true; if (VisitDFPattern(op->parent, node->ref_)) { @@ -315,7 +315,7 @@ bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Exp while (!stack.empty()) { Expr current = stack.top(); stack.pop(); - for (auto node : expr_graph_.node_map_[current]->dominator_children_) { + for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { if (visited.count(node->ref_) == 0) { if (VisitDFPattern(op->parent, node->ref_)) { return true; @@ -412,7 +412,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern) * This is primarily needed to support the post-dominator analysis required for dominator pattern * matching. */ -class PatternGrouper : protected MixedModeVisitor { +class PatternGrouper { public: /* \brief Internal Group class for storing analysis */ struct Group { @@ -432,26 +432,39 @@ class PatternGrouper : protected MixedModeVisitor { const std::vector<Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) { groups_ = {Group()}; gid_assignments_.clear(); - visit_counter_.clear(); pattern_ = pattern; pattern_graph_ = CreateIndexedGraph(pattern_); auto matcher = DFPatternMatcher(pre); matcher_ = &matcher; - this->VisitExpr(pre); + this->VisitExprs(); return this->groups_; } protected: - using ExprVisitor::VisitExpr_; - void VisitLeaf(const Expr& pre) override { - if (matcher_->Match(pattern_, pre)) { - CreateGroup(pre); - } - } - void VisitExpr_(const FunctionNode* op) override { - if (op->attrs->dict.count(attr::kPartitionedFromPattern) == 0) { - ExprVisitor::VisitExpr_(op); + /* \brief Iteratively traverse the Expression in pre-order to find subgraphs + * + * If we traverse the graph in post-order, we can run into situtations where a small subgraph will + * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in + * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph + * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order + * traversal. + */ + void VisitExprs() { + std::unordered_set<Expr, ObjectHash, ObjectEqual> pre_partitioned; + for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + if (auto op = current.as<FunctionNode>()) { + if (op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + pre_partitioned.insert(current); + PostOrderVisit(op->body, + [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); + } + } + if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { + CreateGroup(current); + } } } /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 3a605e4..17c8df4 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -980,6 +980,36 @@ def test_partition_check_types(): relu = run_opt_pass(relu, relay.transform.InferType()) assert relu == pattern.partition(relu, check=check) +def test_partition_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + + conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + bias = conv2d.optional(lambda x: is_op('nn.bias_add')(x, wildcard())) + pattern1 = is_op('nn.relu')(bias) + + conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + bias = is_op('nn.bias_add')(conv2d, wildcard()) + pattern2 = bias.optional(lambda x: is_op('nn.relu')(x)) + + def conv_bias_relu(x, w, b): + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + return relu + relu = conv_bias_relu(x, w, b) + + xf = relay.var('x') + wf = relay.var('w') + bf = relay.var('b') + func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_") + + assert pattern1.match(relu) + assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu)) + + assert pattern2.match(relu) + assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu)) if __name__ == "__main__": test_match_op() @@ -1014,4 +1044,4 @@ if __name__ == "__main__": test_partition_double_batchnorm() test_partition_check() test_partition_check_types() - + test_partition_option()