This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9725144bef [Unity][Transform] Canonicalize and use CSE between pattern 
matches (#15904)
9725144bef is described below

commit 9725144bef4c544e453fa57e47cabece3c6906d2
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Fri Oct 13 07:04:31 2023 -0500

    [Unity][Transform] Canonicalize and use CSE between pattern matches (#15904)
    
    * [Unity][Transform] Canonicalize and use CSE between pattern matches
    
    The `PatternRewriter` is intended to iterate until no matching
    patterns remain.  Prior to this commit, this only involved repeating
    the pattern match rewrite rules.  However, intermediate results
    produced by pattern replacement could cause the iterative pattern
    matching to terminate early.
    
    * If two rewrite rules each introduce the same intermediate, there
      will exist two copies of that intermediate, which can prevent
      `only_used_by` patterns from matching.  Applying
      `EliminateCommonSubexpr` allows the pattern matching to continue.
    
    * Applying a rewrite rule may result in dangling intermediates that
      are no longer used.  These dangling intermediates may prevent the
      next application of a rewrite rule that uses the `only_used_by`
      constraint.  Applying `RemoveAllUnused` allows the pattern matching
      to continue.
    
    * A rewrite rule that returns a `relax::Var` or `relax::TupleGetItem`
      as the replacement introduces trivial var-to-var rebinding, which
      are not tracked by `PatternRewriter`.  Applying
      `CanonicalizeBindings` allows the pattern matching to continue.
    
    While this could be fixed externally by repeatedly applying
    `rewrite_call`, this would require re-inspecting the entire function,
    and not just the dataflow block in which the replacement occurred.
    
    * Fix tests for removing redundant reshapes
    
    * Fixed failing unit tests, along with edge case in CSE
---
 .../relax/transform/remove_redundant_reshape.py    |  15 +-
 src/relax/ir/dataflow_matcher.cc                   |  97 +++++----
 src/relax/transform/eliminate_common_subexpr.cc    |  22 +-
 tests/python/relax/test_dataflow_pattern.py        | 239 +++++++++++++++++++--
 .../python/relax/test_optimize_layout_transform.py |  25 +--
 .../python/relax/test_remove_redundant_reshape.py  |  18 +-
 6 files changed, 333 insertions(+), 83 deletions(-)

diff --git a/python/tvm/relax/transform/remove_redundant_reshape.py 
b/python/tvm/relax/transform/remove_redundant_reshape.py
index 2274f8e5da..a48923df78 100644
--- a/python/tvm/relax/transform/remove_redundant_reshape.py
+++ b/python/tvm/relax/transform/remove_redundant_reshape.py
@@ -66,13 +66,18 @@ class RemoveRedundantReshape:
                 continue
 
             def rewriter(expr, matches):
-                args = matches[self.pattern]
+                arg = matches[self.input1]
+
                 if self.repeated_reshape in matches:
-                    return relax.op.reshape(matches[self.input1], args.args[1])
+                    output_shape = matches[self.repeated_reshape].args[1]
+                    return relax.op.reshape(arg, output_shape)
+
                 elif self.no_op_reshape in matches:
-                    if args.args[0].struct_info.shape:
-                        if structural_equal(args.args[0].struct_info.shape, 
args.args[1]):
-                            return args.args[0]
+                    output_shape = matches[self.no_op_reshape].args[1]
+                    if arg.struct_info.shape and structural_equal(
+                        arg.struct_info.shape, output_shape
+                    ):
+                        return arg
                 return expr
 
             updated_func = rewrite_call(self.pattern, rewriter, funct)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index e85e3c4d51..d1edb945ba 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -42,6 +42,7 @@
 #include <utility>
 #include <vector>
 
+#include "../transform/utils.h"
 #include "dataflow_matcher_impl.h"
 
 namespace tvm {
@@ -937,57 +938,76 @@ class PatternRewriter : ExprMutator {
     return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
   }
 
-  void VisitBinding_(const VarBindingNode* binding) final {
-    bindings_.Set(binding->var, binding->value);
-    ExprMutator::VisitBinding_(binding);
-    if (auto it = memo_.find(binding->value.get()); it != memo_.end()) {
-      // We need to update the binding to pass to ExtractMatchedExpr, so that 
the rewritten
-      // expression can be subject to further pattern matchings.
-      bindings_.Set(binding->var, it->second);
+  Expr VisitExpr_(const SeqExprNode* seq) override {
+    if (ctx_) {
+      return ExprMutator::VisitExpr_(seq);
     }
-  }
 
-  Expr VisitExpr(const Expr& expr) final {
-    auto node = ExprMutator::VisitExpr(expr);
-    if (pattern_) {
-      if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, 
bindings_)) {
-        Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
-        if (!rewritten_expr.same_as(node)) {
-          rewritten_expr = builder_->Normalize(rewritten_expr);
-
-          // If the rewriter returns a variable (e.g. when rewriting
-          // from `R.add(x, R.const(0.0))` to `x`), the variable
-          // should be dereferenced to avoid trivial `var_2 = var_1`
-          // bindings.  This lookup is done using the builder_ instead
-          // of the bindings_, as the previous `builder_->Normalize`
-          // call may have introduced variable bindings.
-          if (auto opt_var = rewritten_expr.as<Var>()) {
-            if (auto binding = builder_->LookupBinding(opt_var.value())) {
-              rewritten_expr = binding.value();
-            }
-          }
-          memo_[expr.get()] = rewritten_expr;
-          return rewritten_expr;
+    auto cache = bindings_;
+    SeqExpr prev = GetRef<SeqExpr>(seq);
+
+    StructuralEqual struct_equal;
+
+    while (true) {
+      SeqExpr next = 
Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
+      if (struct_equal(prev, next)) {
+        return std::move(next);
+      }
+
+      // Canonicalization may result in two previously-different
+      // expressions being recognized as identical.  Elimination of
+      // common subexpressions may result in trival var-to-var
+      // bindings that can be canonicalized.  Therefore, iterate the
+      // simplification steps until converged.
+      while (true) {
+        auto start_of_loop = next;
+        next = Downcast<SeqExpr>(CanonicalizeBindings(next));
+        next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
+        next = Downcast<SeqExpr>(RemoveAllUnused(next));
+        if (struct_equal(start_of_loop, next)) {
+          break;
         }
       }
+
+      if (struct_equal(prev, next)) {
+        return std::move(next);
+      }
+
+      // Reset all knowledge of bindings that were collected from
+      // this DataflowBlock.  The collected bindings are only after
+      // the point where they were collected, and we are repeating
+      // the mutation of this DataflowBlock.
+      bindings_ = cache;
+      prev = next;
     }
-    return node;
   }
 
-  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) 
override {
     if (ctx_) {
       return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
+    } else {
+      return ExprMutator::VisitBindingBlock_(block_node);
     }
+  }
 
-    DataflowBlock prev = GetRef<DataflowBlock>(block_node);
-    while (true) {
-      DataflowBlock next = 
Downcast<DataflowBlock>(ExprMutator::VisitBindingBlock_(prev.get()));
-      if (StructuralEqual()(prev, next)) {
-        return std::move(next);
-      } else {
-        prev = next;
+  void VisitBinding_(const VarBindingNode* binding) override {
+    auto expr = VisitExpr(binding->value);
+    bindings_.Set(binding->var, expr);
+    ReEmitBinding(binding, expr);
+  }
+
+  Expr VisitExpr(const Expr& expr) override {
+    auto node = ExprMutator::VisitExpr(expr);
+
+    if (pattern_) {
+      if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, 
bindings_)) {
+        Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
+        if (!rewritten_expr.same_as(node)) {
+          return builder_->Normalize(rewritten_expr);
+        }
       }
     }
+    return node;
   }
 
  private:
@@ -1076,7 +1096,6 @@ class PatternRewriter : ExprMutator {
   PackedFunc rewriter_func_;
   std::unordered_set<const VarNode*> params_;
   Map<Var, Expr> bindings_;
-  std::unordered_map<const Object*, Expr> memo_;
 };
 
 Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, 
Function f) {
diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 842470c463..2addb60697 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -105,7 +105,27 @@ class SubexprCounter : public ExprVisitor {
         count_map_[e] = count + 1;
       }
     }
-    ExprVisitor::VisitExpr(e);
+
+    // Only visit the interior of objects that we might still keep
+    // around.  Otherwise, double-counting these would lead to extra
+    // variable bindings.
+    //
+    // Before:
+    //     y = f(a+b)
+    //     z = f(a+b)
+    //
+    // Expected:
+    //     y = f(a+b)  // De-duped from (y==z)
+    //     z = y
+    //
+    // Erroneous output:
+    //     c = a+b    // Incorrect, a+b only has a single usage.
+    //     y = f(c)   // De-duped from
+    //     z = y
+    //
+    if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 
2) {
+      ExprVisitor::VisitExpr(e);
+    }
   }
 
   // do not visit inner functions: we will do CSE within those
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 7b68655a2f..49b7d11a80 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -15,6 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import functools
+import math
+
 import pytest
 
 import tvm.testing
@@ -1372,23 +1375,51 @@ def test_repeated_pattern_match():
     tvm.ir.assert_structural_equal(after, expected)
 
 
-def test_rewrite_without_trivial_binding():
-    """rewrite_call should avoid producing trivial "y = x" bindings"""
+bind_to_dataflow_var = tvm.testing.parameter(
+    by_dict={"var-to-var": False, "var-to-dataflow-var": True}
+)
 
-    @R.function(private=True)
-    def before(x: R.Tensor((1024,))):
-        with R.dataflow():
+
+def test_rewrite_without_trivial_binding(bind_to_dataflow_var):
+    """rewrite_call should avoid producing trivial "y = x" bindings
+
+    This may not be possible in all cases, and follows the same
+    rules as CanonicalizeBindings.  For example, a `relax.Var` is
+    bound to a `relax.DataflowVar` may not be removed, to ensure
+    that the `relax.DataflowVar` is only used within a
+    `DataflowBlock`.
+    """
+
+    if bind_to_dataflow_var:
+
+        @R.function(private=True)
+        def before(x: R.Tensor((1024,))):
+            with R.dataflow():
+                a = R.add(x, x)
+                b = R.reshape(a, (1024,))
+                R.output(b)
+            return b
+
+        @R.function(private=True)
+        def expected(x: R.Tensor((1024,))):
+            with R.dataflow():
+                a = R.add(x, x)
+                b = a
+                R.output(b)
+            return b
+
+    else:
+
+        @R.function(private=True)
+        def before(x: R.Tensor((1024,))):
             a = R.add(x, x)
             b = R.reshape(a, (1024,))
-            R.output(b)
-        return b
+            return b
 
-    @R.function(private=True)
-    def expected(x: R.Tensor((1024,))):
-        with R.dataflow():
+        @R.function(private=True)
+        def expected(x: R.Tensor((1024,))):
             a = R.add(x, x)
-            R.output(a)
-        return a
+            return a
 
     pattern_arg = wildcard()
     pattern_shape_expr = wildcard()
@@ -1490,5 +1521,189 @@ def test_same_shape_pattern(same_shape_func_type):
         assert match is None
 
 
+def test_iterative_rewrite_without_trivial_binding():
+    """Avoid introducing common sub-expressions
+
+    Pattern replacement may produce the same intermediate, which
+    should appear only once in the final result.
+    """
+
+    @R.function(private=True)
+    def before(x: R.Tensor((1024,))):
+        with R.dataflow():
+            a = R.strided_slice(x, [0], [0], [512], [1])
+            b = R.strided_slice(x, [0], [512], [1024], [1])
+            c = R.add(a, b)
+            R.output(c)
+        return c
+
+    @R.function(private=True)
+    def expected(x: R.Tensor((1024,))):
+        with R.dataflow():
+            x_split = R.split(x, 2)
+            a = x_split[0]
+            b = x_split[1]
+            c = R.add(a, b)
+            R.output(c)
+        return c
+
+    pattern_arg = wildcard()
+    pattern = is_op("relax.strided_slice")(pattern_arg).has_attr(
+        {
+            "axes": [0],
+            "strides": [T.int64(1)],
+        }
+    )
+
+    def rewriter(expr, matches):
+        arg = matches[pattern_arg]
+        strided_slice = matches[pattern]
+
+        if arg.struct_info.shape is None:
+            return expr
+
+        size = arg.struct_info.shape[0]
+        begin = strided_slice.attrs.begin[0]
+        end = strided_slice.attrs.end[0]
+        if (
+            isinstance(size, tir.IntImm)
+            and isinstance(begin, tir.IntImm)
+            and isinstance(end, tir.IntImm)
+        ):
+            size = size.value
+            begin = begin.value
+            end = end.value
+        else:
+            return expr
+
+        gcd = functools.reduce(math.gcd, [begin, end, size])
+        if (end - begin) // gcd == 1:
+            return rx.op.split(arg, size // gcd)[begin // gcd]
+
+        return expr
+
+    after = rewrite_call(pattern, rewriter, before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_iterative_rewrite_with_removed_intermediates():
+    """Pattern replacement may require canonicalization
+
+    A pattern may replace a tuple returned by a function with a tuple
+    whose contents are known by Relax.  In that case, canonicalization
+    is required to unwrap the TupleGetItem instances into the known
+    contents.
+
+    This test case shows the intermediate results produced in the
+    process of pattern-matching.
+    """
+
+    @R.function(private=True)
+    def before(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
+        with R.dataflow():
+            c = R.concat([a, b])
+            d = R.split(c, 2)
+            e = d[0]
+            f = d[1]
+            g = R.add(a, e)
+            h = R.add(f, g)
+            R.output(h)
+        return h
+
+    # First pattern rewrite.  The concat/rewrite can be unwrapped, so
+    # `d` is rewritten from `R.split(c, 2)` into `(a, b)`.
+    #
+    # @R.function(private=True)
+    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
+    #     with R.dataflow():
+    #         c = R.concat([a, b])
+    #         d = (a,b)
+    #         e = d[0]
+    #         f = d[1]
+    #         g = R.add(a, e)
+    #         h = R.add(f, g)
+    #         R.output(h)
+
+    # Canonicalization step.  Because `d` is known to be `(a,b)`,
+    # canonicalization can rewrite `d[0]` into `a` and `d[1]` into
+    # `b`.
+    #
+    # @R.function(private=True)
+    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
+    #     with R.dataflow():
+    #         c = R.concat([a, b])
+    #         d = (a,b)
+    #         e = a
+    #         f = b
+    #         g = R.add(a, a)
+    #         h = R.add(b, g)
+    #         R.output(h)
+
+    # Dead-code-elimination step.  This technically isn't required
+    # until the pattern matching has converged, but performing it now
+    # prevents testing for matches on dead code.
+    #
+    # @R.function(private=True)
+    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
+    #     with R.dataflow():
+    #         g = R.add(a, a)
+    #         h = R.add(b, g)
+    #         R.output(h)
+
+    # Second pattern-matching step.  Now, the `R.add(a,a)` can match
+    # the other option in our pattern, and be rewritten as
+    # `R.multiply(a,R.const(2))`.
+    #
+    # @R.function(private=True)
+    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
+    #     with R.dataflow():
+    #         g = R.multiply(a, R.const(2))
+    #         h = R.add(b, g)
+    #         R.output(h)
+
+    # Canonicalization and dead-code-elimination are applied again,
+    # but have no effect this time.
+
+    @R.function(private=True)
+    def expected(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
+        with R.dataflow():
+            g = R.multiply(a, R.const(2))
+            h = R.add(b, g)
+            R.output(h)
+        return h
+
+    pat_args = wildcard()
+
+    op_concat = is_op("relax.concat")
+    pat_concat = op_concat(pat_args).has_attr({"axis": 0})
+
+    op_split = is_op("relax.split")
+    pat_split = op_split(pat_concat).has_attr({"axis": 0, 
"indices_or_sections": T.int64(2)})
+
+    pat_unwrap_concat_split = pat_split
+
+    pat_arg = wildcard()
+    op_add = is_op("relax.add")
+    pat_add_self = op_add(pat_arg, pat_arg)
+
+    pattern = pat_unwrap_concat_split | pat_add_self
+
+    def rewriter(expr, matches):
+        if pat_unwrap_concat_split in matches:
+            args = matches[pat_args]
+
+            if len(args) == 2 and tvm.ir.structural_equal(args[0].struct_info, 
args[1].struct_info):
+                return args
+
+        elif pat_add_self in matches:
+            arg = matches[pat_arg]
+            return arg * rx.const(2)
+
+        return expr
+
+    after = rewrite_call(pattern, rewriter, before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_optimize_layout_transform.py 
b/tests/python/relax/test_optimize_layout_transform.py
index 56c11984ee..08c9e31107 100644
--- a/tests/python/relax/test_optimize_layout_transform.py
+++ b/tests/python/relax/test_optimize_layout_transform.py
@@ -26,19 +26,15 @@ from tvm.script import ir as I, tir as T, relax as R
 
 
 def _run_pass_compare_output(Before, Expected):
-    fused_mod = OptimizeLayoutTransform()(Before)
-    if not relax.analysis.well_formed(fused_mod):
-        print("IRModule is not well-formed")
+    After = tvm.ir.transform.Sequential(
+        [
+            OptimizeLayoutTransform(),
+            DeadCodeElimination(),
+            FuseTIR(),
+        ]
+    )(Before)
 
-    fused_mod = DeadCodeElimination()(fused_mod)
-    if not relax.analysis.well_formed(fused_mod):
-        print("IRModule is not well-formed")
-
-    fused_mod = FuseTIR()(fused_mod)
-    if not relax.analysis.well_formed(fused_mod):
-        print("IRModule is not well-formed")
-
-    tvm.ir.assert_structural_equal(Expected, fused_mod)
+    tvm.ir.assert_structural_equal(Expected, After)
 
 
 def test_optimize_transform_layout_pass_one_arg():
@@ -129,12 +125,9 @@ def test_optimize_transform_layout_pass_one_arg():
                     (lv, lv1),
                     out_sinfo=R.Tensor((4, 4), dtype="float32"),
                 )
-                lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
-                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
-                )
                 lv5 = R.call_tir(
                     Expected.relax_add_replacement,
-                    (lv4, lv2),
+                    (lv1, lv2),
                     out_sinfo=R.Tensor((4, 4), dtype="float32"),
                 )
                 lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
diff --git a/tests/python/relax/test_remove_redundant_reshape.py 
b/tests/python/relax/test_remove_redundant_reshape.py
index 806b563cb8..11e8c87cf1 100644
--- a/tests/python/relax/test_remove_redundant_reshape.py
+++ b/tests/python/relax/test_remove_redundant_reshape.py
@@ -41,9 +41,9 @@ def test_remove_redundant_reshape_pass_one_arg():
             with R.dataflow():
                 lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
                 lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, 
R.shape([1, 1001]))
-                lv2: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv1, 
R.shape([1, 1001]))
-                R.output(lv2)
-            return lv2
+                gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv1, 
R.shape([1, 1001]))
+                R.output(gv)
+            return gv
 
     @I.ir_module
     class Expected:
@@ -52,9 +52,10 @@ def test_remove_redundant_reshape_pass_one_arg():
             x: R.Tensor((1, 1001, 1, 1), dtype="float16")
         ) -> R.Tensor((1, 1001), dtype="float16"):
             with R.dataflow():
-                lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
-                R.output(lv1)
-            return lv1
+                lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
+                gv: R.Tensor((1, 1001), dtype="float16") = lv
+                R.output(gv)
+            return gv
 
     _run_pass_compare_output(Before, Expected)
 
@@ -106,10 +107,7 @@ def test_remove_redundant_reshape_pass_three_arg():
         def main(
             x: R.Tensor((1, 1001, 1, 1), dtype="float16")
         ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
-            with R.dataflow():
-                lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = x
-                R.output(lv)
-            return lv
+            return x
 
     _run_pass_compare_output(Before, Expected)
 

Reply via email to