This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2d53e6ac63 [Unity][Transform] Handle replacement at both var binding 
and usage (#16367)
2d53e6ac63 is described below

commit 2d53e6ac635b28ec64673cd69ac6995368b08cb3
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jan 8 19:31:14 2024 -0600

    [Unity][Transform] Handle replacement at both var binding and usage (#16367)
    
    Resolve a bug that caused undefined relax variables in the output of
    `CanonicalizeBindings` for cases where `VisitVarDef(const Var&)`
    replaces a variable, and `VisitExpr_(const VarNode*)` returns a value
    with different struct info, both occurring within the same
    `VarBinding`.
    
    The ExprMutator is only allowed to update a variable's struct info
    if the value bound to it has new struct info.  When
    CanonicalizeBindings replaces a trivial binding, this may provide
    better struct info as a result.
    
    Prior to this commit, `ExprMutator::ReEmitBinding` defined a
    remap for `binding->var->vid`, even if the derived class defined a
    replacement by overriding `VisitVarDef`.  If the derived class
    defines a new variable binding by overriding `VisitVarDef`, and
    also causes a variable replacement by overriding `VisitExpr` and
    returning a type with different struct info, then `ExprMutator`
    must check for both `binding->var->vid` *AND* `new_var->vid`.  The
    former may be present in the unmodified graph, and the latter may
    be produced by the derived class before delegating to the base
    class.
    
    This commit updates `ExprMutator::ReEmitBinding` to define entries for
    both replacements that may be required.
---
 src/relax/ir/expr_functor.cc                       |   2 +
 .../relax/test_transform_canonicalize_bindings.py  | 186 +++++++++++++++------
 2 files changed, 133 insertions(+), 55 deletions(-)

diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 26ce952626..e01b710df1 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -682,7 +682,9 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* 
binding, Expr new_value) {
   if (!temp.same_as(new_var)) {
     new_var = temp;
   }
+
   this->var_remap_[binding->var->vid] = new_var;
+  this->var_remap_[new_var->vid] = new_var;
 
   builder_->EmitNormalized(VarBinding(new_var, new_value));
 }
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
index a674d0c7e4..20f45f44dd 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -22,7 +22,7 @@ import tvm.testing
 import pytest
 from tvm import relax
 from tvm.ir.base import assert_structural_equal
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I, relax as R, tir as T
 
 
 def verify(input, expected):
@@ -30,7 +30,7 @@ def verify(input, expected):
 
 
 def test_simple_assignments():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestChainAssignments:
         @R.function
         def main(x: R.Tensor):
@@ -41,7 +41,7 @@ def test_simple_assignments():
             o = p
             return o
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -51,7 +51,7 @@ def test_simple_assignments():
 
 
 def test_dataflow_block():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestDataflowAssignments:
         @R.function
         def main(x: R.Tensor):
@@ -65,7 +65,7 @@ def test_dataflow_block():
                 R.output(n)
             return n
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -78,7 +78,7 @@ def test_dataflow_block():
 
 
 def test_assign_to_output_in_dataflow_block():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestDataflowAssignments:
         @R.function
         def main(x: R.Tensor):
@@ -92,7 +92,7 @@ def test_assign_to_output_in_dataflow_block():
                 R.output(n)
             return n
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -105,7 +105,7 @@ def test_assign_to_output_in_dataflow_block():
 
 
 def test_ops():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestOps:
         @R.function
         def main(x: R.Tensor, y: R.Tensor):
@@ -114,7 +114,7 @@ def test_ops():
             z = R.add(w, q)
             return R.add(q, z)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor, y: R.Tensor):
@@ -126,7 +126,7 @@ def test_ops():
 
 @pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the 
same struct info.")
 def test_casting():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestCasting:
         @R.function
         def main(x: R.Tensor) -> R.Object:
@@ -135,7 +135,7 @@ def test_casting():
             z: R.Object = y
             return z
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor) -> R.Object:
@@ -147,7 +147,7 @@ def test_casting():
 
 
 def test_match_cast():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestMatchCast:
         @R.function
         def main(x: R.Tensor):
@@ -157,7 +157,7 @@ def test_match_cast():
             w = z
             return w
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -170,7 +170,7 @@ def test_match_cast():
 
 
 def test_same_shape():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestSameShape:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")):
@@ -182,7 +182,7 @@ def test_same_shape():
             q = R.add(w, y)
             return R.add(q, w)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")):
@@ -195,7 +195,7 @@ def test_same_shape():
 
 
 def test_change_shape():
-    @tvm.script.ir_module
+    @I.ir_module
     class TestChangeShape:
         @R.function
         def main(x: R.Tensor(("m", "n"))):
