This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e68b4d  [PatternMatcher] Support matching tuples, call nodes, and 
functions with variable numbers of inputs (#7754)
7e68b4d is described below

commit 7e68b4d413dfc0ec0042634a40d94e9bab59dbde
Author: Matthew Brookhart <[email protected]>
AuthorDate: Sat Apr 3 03:12:13 2021 -0600

    [PatternMatcher] Support matching tuples, call nodes, and functions with 
variable numbers of inputs (#7754)
    
    * Allow TuplePattern to have null fields and match any tuple
    
    * support matching functions and call nodes with variable numbers of 
parameters
    
    * remove development code that was commented out
    
    * add docs for fuzzy matching
---
 docs/langref/relay_pattern.rst                |  16 ++++
 python/tvm/relay/dataflow_pattern/__init__.py |   5 +-
 src/relay/ir/dataflow_matcher.cc              | 107 ++++++++++++++++++--------
 src/relay/ir/dataflow_pattern_functor.cc      |  18 +++--
 src/relay/ir/indexed_graph.cc                 |  18 +++--
 tests/python/relay/test_dataflow_pattern.py   |  80 +++++++++++++++++--
 6 files changed, 192 insertions(+), 52 deletions(-)

diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index d77a519..efb9804 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -307,6 +307,22 @@ The final example is matching diamonds with a 
post-dominator relationship. We em
         assert diamond.match(out)
 
 
+Matching Fuzzy Patterns
+=======================
+
+The Dominator analysis above lets one match a subgraph of Relay AST that 
doesn't correspond to a set of patterns nodes exactly 1-to-1. There are a few 
other places where we support such "fuzzy" matching.
+
+Tuples, Functions, and Call nodes with any number of inputs can be matched by 
passing `None` as the argument value, i.e.::
+
+    tuple_pattern = is_tuple(None)
+    func_pattern = FunctionPattern(None, wildcard() + wildcard())
+    call_pattern = func_pattern(None)
+
+These patterns allow matching more generic classes patterns by constraining 
the use of the arguments rather than the number of arguments.
+
+Additionally, we support matching Functions with fuzzy bodies, i.e., a 
function body that is under constrained by the pattern. The pattern 
`FunctionPattern([is_var(), is_var()], wildcard() + wildcard()])` will match 
`relay.Function([x, y], x + y)`, but it will also match `relay.Function([x, y], 
x * x + y)`. In the second case, the pattern doesn't perfectly constrain the 
body of the function, so the resulting match is fuzzy.
+
+
 Pattern Language Design
 =======================
 
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py 
b/python/tvm/relay/dataflow_pattern/__init__.py
index d4a8481..b368f4e 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -47,7 +47,10 @@ class DFPattern(Node):
     """Base class of all Patterns."""
 
     def __call__(self, *args):
-        return CallPattern(self, list(args))
+        args = list(args)
+        if len(args) == 1 and args[0] is None:
+            args = None
+        return CallPattern(self, args)
 
     def __or__(self, other):
         return AltPattern(self, other)
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 43a6473..6ed24d5 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -242,6 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const 
CallPatternNode* op, const Expr& ex
     }
     return false;
   };
