This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 53251c8  [PatternLang] Convert PatternGrouper to do pre-order, 
non-recursive analysis (#5653)
53251c8 is described below

commit 53251c87b2a7be53d00a968629bfc688585d8e4e
Author: Matthew Brookhart <mbrookh...@octoml.ai>
AuthorDate: Fri May 22 20:17:39 2020 -0700

    [PatternLang] Convert PatternGrouper to do pre-order, non-recursive 
analysis (#5653)
    
    * make the PatternGrouper iterate over the input Expr in a non-recursive 
pre-order fasion
    
    * add a comment
---
 src/relay/ir/dataflow_matcher.cc            | 43 +++++++++++++++++++----------
 tests/python/relay/test_dataflow_pattern.py | 32 ++++++++++++++++++++-
 2 files changed, 59 insertions(+), 16 deletions(-)

diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 980935c..2f25733 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -43,6 +43,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const 
DFPattern&, const Ex
   explicit DFPatternMatcher(const Expr& root_expr) : 
expr_graph_(CreateIndexedGraph(root_expr)) {}
   bool Match(const DFPattern& pattern, const Expr& expr);
   Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, 
Array<Expr>>(memo_); }
+  const IndexedGraph<Expr> expr_graph_;
 
  protected:
   bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
@@ -63,7 +64,6 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const 
DFPattern&, const Ex
 
   std::unordered_map<DFPattern, Array<Expr>, ObjectHash, ObjectEqual> memo_;
   std::vector<DFPattern> matched_nodes_;
-  IndexedGraph<Expr> expr_graph_;
   bool memoize_ = true;
 };
 
@@ -291,7 +291,7 @@ bool DFPatternMatcher::VisitDFPattern_(const 
CallPatternNode* op, const Expr& ex
 // Recursively find the Dominator parent along all inputs paths.
 bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& 
expr) {
   auto call_node = expr.as<CallNode>();
-  for (auto node : expr_graph_.node_map_[expr]->inputs_) {
+  for (auto node : expr_graph_.node_map_.at(expr)->inputs_) {
     if (!(call_node && node->ref_ == call_node->op)) {
       memoize_ = true;
       if (VisitDFPattern(op->parent, node->ref_)) {
@@ -315,7 +315,7 @@ bool DFPatternMatcher::DominatesParent(const 
DominatorPatternNode* op, const Exp
   while (!stack.empty()) {
     Expr current = stack.top();
     stack.pop();
-    for (auto node : expr_graph_.node_map_[current]->dominator_children_) {
+    for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) {
       if (visited.count(node->ref_) == 0) {
         if (VisitDFPattern(op->parent, node->ref_)) {
           return true;
@@ -412,7 +412,7 @@ 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern)
  * This is primarily needed to support the post-dominator analysis required 
for dominator pattern
  * matching.
  */
-class PatternGrouper : protected MixedModeVisitor {
+class PatternGrouper {
  public:
   /* \brief Internal Group class for storing analysis */
   struct Group {
@@ -432,26 +432,39 @@ class PatternGrouper : protected MixedModeVisitor {
   const std::vector<Group>& GroupMatches(const DFPattern& pattern, const Expr& 
pre) {
     groups_ = {Group()};
     gid_assignments_.clear();
-    visit_counter_.clear();
 
     pattern_ = pattern;
     pattern_graph_ = CreateIndexedGraph(pattern_);
     auto matcher = DFPatternMatcher(pre);
     matcher_ = &matcher;
-    this->VisitExpr(pre);
+    this->VisitExprs();
     return this->groups_;
   }
 
  protected:
-  using ExprVisitor::VisitExpr_;
-  void VisitLeaf(const Expr& pre) override {
-    if (matcher_->Match(pattern_, pre)) {
-      CreateGroup(pre);
-    }
-  }
-  void VisitExpr_(const FunctionNode* op) override {
-    if (op->attrs->dict.count(attr::kPartitionedFromPattern) == 0) {
-      ExprVisitor::VisitExpr_(op);
+  /* \brief Iteratively traverse the Expression in pre-order to find subgraphs
+   *
+   * If we traverse the graph in post-order, we can run into situtations where 
a small subgraph will
+   * match the pattern. Due to options like AltPattern, a larger subgraph with 
more nodes later in
+   * the graph may also match the pattern. With post-order traversal, we mark 
the smaller subgraph
+   * as matched and fail to catch the larger subgraph. This problem is fixed 
by using pre-order
+   * traversal.
+   */
+  void VisitExprs() {
+    std::unordered_set<Expr, ObjectHash, ObjectEqual> pre_partitioned;
+    for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; 
--i) {
+      size_t index = i - 1;
+      Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
+      if (auto op = current.as<FunctionNode>()) {
+        if (op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
+          pre_partitioned.insert(current);
+          PostOrderVisit(op->body,
+                         [&pre_partitioned](const Expr& expr) { 
pre_partitioned.insert(expr); });
+        }
+      }
+      if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, 
current)) {
+        CreateGroup(current);
+      }
     }
   }
   /* \brief Creates a new set of nodes based on Group inputs, used to create 
functions and perform
diff --git a/tests/python/relay/test_dataflow_pattern.py 
b/tests/python/relay/test_dataflow_pattern.py
index 3a605e4..17c8df4 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -980,6 +980,36 @@ def test_partition_check_types():
     relu = run_opt_pass(relu, relay.transform.InferType())
     assert relu == pattern.partition(relu, check=check)
 
+def test_partition_option():
+    x = relay.var('x')
+    w = relay.var('w')
+    b = relay.var('b')
+
+    conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    bias = conv2d.optional(lambda x: is_op('nn.bias_add')(x, wildcard()))
+    pattern1 = is_op('nn.relu')(bias)
+
+    conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    bias = is_op('nn.bias_add')(conv2d, wildcard())
+    pattern2 = bias.optional(lambda x: is_op('nn.relu')(x))
+
+    def conv_bias_relu(x, w, b):
+        conv2d = relay.op.nn.conv2d(x, w)
+        bias_add = relay.op.nn.bias_add(conv2d, b)
+        relu = relay.op.nn.relu(bias_add)
+        return relu
+    relu = conv_bias_relu(x, w, b)
+
+    xf = relay.var('x')
+    wf = relay.var('w')
+    bf = relay.var('b')
+    func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, 
bf)).with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
+
+    assert pattern1.match(relu)
+    assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu))
+
+    assert pattern2.match(relu)
+    assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
 
 if __name__ == "__main__":
     test_match_op()
@@ -1014,4 +1044,4 @@ if __name__ == "__main__":
     test_partition_double_batchnorm()
     test_partition_check()
     test_partition_check_types()
-
+    test_partition_option()

Reply via email to