@@ -207,7 +207,7 @@ def test_change_shape():
             q = R.add(w, y)
             return R.add(q, w)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor(("m", "n"))):
@@ -221,7 +221,7 @@ def test_change_shape():
 
 
 def test_unwrap_tuple():
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor, y: R.Tensor):
@@ -231,7 +231,7 @@ def test_unwrap_tuple():
             z = R.add(w, q)
             return R.add(q, z)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor, y: R.Tensor):
@@ -243,7 +243,7 @@ def test_unwrap_tuple():
 
 
 def test_basic_folding_example():
-    @tvm.script.ir_module
+    @I.ir_module
     class Input:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -253,7 +253,7 @@ def test_basic_folding_example():
                 R.output(n)
             return n
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -266,7 +266,7 @@ def test_basic_folding_example():
 
 
 def test_fold_match_cast():
-    @tvm.script.ir_module
+    @I.ir_module
     class Input:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -276,7 +276,7 @@ def test_fold_match_cast():
                 R.output(n)
             return n
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -290,7 +290,7 @@ def test_fold_match_cast():
 
 
 def test_unable_to_fold():
-    @tvm.script.ir_module
+    @I.ir_module
     class MultipleUse:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -301,7 +301,7 @@ def test_unable_to_fold():
                 R.output(n)
             return n
 
-    @tvm.script.ir_module
+    @I.ir_module
     class ComplexExpr:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -317,7 +317,7 @@ def test_unable_to_fold():
 
 
 def test_multiple_outputs():
-    @tvm.script.ir_module
+    @I.ir_module
     class Input:
         @R.function
         def main():
@@ -331,7 +331,7 @@ def test_multiple_outputs():
                 R.output(l, m, n)
             return (l, m, n)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main():
@@ -352,7 +352,7 @@ def test_single_output_multiple_nondataflow():
     statement.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Input:
         @R.function
         def main():
@@ -366,7 +366,7 @@ def test_single_output_multiple_nondataflow():
                 R.output(l, m, n)
             return n
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main():
@@ -382,7 +382,7 @@ def test_single_output_multiple_nondataflow():
 
 def test_multiply_used_in_outputs():
     # cannot fold output in this case
-    @tvm.script.ir_module
+    @I.ir_module
     class UsedInMultipleOutputs:
         @R.function
         def main() -> R.Tensor((), "int32"):
@@ -403,7 +403,7 @@ def test_canonicalize_var_to_dataflow_var_if_legal():
     `DataflowVar` outside of a `DataflowBlock`.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -413,7 +413,7 @@ def test_canonicalize_var_to_dataflow_var_if_legal():
                 R.output(y, z)
             return z
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -436,7 +436,7 @@ def 
test_update_dataflow_computations_if_var_replacement_occurs():
     updated to remain well-formed.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -447,7 +447,7 @@ def 
test_update_dataflow_computations_if_var_replacement_occurs():
                 R.output(gv1, gv2)
             return (gv1, gv2)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -471,7 +471,7 @@ def 
test_update_dataflow_computations_if_var_replacement_occurs_after_usage():
     that causes it to be replaced.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -482,7 +482,7 @@ def 
test_update_dataflow_computations_if_var_replacement_occurs_after_usage():
                 R.output(gv1, gv2)
             return (gv1, gv2)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -523,7 +523,7 @@ def 
test_replace_var_with_dataflow_if_all_usage_within_dataflow_block():
     `test_canonicalize_var_to_dataflow_var_if_legal`.)
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -534,7 +534,7 @@ def 
test_replace_var_with_dataflow_if_all_usage_within_dataflow_block():
                 R.output(gv1, gv2)
             return gv2
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -557,7 +557,7 @@ def 
test_canonicalize_var_to_dataflow_with_trivial_binding():
     binding.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -568,7 +568,7 @@ def 
test_canonicalize_var_to_dataflow_with_trivial_binding():
                 R.output(gv1, gv2)
             return gv2
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -591,7 +591,7 @@ def test_canonicalize_with_updated_struct_info():
     in order to provide better struct info.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function(private=True)
         def main(A: R.Tensor(("n", 16), dtype="int32")) -> R.Tensor(("n", 16), 
dtype="int32"):
@@ -610,7 +610,7 @@ def test_canonicalize_with_updated_struct_info():
             # version of `C` with `ndim=2`.
             return C
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function(private=True)
         def main(A: R.Tensor(("n", 16), dtype="int32")) -> R.Tensor(("n", 16), 
dtype="int32"):
@@ -634,7 +634,7 @@ def test_canonicalize_trivial_binding_to_dataflow_var():
     then canonicalization replaces the earlier DataflowVar with a Var.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -644,7 +644,7 @@ def test_canonicalize_trivial_binding_to_dataflow_var():
                 R.output(z)
             return z
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -664,7 +664,7 @@ def 
test_canonicalize_multiple_trivial_binding_to_dataflow_var():
     exist multiple trivial bindings to the DataflowVar.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(w: R.Tensor):