+
   // logic
   auto watermark = matched_nodes_.size();
   if (const auto* call_node = expr.as<CallNode>()) {
@@ -253,13 +254,15 @@ bool DFPatternMatcher::VisitDFPattern_(const 
CallPatternNode* op, const Expr& ex
                                             const Array<Expr> expr_args) {
         bool matches = true;
         size_t i = 0;
-        if (pattern_args.size() == expr_args.size()) {
-          while (matches && i < pattern_args.size()) {
-            matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
-            ++i;
+        if (pattern_args.defined()) {
+          if (pattern_args.size() == expr_args.size()) {
+            while (matches && i < pattern_args.size()) {
+              matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+              ++i;
+            }
+          } else {
+            matches = false;
           }
-        } else {
-          matches = false;
         }
         if (!matches) {
           ClearMap(watermark2);
@@ -381,14 +384,16 @@ bool DFPatternMatcher::VisitDFPattern_(const 
FunctionPatternNode* op, const Expr
   bool matches = false;
   if (const auto* func = expr.as<FunctionNode>()) {
     matches = true;
-    size_t i = 0;
-    if (op->params.size() == func->params.size()) {
-      while (matches && i < op->params.size()) {
-        matches &= VisitDFPattern(op->params[i], func->params[i]);
-        ++i;
+    if (op->params.defined()) {
+      size_t i = 0;
+      if (op->params.size() == func->params.size()) {
+        while (matches && i < op->params.size()) {
+          matches &= VisitDFPattern(op->params[i], func->params[i]);
+          ++i;
+        }
+      } else {
+        matches = false;
       }
-    } else {
-      matches = false;
     }
     if (matches) {
       matches &= VisitDFPattern(op->body, func->body);
@@ -409,12 +414,16 @@ bool DFPatternMatcher::VisitDFPattern_(const 
TupleGetItemPatternNode* op, const
 bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& 
expr) {
   bool matches = false;
   if (const auto* tuple_node = expr.as<TupleNode>()) {
-    if (op->fields.size() == tuple_node->fields.size()) {
-      matches = true;
-      size_t i = 0;
-      while (matches && i < op->fields.size()) {
-        matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
-        ++i;
+    matches = true;
+    if (op->fields.defined()) {
+      if (op->fields.size() == tuple_node->fields.size()) {
+        size_t i = 0;
+        while (matches && i < op->fields.size()) {
+          matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+          ++i;
+        }
+      } else {
+        matches = false;
       }
     }
   }
@@ -657,7 +666,6 @@ class PatternGrouper {
     int var_number = 0;
 
     auto node_map = matcher_->GetMemo();
-
     // Get fuzzy patterns
     std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
     for (auto node : pattern_graph_.topological_order_) {
@@ -669,11 +677,13 @@ class PatternGrouper {
           }
         }
       }
-      // Don't treat Function params as input variables for partition
-      if (auto op = node->ref_.as<FunctionPatternNode>()) {
-        for (auto fuzzy_op : op->params) {
-          for (auto match : node_map[fuzzy_op]) {
-            fuzzy_matches.insert(match);
+      // Don't treat Function params or body as input variables for partition
+      if (node->ref_.as<FunctionPatternNode>()) {
+        auto matches = node_map[node->ref_];
+        for (auto match : matches) {
+          auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
+          for (auto node : graph.topological_order_) {
+            fuzzy_matches.insert(node->ref_);
           }
         }
       }
@@ -686,22 +696,46 @@ class PatternGrouper {
 
     std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
     Array<Var> params;
+
     for (auto node : pattern_graph_.topological_order_) {
-      if (node->inputs_.size() == 0) {
+      auto make_input = [&](const Expr& input) {
+        if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
+            input.as<FunctionNode>() == nullptr && !EmbedConst(input, 
node->ref_)) {
+          inputs[input] =
+              Var("FunctionVar_" + std::to_string(graph_number_) + "_" + 
std::to_string(var_number),
+                  NullValue<Type>());
+          group.args.push_back(input);
+          params.push_back(inputs[input]);
+          var_number++;
+        }
+      };
+      auto tuple = node->ref_.as<TuplePatternNode>();
+      auto call = node->ref_.as<CallPatternNode>();
+      if (tuple && !tuple->fields.defined()) {
         if (node_map.count(node->ref_)) {
           auto matches = node_map[node->ref_];
           for (auto match : matches) {
-            if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == 
nullptr &&
-                match.as<FunctionNode>() == nullptr && !EmbedConst(match, 
node->ref_)) {
-              inputs[match] = Var(
-                  "FunctionVar_" + std::to_string(graph_number_) + "_" + 
std::to_string(var_number),
-                  NullValue<Type>());
-              group.args.push_back(match);
-              params.push_back(inputs[match]);
-              var_number++;
+            for (auto input : match.as<TupleNode>()->fields) {
+              make_input(input);
             }
           }
         }
+      } else if (call && !call->args.defined()) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            for (auto input : match.as<CallNode>()->args) {
+              make_input(input);
+            }
+          }
+        }
+      } else if (node->inputs_.size() == 0) {
+        if (node_map.count(node->ref_)) {
+          auto matches = node_map[node->ref_];
+          for (auto match : matches) {
+            make_input(match);
+          }
+        }
       }
     }
 
@@ -898,6 +932,11 @@ class PatternPartitioner : protected MixedModeMutator {
  public:
   Expr Partition(const DFPattern& pattern, const Expr& pre, const Map<String, 
ObjectRef>& attrs,
                  PackedFunc check) {
+    if (pattern.as<FunctionPatternNode>()) {
+      LOG(WARNING) << "Partioning a Function that isn't called doesn't make 
sense, skipping"
+                   << pattern;
+      return pre;
+    }
     auto grouper = PatternGrouper();
     groups_ = grouper.GroupMatches(pattern, pre);
     gid_assignments_ = grouper.GetGIDAssignments();
diff --git a/src/relay/ir/dataflow_pattern_functor.cc 
b/src/relay/ir/dataflow_pattern_functor.cc
index 828e867..290f72d 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -45,8 +45,10 @@ void DFPatternVisitor::VisitDFPattern_(const 
AttrPatternNode* op) { VisitDFPatte
 
 void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
   VisitDFPattern(op->op);
-  for (auto arg : op->args) {
-    VisitDFPattern(arg);
+  if (op->args.defined()) {
+    for (auto arg : op->args) {
+      VisitDFPattern(arg);
+    }
   }
 }
 
@@ -63,8 +65,10 @@ void DFPatternVisitor::VisitDFPattern_(const 
DominatorPatternNode* op) {
 void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}
 
 void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
-  for (auto param : op->params) {
-    VisitDFPattern(param);
+  if (op->params.defined()) {
+    for (auto param : op->params) {
+      VisitDFPattern(param);
+    }
   }
   VisitDFPattern(op->body);
 }
@@ -76,8 +80,10 @@ void DFPatternVisitor::VisitDFPattern_(const 
TupleGetItemPatternNode* op) {
 }
 
 void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) {
-  for (auto field : op->fields) {
-    VisitDFPattern(field);
+  if (op->fields.defined()) {
+    for (auto field : op->fields) {
+      VisitDFPattern(field);
+    }
   }
 }
 
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 36789e6..e4d9585 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -242,8 +242,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const 
DFPattern& pattern) {
 
     void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override {
       VisitDFPattern(op->op, graph_.node_map_[GetRef<DFPattern>(op)]);
-      for (auto arg : op->args) {
-        VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
+      if (op->args.defined()) {
+        for (auto arg : op->args) {
+          VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
+        }
       }
     }
 
@@ -262,8 +264,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const 
DFPattern& pattern) {
     void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}
 
     void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) 
override {
-      for (auto param : op->params) {
-        VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
+      if (op->params.defined()) {
+        for (auto param : op->params) {
+          VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
+        }
       }
       VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
     }
@@ -277,8 +281,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const 
DFPattern& pattern) {
     }
 
     void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override {
-      for (auto field : op->fields) {
-        VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
+      if (op->fields.defined()) {
+        for (auto field : op->fields) {
+          VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
+        }
       }
     }
 
diff --git a/tests/python/relay/test_dataflow_pattern.py 
b/tests/python/relay/test_dataflow_pattern.py
index a8e4b65..8e2c74a 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -196,6 +196,11 @@ def test_match_call():
     add_pattern = is_op("add")(wildcard(), wildcard())
     assert add_pattern.match(x + y)
 
+    # Match call with any number of inputs
+    call_pattern = wildcard()(None)
+    assert call_pattern.match(relay.op.nn.relu(x))
+    assert call_pattern.match(relay.op.add(x, y))
+
 
 def test_no_match_call():
     x = relay.var("x")
@@ -212,6 +217,11 @@ def test_match_func():
     func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
     assert func_pattern.match(relay.Function([x, y], x + y))
 
+    # Match Function with any number of inputs
+    func_pattern = FunctionPattern(None, wildcard())
+    assert func_pattern.match(relay.Function([x], x))
+    assert func_pattern.match(relay.Function([x, y], x + y))
+
 
 def test_no_match_func():
     x = relay.var("x")
@@ -369,6 +379,13 @@ def test_match_tuple():
     assert 
tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, 
z)), 1))
     assert 
tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, 
z)), 2))
 
+    # Match tuple with any inputs
+    tuple_pattern = is_tuple(None)
+    concat_pattern = is_op("concatenate")(tuple_pattern)
+    assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x,)), 
axis=0))
+    assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y)), 
axis=0))
+    assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y, 
z)), axis=0))
+
 
 def test_no_match_tuple():
     x = relay.var("x")
