This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 233ac7aee0 [Unity] Avoid trivial var-to-var bindings in 
CanonicalizeBindings (#15840)
233ac7aee0 is described below

commit 233ac7aee09e65e42c18596b9877de8c0476b148
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Oct 2 17:09:05 2023 -0500

    [Unity] Avoid trivial var-to-var bindings in CanonicalizeBindings (#15840)
    
    Prior to this commit, the `relax.transform.CanonicalizeBindings`
    transform would detect trivial bindings `var_y = var_x`, and replace
    later usage of `var_y` with `var_x`.  However, the trivial binding
    `var_y = var_x` would be left in the canonicalized function.
    
    This commit updates the `CanonicalizeBindings` transform to remove
    trivial bindings.  This is not intended as a full dead-code
    elimination, as that is better handled as a separate pass, but is
    instead intended to avoid introduction of dead code during
    canonicalization.
---
 src/relax/transform/canonicalize_bindings.cc       | 49 ++++++++++------------
 .../relax/test_transform_canonicalize_bindings.py  | 41 ++----------------
 2 files changed, 26 insertions(+), 64 deletions(-)

diff --git a/src/relax/transform/canonicalize_bindings.cc 
b/src/relax/transform/canonicalize_bindings.cc
index ea5a612e1a..d8e3a9ba98 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -38,16 +38,6 @@ class BindingCanonicalizer : public ExprMutator {
 
   using ExprMutator::VisitExpr_;
 
-  Expr VisitExpr_(const VarNode* op) override {
-    // remap first
-    Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
-    if (!CanCanonicalizeVar(v)) {
-      return Downcast<Expr>(v);
-    }
-    // visit again in case we need to do a substitution in the value
-    return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
-  }
-
   Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
     if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
       if (auto tuple_value = LookupBinding(tuple_var.value())) {
@@ -71,12 +61,14 @@ class BindingCanonicalizer : public ExprMutator {
     Expr new_value = this->VisitExpr(binding->value);
     Var new_var = this->VisitVarDef(binding->var);
 
-    if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
+    if (auto opt_var = new_value.as<Var>();
+        opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
+      var_remap_[new_var->vid] = opt_var.value();
+    } else if (new_var.same_as(binding->var) && 
new_value.same_as(binding->value)) {
       this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
-      return;
+    } else {
+      this->builder_->EmitNormalized(VarBinding(new_var, new_value));
     }
-
-    this->builder_->EmitNormalized(VarBinding(new_var, new_value));
   }
 
   void VisitBinding_(const MatchCastNode* binding) override {
@@ -84,9 +76,19 @@ class BindingCanonicalizer : public ExprMutator {
     // we can canonicalize to a var binding
     Expr new_value = this->VisitExpr(binding->value);
 
-    // if the LHS and RHS have the same struct info, we canonicalize to a var 
binding instead
-    if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) {
-      builder_->EmitNormalized(VarBinding(binding->var, new_value));
+    bool has_same_struct_info = StructuralEqual()(binding->struct_info, 
GetStructInfo(new_value));
+
+    if (has_same_struct_info) {
+      if (auto parent = new_value.as<Var>();
+          parent && CanCanonicalizeVar(binding->var, parent.value())) {
+        // LHS and RHS have the same struct info, and occur in a
+        // context where the RHS can replace the LHS.
+        var_remap_[binding->var->vid] = parent.value();
+      } else {
+        // LHS and RHS have the same struct info, but the RHS is not a
+        // drop-in replacement for the LHS.
+        builder_->EmitNormalized(VarBinding(binding->var, new_value));
+      }
     } else if (new_value.same_as(binding->value)) {
       builder_->EmitNormalized(GetRef<MatchCast>(binding));
     } else {
@@ -104,24 +106,17 @@ class BindingCanonicalizer : public ExprMutator {
     return !(both_present || neither_present) || (both_present && 
!check_eq(obj1, obj2));
   }
 
-  bool CanCanonicalizeVar(Var v) {
-    Optional<Expr> value = LookupBinding(v);
-    // can replace only if the value is also a var
-    if (!value || !value.as<VarNode>()) {
-      return false;
-    }
-    Var parent_var = Downcast<Var>(value);
-
+  bool CanCanonicalizeVar(Var var, Var parent_var) {
     // Cases when we conservatively do not unify:
     // 1. checked_type_ or shape_ of the child differs from that of the parent
     //    In this case, we could be overriding user annotations.
     // 2. If the child is a Var and the parent is a DataflowVar.
     //    That could result in a DataflowVar leaving the current DataflowBlock.
-    bool annotations_differ = AnnotationsDiffer(v->struct_info_, 
parent_var->struct_info_,
+    bool annotations_differ = AnnotationsDiffer(var->struct_info_, 
parent_var->struct_info_,
                                                 [&](const ObjectRef& lhs, 
const ObjectRef& rhs) {
                                                   return 
tvm::StructuralEqual()(lhs, rhs);
                                                 });
-    bool var_to_dataflow = (!v.as<DataflowVarNode>() && 
parent_var.as<DataflowVarNode>());
+    bool var_to_dataflow = (!var.as<DataflowVarNode>() && 
parent_var.as<DataflowVarNode>());
     return !annotations_differ && !var_to_dataflow;
   }
 };
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
index 91396ccb13..52bf5a6e43 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -36,17 +36,10 @@ def test_simple_assignments():
             o = p
             return o
 
-    # a little annoying to have these unused bindings around
-    # but they can be eliminated in a separate pass
     @tvm.script.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
-            y = x
-            z = x
-            q = x
-            p = x
-            o = x
             return x
 
     new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments)
@@ -68,19 +61,16 @@ def test_dataflow_block():
                 R.output(n)
             return n
 
-    # a little annoying to have these unused bindings around
-    # but they can be eliminated in a separate pass
     @tvm.script.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
             with R.dataflow():
                 y = R.const(1)
-                z = y
-                o = y
-                p = y
-                m = y
-                # we can't get rid of n because it leaves the block
+                # We can't get rid of n because it leaves the block.
+                # CanonicalizeBindings does not do a full dead-code
+                # elimination, and only does local analysis of trivial
+                # bindings that it may produce.
                 n = y
                 R.output(n)
             return n
@@ -108,15 +98,6 @@ def test_assign_to_output_indataflow_block():
     class Expected:
         @R.function
         def main(x: R.Tensor):
-            with R.dataflow():
-                y = x
-                z = x
-                o = x
-                p = x
-                m = x
-                # we can't get rid of n because it leaves the block
-                n = x
-                R.output(n)
             return x
 
     new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments)
@@ -137,8 +118,6 @@ def test_ops():
     class Expected:
         @R.function
         def main(x: R.Tensor, y: R.Tensor):
-            w = y
-            q = x
             z = R.add(y, x)
             return R.add(x, z)
 
@@ -161,7 +140,6 @@ def test_casting():
     class Expected:
         @R.function
         def main(x: R.Tensor) -> R.Object:
-            y = x
             # Cannot unify because the cast indicates user intent
             z: R.Object = x
             return z
@@ -185,11 +163,9 @@ def test_match_cast():
     class Expected:
         @R.function
         def main(x: R.Tensor):
-            q = x
             # can't get rid of z because its shape_ is different from x's
             m, n = T.int64(), T.int64()
             z = R.match_cast(x, R.Tensor((m, n)))
-            w = z
             return z
 
     new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast)
@@ -213,11 +189,6 @@ def test_same_shape():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")):
-            m, n = T.int64(), T.int64()
-            y = x
-            # canonicalized into a var binding
-            z = x
-            w = x
             q = R.add(x, x)
             return R.add(q, x)
 
@@ -242,10 +213,8 @@ def test_change_shape():
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"))):
-            y = x
             o, p = T.int64(), T.int64()
             z = R.match_cast(x, R.Tensor((o, p)))
-            w = z
             # the shape_ field on q will need to be updated
             q = R.add(z, x)
             return R.add(q, z)
@@ -270,8 +239,6 @@ def test_unwrap_tuple():
         @R.function
         def main(x: R.Tensor, y: R.Tensor):
             tuple_var = (x, y)
-            w = x
-            q = y
             z = R.add(x, y)
             return R.add(y, z)
 

Reply via email to