@@ -675,7 +675,7 @@ def 
test_canonicalize_multiple_trivial_binding_to_dataflow_var():
                 R.output(y, z)
             return (y, z)
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(w: R.Tensor):
@@ -696,7 +696,7 @@ def 
test_canonicalize_trivial_var_binding_inside_dataflow_block():
     cases both occur, should produce reasonable results.
     """
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -706,7 +706,7 @@ def 
test_canonicalize_trivial_var_binding_inside_dataflow_block():
                 R.output(y, z)
             return z
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -722,7 +722,7 @@ def 
test_canonicalize_trivial_var_binding_inside_dataflow_block():
 def test_canonicalize_across_non_dataflow_tuple():
     """Canonicalize Var to DataflowVar inside DataflowBlock"""
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor):
@@ -733,7 +733,7 @@ def test_canonicalize_across_non_dataflow_tuple():
                 R.output(z, gv)
             return gv
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor):
@@ -753,7 +753,7 @@ def test_var_used_in_distinct_df_blocks():
     but outside of the one where it was originally defined,
     it should be exposed as an output."""
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function(pure=False)
         def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -777,7 +777,7 @@ def test_var_used_in_distinct_df_blocks():
 
 
 def test_inner_function():
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function(pure=False)
         def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -811,7 +811,7 @@ def test_inner_function():
             return c
 
     # expected: we do not need to expose all the outputs
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function(pure=False)
         def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -849,7 +849,7 @@ def test_inner_function():
 
 
 def test_canonicalize_inside_branches():
-    @tvm.script.ir_module
+    @I.ir_module
     class Before:
         @R.function
         def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -871,7 +871,7 @@ def test_canonicalize_inside_branches():
                 q = v
             return q
 
-    @tvm.script.ir_module
+    @I.ir_module
     class Expected:
         @R.function
         def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -896,5 +896,81 @@ def test_canonicalize_inside_branches():
     assert_structural_equal(Expected, after)
 
 
+def test_canonicalization_causes_struct_info_update():
+    """Regression test for failure mode causing undefined variable
+
+    The ExprMutator is only allowed to update a variable's struct info
+    if the value bound to it has new struct info.  When
+    CanonicalizeBindings replaces a trivial binding, this may provide
+    better struct info as a result.  If this happens, the
+
+    In previous implementations, ExprMutator::ReEmitBinding defined a
+    remap for `binding->var->vid`, even if the derived class defined a
+    replacement by overriding `VisitVarDef`.  If the derived class
+    defines a new variable binding by overriding `VisitVarDef`, and
+    also causes a variable replacement by overriding `VisitExpr` and
+    returning a type with different struct info, then `ExprMutator`
+    must check for both `binding->var->vid` *AND* `new_var->vid`.  The
+    former may be present in the unmodified graph, and the latter may
+    be produced by the derived class before delegating to the base
+    class.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(
+            A: R.Tensor(("vocab_size", 4096), dtype="float16"),
+            B: R.Tensor((6144, 4096), dtype="float16"),
+        ):
+            with R.dataflow():
+                # Trivial binding of `DataFlow = NonDataFlow`.
+                # Wherever `C` is used, Canonicalization will attempt
+                # to replace it with `B`.
+                C = B
+
+                # RHS contains `(A,C)`, which CanonicalizeBindings
+                # replaces with `(A,B)`.  Because this changes the
+                # RHS, a new LHS (and new struct info!) will be
+                # generated.
+                D: R.Tuple(
+                    R.Tensor(dtype="float16", ndim=2),
+                    R.Tensor((6144, 4096), dtype="float16"),
+                ) = (A, C)
+
+                # Trivial binding of `NonDataFlow = DataFlow`.  The
+                # definition of `D` will be replaced with a definition
+                # of `E`.  This definition of `E` will then be updated
+                # to have a known shape.
+                E = D
+                R.output(E)
+
+            # By the time `E` is encountered at a usage site, the
+            # `ExprMutator` must have a replacement for the old
+            # version of `E` with `ndim=2` to the new versions of `E`
+            # with `shape=[vocab_size,4096]`.
+            return E
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor(("vocab_size", 4096), dtype="float16"),
+            B: R.Tensor((6144, 4096), dtype="float16"),
+        ):
+            vocab_size = T.int64()
+            with R.dataflow():
+                E: R.Tuple(
+                    R.Tensor((vocab_size, 4096), dtype="float16"),
+                    R.Tensor((6144, 4096), dtype="float16"),
+                ) = (A, B)
+
+                R.output(E)
+            return E
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to