@@ -1375,6 +1392,63 @@ def test_partition_overused():
     assert pattern.partition(out) == out
 
 
+def test_partition_fuzzy_tuple():
+    x = relay.var("x")
+    y = relay.var("y")
+    z = x + y
+    tuple_pattern = is_tuple(None)
+    concat_pattern = is_op("concatenate")(tuple_pattern)
+
+    xp = relay.var("xp")
+    yp = relay.var("yp")
+    zp = relay.var("zp")
+
+    def create_func(args, body):
+        return relay.Function(args, body).with_attr("PartitionedFromPattern", 
"Tuple_concatenate_")
+
+    def concat(*args):
+        return relay.op.concatenate(relay.expr.Tuple(args), axis=0)
+
+    one = concat_pattern.partition(concat(x))
+    assert tvm.ir.structural_equal(one, create_func([xp], concat(xp))(x))
+    two = concat_pattern.partition(concat(x, y))
+    assert tvm.ir.structural_equal(two, create_func([xp, yp], concat(xp, 
yp))(x, y))
+    three = concat_pattern.partition(concat(x, y, z))
+    assert tvm.ir.structural_equal(three, create_func([xp, yp, zp], concat(xp, 
yp, zp))(x, y, z))
+
+
+def test_partition_fuzzy_function_args():
+
+    func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + 
wildcard()
+    x = relay.var("x")
+    y = relay.var("y")
+    z = relay.var("z")
+    b = relay.var("b")
+    xp = relay.var("xp")
+    yp = relay.var("yp")
+    zp = relay.var("zp")
+
+    def create_func(call):
+        N = len(call.op.params)
+        new_params = [relay.var(str(i)) for i in range(N + 1)]
+        label = "add_FunctionCall_add_"
+        if N == 3:
+            label = "add_" + label
+        return relay.Function(
+            new_params, relay.Call(call.op, (new_params[0:-1])) + 
new_params[-1]
+        ).with_attr("PartitionedFromPattern", label)(*([x, y, z][0:N] + [b]))
+
+    f1 = relay.Function([xp], xp + xp)(x)
+    one = func_pattern.partition(f1 + b)
+    assert tvm.ir.structural_equal(one, create_func(f1))
+    f2 = relay.Function([xp, yp], xp + yp)(x, y)
+    two = func_pattern.partition(f2 + b)
+    assert tvm.ir.structural_equal(two, create_func(f2))
+    f3 = relay.Function([xp, yp, zp], xp + yp + zp)(x, y, z)
+    three = func_pattern.partition(f3 + b)
+    assert tvm.ir.structural_equal(three, create_func(f3))
+
+
 def test_partition_check():
     pattern = is_op("nn.relu")(is_op("nn.conv2d")(is_var("input"), wildcard()))
 
@@ -1529,10 +1603,6 @@ def test_rewrite_function_with_fuzzy_body():
     assert tvm.ir.structural_equal(x + w, x + w)
 
 
[email protected](
-    """TODO(mbrookhart): The current partitioner can't properly handle 
-                       the partitioned inputs on the fuzzy body"""
-)
 def test_partition_function_with_fuzzy_body():
     """
     Allow Rewriting a function with a fuzzy body via dominator analysis
@@ -1560,7 +1630,7 @@ def test_partition_function_with_fuzzy_body():
     w2 = relay.var("w2")
     b2 = relay.var("b2")
     func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
-        "PartitionedFromPattern", "FunctionCall_add_"
+        "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_"
     )
     expr2 = func2(x, w, b) + b
     assert tvm.ir.structural_equal(pattern.partition(expr), expr2)

Reply via email to