This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2d53e6ac63 [Unity][Transform] Handle replacement at both var binding
and usage (#16367)
2d53e6ac63 is described below
commit 2d53e6ac635b28ec64673cd69ac6995368b08cb3
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jan 8 19:31:14 2024 -0600
[Unity][Transform] Handle replacement at both var binding and usage (#16367)
Resolve a bug that caused undefined relax variables in the output of
`CanonicalizeBindings` for cases where `VisitVarDef(const Var&)`
replaces a variable, and `VisitExpr_(const VarNode*)` returns a value
with different struct info, both occurring within the same
`VarBinding`.
The ExprMutator is only allowed to update a variable's struct info
if the value bound to it has new struct info. When
CanonicalizeBindings replaces a trivial binding, this may provide
better struct info as a result.
Prior to this commit, `ExprMutator::ReEmitBinding` defined a
remap for `binding->var->vid`, even if the derived class defined a
replacement by overriding `VisitVarDef`. If the derived class
defines a new variable binding by overriding `VisitVarDef`, and
also causes a variable replacement by overriding `VisitExpr` and
returning a type with different struct info, then `ExprMutator`
must check for both `binding->var->vid` *AND* `new_var->vid`. The
former may be present in the unmodified graph, and the latter may
be produced by the derived class before delegating to the base
class.
This commit updates `ExprMutator::ReEmitBinding` to define entries for
both replacements that may be required.
---
src/relax/ir/expr_functor.cc | 2 +
.../relax/test_transform_canonicalize_bindings.py | 186 +++++++++++++++------
2 files changed, 133 insertions(+), 55 deletions(-)
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 26ce952626..e01b710df1 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -682,7 +682,9 @@ void ExprMutator::ReEmitBinding(const VarBindingNode*
binding, Expr new_value) {
if (!temp.same_as(new_var)) {
new_var = temp;
}
+
this->var_remap_[binding->var->vid] = new_var;
+ this->var_remap_[new_var->vid] = new_var;
builder_->EmitNormalized(VarBinding(new_var, new_value));
}
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py
b/tests/python/relax/test_transform_canonicalize_bindings.py
index a674d0c7e4..20f45f44dd 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -22,7 +22,7 @@ import tvm.testing
import pytest
from tvm import relax
from tvm.ir.base import assert_structural_equal
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I, relax as R, tir as T
def verify(input, expected):
@@ -30,7 +30,7 @@ def verify(input, expected):
def test_simple_assignments():
- @tvm.script.ir_module
+ @I.ir_module
class TestChainAssignments:
@R.function
def main(x: R.Tensor):
@@ -41,7 +41,7 @@ def test_simple_assignments():
o = p
return o
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -51,7 +51,7 @@ def test_simple_assignments():
def test_dataflow_block():
- @tvm.script.ir_module
+ @I.ir_module
class TestDataflowAssignments:
@R.function
def main(x: R.Tensor):
@@ -65,7 +65,7 @@ def test_dataflow_block():
R.output(n)
return n
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -78,7 +78,7 @@ def test_dataflow_block():
def test_assign_to_output_in_dataflow_block():
- @tvm.script.ir_module
+ @I.ir_module
class TestDataflowAssignments:
@R.function
def main(x: R.Tensor):
@@ -92,7 +92,7 @@ def test_assign_to_output_in_dataflow_block():
R.output(n)
return n
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -105,7 +105,7 @@ def test_assign_to_output_in_dataflow_block():
def test_ops():
- @tvm.script.ir_module
+ @I.ir_module
class TestOps:
@R.function
def main(x: R.Tensor, y: R.Tensor):
@@ -114,7 +114,7 @@ def test_ops():
z = R.add(w, q)
return R.add(q, z)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
@@ -126,7 +126,7 @@ def test_ops():
@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the
same struct info.")
def test_casting():
- @tvm.script.ir_module
+ @I.ir_module
class TestCasting:
@R.function
def main(x: R.Tensor) -> R.Object:
@@ -135,7 +135,7 @@ def test_casting():
z: R.Object = y
return z
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor) -> R.Object:
@@ -147,7 +147,7 @@ def test_casting():
def test_match_cast():
- @tvm.script.ir_module
+ @I.ir_module
class TestMatchCast:
@R.function
def main(x: R.Tensor):
@@ -157,7 +157,7 @@ def test_match_cast():
w = z
return w
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -170,7 +170,7 @@ def test_match_cast():
def test_same_shape():
- @tvm.script.ir_module
+ @I.ir_module
class TestSameShape:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")):
@@ -182,7 +182,7 @@ def test_same_shape():
q = R.add(w, y)
return R.add(q, w)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")):
@@ -195,7 +195,7 @@ def test_same_shape():
def test_change_shape():
- @tvm.script.ir_module
+ @I.ir_module
class TestChangeShape:
@R.function
def main(x: R.Tensor(("m", "n"))):
@@ -207,7 +207,7 @@ def test_change_shape():
q = R.add(w, y)
return R.add(q, w)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"))):
@@ -221,7 +221,7 @@ def test_change_shape():
def test_unwrap_tuple():
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor):
@@ -231,7 +231,7 @@ def test_unwrap_tuple():
z = R.add(w, q)
return R.add(q, z)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor):
@@ -243,7 +243,7 @@ def test_unwrap_tuple():
def test_basic_folding_example():
- @tvm.script.ir_module
+ @I.ir_module
class Input:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -253,7 +253,7 @@ def test_basic_folding_example():
R.output(n)
return n
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -266,7 +266,7 @@ def test_basic_folding_example():
def test_fold_match_cast():
- @tvm.script.ir_module
+ @I.ir_module
class Input:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -276,7 +276,7 @@ def test_fold_match_cast():
R.output(n)
return n
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -290,7 +290,7 @@ def test_fold_match_cast():
def test_unable_to_fold():
- @tvm.script.ir_module
+ @I.ir_module
class MultipleUse:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -301,7 +301,7 @@ def test_unable_to_fold():
R.output(n)
return n
- @tvm.script.ir_module
+ @I.ir_module
class ComplexExpr:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -317,7 +317,7 @@ def test_unable_to_fold():
def test_multiple_outputs():
- @tvm.script.ir_module
+ @I.ir_module
class Input:
@R.function
def main():
@@ -331,7 +331,7 @@ def test_multiple_outputs():
R.output(l, m, n)
return (l, m, n)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main():
@@ -352,7 +352,7 @@ def test_single_output_multiple_nondataflow():
statement.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Input:
@R.function
def main():
@@ -366,7 +366,7 @@ def test_single_output_multiple_nondataflow():
R.output(l, m, n)
return n
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main():
@@ -382,7 +382,7 @@ def test_single_output_multiple_nondataflow():
def test_multiply_used_in_outputs():
# cannot fold output in this case
- @tvm.script.ir_module
+ @I.ir_module
class UsedInMultipleOutputs:
@R.function
def main() -> R.Tensor((), "int32"):
@@ -403,7 +403,7 @@ def test_canonicalize_var_to_dataflow_var_if_legal():
`DataflowVar` outside of a `DataflowBlock`.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -413,7 +413,7 @@ def test_canonicalize_var_to_dataflow_var_if_legal():
R.output(y, z)
return z
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -436,7 +436,7 @@ def
test_update_dataflow_computations_if_var_replacement_occurs():
updated to remain well-formed.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -447,7 +447,7 @@ def
test_update_dataflow_computations_if_var_replacement_occurs():
R.output(gv1, gv2)
return (gv1, gv2)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -471,7 +471,7 @@ def
test_update_dataflow_computations_if_var_replacement_occurs_after_usage():
that causes it to be replaced.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -482,7 +482,7 @@ def
test_update_dataflow_computations_if_var_replacement_occurs_after_usage():
R.output(gv1, gv2)
return (gv1, gv2)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -523,7 +523,7 @@ def
test_replace_var_with_dataflow_if_all_usage_within_dataflow_block():
`test_canonicalize_var_to_dataflow_var_if_legal`.)
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -534,7 +534,7 @@ def
test_replace_var_with_dataflow_if_all_usage_within_dataflow_block():
R.output(gv1, gv2)
return gv2
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -557,7 +557,7 @@ def
test_canonicalize_var_to_dataflow_with_trivial_binding():
binding.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -568,7 +568,7 @@ def
test_canonicalize_var_to_dataflow_with_trivial_binding():
R.output(gv1, gv2)
return gv2
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -591,7 +591,7 @@ def test_canonicalize_with_updated_struct_info():
in order to provide better struct info.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function(private=True)
def main(A: R.Tensor(("n", 16), dtype="int32")) -> R.Tensor(("n", 16),
dtype="int32"):
@@ -610,7 +610,7 @@ def test_canonicalize_with_updated_struct_info():
# version of `C` with `ndim=2`.
return C
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function(private=True)
def main(A: R.Tensor(("n", 16), dtype="int32")) -> R.Tensor(("n", 16),
dtype="int32"):
@@ -634,7 +634,7 @@ def test_canonicalize_trivial_binding_to_dataflow_var():
then canonicalization replaces the earlier DataflowVar with a Var.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -644,7 +644,7 @@ def test_canonicalize_trivial_binding_to_dataflow_var():
R.output(z)
return z
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -664,7 +664,7 @@ def
test_canonicalize_multiple_trivial_binding_to_dataflow_var():
exist multiple trivial bindings to the DataflowVar.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(w: R.Tensor):
@@ -675,7 +675,7 @@ def
test_canonicalize_multiple_trivial_binding_to_dataflow_var():
R.output(y, z)
return (y, z)
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(w: R.Tensor):
@@ -696,7 +696,7 @@ def
test_canonicalize_trivial_var_binding_inside_dataflow_block():
cases both occur, should produce reasonable results.
"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -706,7 +706,7 @@ def
test_canonicalize_trivial_var_binding_inside_dataflow_block():
R.output(y, z)
return z
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -722,7 +722,7 @@ def
test_canonicalize_trivial_var_binding_inside_dataflow_block():
def test_canonicalize_across_non_dataflow_tuple():
"""Canonicalize Var to DataflowVar inside DataflowBlock"""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor):
@@ -733,7 +733,7 @@ def test_canonicalize_across_non_dataflow_tuple():
R.output(z, gv)
return gv
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
@@ -753,7 +753,7 @@ def test_var_used_in_distinct_df_blocks():
but outside of the one where it was originally defined,
it should be exposed as an output."""
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -777,7 +777,7 @@ def test_var_used_in_distinct_df_blocks():
def test_inner_function():
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -811,7 +811,7 @@ def test_inner_function():
return c
# expected: we do not need to expose all the outputs
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -849,7 +849,7 @@ def test_inner_function():
def test_canonicalize_inside_branches():
- @tvm.script.ir_module
+ @I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -871,7 +871,7 @@ def test_canonicalize_inside_branches():
q = v
return q
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@@ -896,5 +896,81 @@ def test_canonicalize_inside_branches():
assert_structural_equal(Expected, after)
+def test_canonicalization_causes_struct_info_update():
+ """Regression test for failure mode causing undefined variable
+
+ The ExprMutator is only allowed to update a variable's struct info
+ if the value bound to it has new struct info. When
+ CanonicalizeBindings replaces a trivial binding, this may provide
+ better struct info as a result. If this happens, the
+
+ In previous implementations, ExprMutator::ReEmitBinding defined a
+ remap for `binding->var->vid`, even if the derived class defined a
+ replacement by overriding `VisitVarDef`. If the derived class
+ defines a new variable binding by overriding `VisitVarDef`, and
+ also causes a variable replacement by overriding `VisitExpr` and
+ returning a type with different struct info, then `ExprMutator`
+ must check for both `binding->var->vid` *AND* `new_var->vid`. The
+ former may be present in the unmodified graph, and the latter may
+ be produced by the derived class before delegating to the base
+ class.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(
+ A: R.Tensor(("vocab_size", 4096), dtype="float16"),
+ B: R.Tensor((6144, 4096), dtype="float16"),
+ ):
+ with R.dataflow():
+ # Trivial binding of `DataFlow = NonDataFlow`.
+ # Wherever `C` is used, Canonicalization will attempt
+ # to replace it with `B`.
+ C = B
+
+ # RHS contains `(A,C)`, which CanonicalizeBindings
+ # replaces with `(A,B)`. Because this changes the
+ # RHS, a new LHS (and new struct info!) will be
+ # generated.
+ D: R.Tuple(
+ R.Tensor(dtype="float16", ndim=2),
+ R.Tensor((6144, 4096), dtype="float16"),
+ ) = (A, C)
+
+ # Trivial binding of `NonDataFlow = DataFlow`. The
+ # definition of `D` will be replaced with a definition
+ # of `E`. This definition of `E` will then be updated
+ # to have a known shape.
+ E = D
+ R.output(E)
+
+ # By the time `E` is encountered at a usage site, the
+ # `ExprMutator` must have a replacement for the old
+ # version of `E` with `ndim=2` to the new versions of `E`
+ # with `shape=[vocab_size,4096]`.
+ return E
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor(("vocab_size", 4096), dtype="float16"),
+ B: R.Tensor((6144, 4096), dtype="float16"),
+ ):
+ vocab_size = T.int64()
+ with R.dataflow():
+ E: R.Tuple(
+ R.Tensor((vocab_size, 4096), dtype="float16"),
+ R.Tensor((6144, 4096), dtype="float16"),
+ ) = (A, B)
+
+ R.output(E)
+ return E
+
+ after = relax.transform.CanonicalizeBindings()(Before)
+ assert_structural_equal(Expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()