This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 36ebcd0a87 [Relax][Transform] Canonicalize `let var = R.const`
bindings (#16601)
36ebcd0a87 is described below
commit 36ebcd0a8765e0b865b4dd6f35671434b074c5b1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Feb 18 21:03:55 2024 -0600
[Relax][Transform] Canonicalize `let var = R.const` bindings (#16601)
* [Relax][Transform] Canonicalize `let var = R.const` bindings
Prior to this commit, known tuples could be unwrapped into variables,
but any constants would remain. This commit updates
`CanonicalizeBindings` to unwrap tuples containing constants.
* Fix broken tests, removing test_unable_to_fold
The `test_unable_to_fold` was ported from the
`FoldDataflowBlockOutput` tests, and has been updated enough that it
no longer a purpose.
---
src/relax/transform/canonicalize_bindings.cc | 13 +++++
.../relax/test_transform_canonicalize_bindings.py | 67 ++++++++++++----------
.../relax/test_transform_convert_dataflow.py | 2 +-
3 files changed, 50 insertions(+), 32 deletions(-)
diff --git a/src/relax/transform/canonicalize_bindings.cc
b/src/relax/transform/canonicalize_bindings.cc
index 38aebcb8fd..9aeb289e2a 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -39,6 +39,7 @@ struct CanonicalizationPlan {
Map<Id, Var> replace_usage;
Map<Id, Var> replace_binding;
std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
+ Map<Id, Constant> inline_constant;
};
/*! \brief Utility class to identify usage location
@@ -69,6 +70,10 @@ class CanonicalizePlanner : public ExprVisitor {
}
}
+ for (const auto& [var, constant] : visitor.known_bound_to_constant_) {
+ plan.inline_constant.Set(var->vid, constant);
+ }
+
for (const auto& binding_iter : visitor.trivial_bindings_) {
Var bound_var = binding_iter.first;
Var bound_to = binding_iter.second;
@@ -180,6 +185,10 @@ class CanonicalizePlanner : public ExprVisitor {
trivial_bindings_.Set(binding->var, parent.value());
}
+ if (auto constant = value.as<Constant>()) {
+ known_bound_to_constant_.Set(binding->var, constant.value());
+ }
+
known_bindings_.Set(binding->var, value);
def_blocks_.Set(binding->var, current_block_.value());
@@ -213,6 +222,7 @@ class CanonicalizePlanner : public ExprVisitor {
Map<Var, Var> trivial_bindings_;
Map<Var, Expr> known_bindings_;
+ Map<Var, Constant> known_bound_to_constant_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
defined_inside_dataflow_;
// Set of vars either used outside a dataflow block altogether or outside
their
// home dataflow block (the one where they were defined)
@@ -251,6 +261,9 @@ class BindingCanonicalizer : public ExprMutator {
while (auto opt = plan_.replace_usage.Get(new_var->vid)) {
new_var = opt.value();
}
+ if (auto opt = plan_.inline_constant.Get(new_var->vid)) {
+ return VisitExpr(opt.value());
+ }
return ExprMutator::VisitExpr_(new_var.get());
}
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py
b/tests/python/relax/test_transform_canonicalize_bindings.py
index 20f45f44dd..7d7b74bf59 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -289,33 +289,6 @@ def test_fold_match_cast():
verify(Input, Expected)
-def test_unable_to_fold():
- @I.ir_module
- class MultipleUse:
- @R.function
- def main() -> R.Tensor((), "int32"):
- with R.dataflow():
- n = R.const(1)
- # multiple uses -> cannot coalesce
- m = R.add(n, n)
- R.output(n)
- return n
-
- @I.ir_module
- class ComplexExpr:
- @R.function
- def main() -> R.Tensor((), "int32"):
- with R.dataflow():
- y = R.const(1)
- # y does not appear by itself -> cannot coalesce
- n = R.add(y, y)
- R.output(n)
- return n
-
- verify(MultipleUse, MultipleUse)
- verify(ComplexExpr, ComplexExpr)
-
-
def test_multiple_outputs():
@I.ir_module
class Input:
@@ -380,10 +353,9 @@ def test_single_output_multiple_nondataflow():
verify(Input, Expected)
-def test_multiply_used_in_outputs():
- # cannot fold output in this case
+def test_fold_const_to_output():
@I.ir_module
- class UsedInMultipleOutputs:
+ class Before:
@R.function
def main() -> R.Tensor((), "int32"):
with R.dataflow():
@@ -391,7 +363,16 @@ def test_multiply_used_in_outputs():
R.output(n)
return n
- verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main() -> R.Tensor((), "int32"):
+ with R.dataflow():
+ n = R.const(1)
+ R.output(n)
+ return R.const(1)
+
+ verify(Before, Expected)
def test_canonicalize_var_to_dataflow_var_if_legal():
@@ -972,5 +953,29 @@ def test_canonicalization_causes_struct_info_update():
assert_structural_equal(Expected, after)
+def test_unwrap_tuple_of_constant():
+ @I.ir_module
+ class TestChainAssignments:
+ @R.function
+ def main():
+ tup = (R.const(0, "int64"), R.const(1, "int64"))
+ x = tup[0]
+ y = tup[1]
+ z = R.add(x, y)
+ return z
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main():
+ tup = (R.const(0, "int64"), R.const(1, "int64"))
+ x = tup[0]
+ y = tup[1]
+ z = R.add(R.const(0, "int64"), R.const(1, "int64"))
+ return z
+
+ verify(TestChainAssignments, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_convert_dataflow.py
b/tests/python/relax/test_transform_convert_dataflow.py
index 790465e958..8a926cd4ae 100644
--- a/tests/python/relax/test_transform_convert_dataflow.py
+++ b/tests/python/relax/test_transform_convert_dataflow.py
@@ -221,7 +221,7 @@ class TestTreatNonCallAsPure(ExtractCompare):
t2 = (y, y, x)
c = R.const([1, 2, 3], dtype="int32")
R.output(c)
- return c
+ return R.const([1, 2, 3], dtype="int32")
@R.function
def shapes() -> R.Shape: