This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 36ebcd0a87 [Relax][Transform] Canonicalize `let var = R.const` 
bindings (#16601)
36ebcd0a87 is described below

commit 36ebcd0a8765e0b865b4dd6f35671434b074c5b1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Feb 18 21:03:55 2024 -0600

    [Relax][Transform] Canonicalize `let var = R.const` bindings (#16601)
    
    * [Relax][Transform] Canonicalize `let var = R.const` bindings
    
    Prior to this commit, known tuples could be unwrapped into variables,
    but any constants would remain.  This commit updates
    `CanonicalizeBindings` to unwrap tuples containing constants.
    
    * Fix broken tests, removing test_unable_to_fold
    
    The `test_unable_to_fold` was ported from the
    `FoldDataflowBlockOutput` tests, and has been updated enough that it
    no longer a purpose.
---
 src/relax/transform/canonicalize_bindings.cc       | 13 +++++
 .../relax/test_transform_canonicalize_bindings.py  | 67 ++++++++++++----------
 .../relax/test_transform_convert_dataflow.py       |  2 +-
 3 files changed, 50 insertions(+), 32 deletions(-)

diff --git a/src/relax/transform/canonicalize_bindings.cc 
b/src/relax/transform/canonicalize_bindings.cc
index 38aebcb8fd..9aeb289e2a 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -39,6 +39,7 @@ struct CanonicalizationPlan {
   Map<Id, Var> replace_usage;
   Map<Id, Var> replace_binding;
   std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
+  Map<Id, Constant> inline_constant;
 };
 
 /*! \brief Utility class to identify usage location
@@ -69,6 +70,10 @@ class CanonicalizePlanner : public ExprVisitor {
       }
     }
 
+    for (const auto& [var, constant] : visitor.known_bound_to_constant_) {
+      plan.inline_constant.Set(var->vid, constant);
+    }
+
     for (const auto& binding_iter : visitor.trivial_bindings_) {
       Var bound_var = binding_iter.first;
       Var bound_to = binding_iter.second;
@@ -180,6 +185,10 @@ class CanonicalizePlanner : public ExprVisitor {
       trivial_bindings_.Set(binding->var, parent.value());
     }
 
+    if (auto constant = value.as<Constant>()) {
+      known_bound_to_constant_.Set(binding->var, constant.value());
+    }
+
     known_bindings_.Set(binding->var, value);
     def_blocks_.Set(binding->var, current_block_.value());
 
@@ -213,6 +222,7 @@ class CanonicalizePlanner : public ExprVisitor {
 
   Map<Var, Var> trivial_bindings_;
   Map<Var, Expr> known_bindings_;
+  Map<Var, Constant> known_bound_to_constant_;
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> 
defined_inside_dataflow_;
   // Set of vars either used outside a dataflow block altogether or outside 
their
   // home dataflow block (the one where they were defined)
@@ -251,6 +261,9 @@ class BindingCanonicalizer : public ExprMutator {
     while (auto opt = plan_.replace_usage.Get(new_var->vid)) {
       new_var = opt.value();
     }
+    if (auto opt = plan_.inline_constant.Get(new_var->vid)) {
+      return VisitExpr(opt.value());
+    }
 
     return ExprMutator::VisitExpr_(new_var.get());
   }
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
index 20f45f44dd..7d7b74bf59 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -289,33 +289,6 @@ def test_fold_match_cast():
     verify(Input, Expected)
 
 
-def test_unable_to_fold():
-    @I.ir_module
-    class MultipleUse:
-        @R.function
-        def main() -> R.Tensor((), "int32"):
-            with R.dataflow():
-                n = R.const(1)
-                # multiple uses -> cannot coalesce
-                m = R.add(n, n)
-                R.output(n)
-            return n
-
-    @I.ir_module
-    class ComplexExpr:
-        @R.function
-        def main() -> R.Tensor((), "int32"):
-            with R.dataflow():
-                y = R.const(1)
-                # y does not appear by itself -> cannot coalesce
-                n = R.add(y, y)
-                R.output(n)
-            return n
-
-    verify(MultipleUse, MultipleUse)
-    verify(ComplexExpr, ComplexExpr)
-
-
 def test_multiple_outputs():
     @I.ir_module
     class Input:
@@ -380,10 +353,9 @@ def test_single_output_multiple_nondataflow():
     verify(Input, Expected)
 
 
-def test_multiply_used_in_outputs():
-    # cannot fold output in this case
+def test_fold_const_to_output():
     @I.ir_module
-    class UsedInMultipleOutputs:
+    class Before:
         @R.function
         def main() -> R.Tensor((), "int32"):
             with R.dataflow():
@@ -391,7 +363,16 @@ def test_multiply_used_in_outputs():
                 R.output(n)
             return n
 
-    verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main() -> R.Tensor((), "int32"):
+            with R.dataflow():
+                n = R.const(1)
+                R.output(n)
+            return R.const(1)
+
+    verify(Before, Expected)
 
 
 def test_canonicalize_var_to_dataflow_var_if_legal():
@@ -972,5 +953,29 @@ def test_canonicalization_causes_struct_info_update():
     assert_structural_equal(Expected, after)
 
 
+def test_unwrap_tuple_of_constant():
+    @I.ir_module
+    class TestChainAssignments:
+        @R.function
+        def main():
+            tup = (R.const(0, "int64"), R.const(1, "int64"))
+            x = tup[0]
+            y = tup[1]
+            z = R.add(x, y)
+            return z
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main():
+            tup = (R.const(0, "int64"), R.const(1, "int64"))
+            x = tup[0]
+            y = tup[1]
+            z = R.add(R.const(0, "int64"), R.const(1, "int64"))
+            return z
+
+    verify(TestChainAssignments, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_convert_dataflow.py 
b/tests/python/relax/test_transform_convert_dataflow.py
index 790465e958..8a926cd4ae 100644
--- a/tests/python/relax/test_transform_convert_dataflow.py
+++ b/tests/python/relax/test_transform_convert_dataflow.py
@@ -221,7 +221,7 @@ class TestTreatNonCallAsPure(ExtractCompare):
                 t2 = (y, y, x)
                 c = R.const([1, 2, 3], dtype="int32")
                 R.output(c)
-            return c
+            return R.const([1, 2, 3], dtype="int32")
 
         @R.function
         def shapes() -> R.Shape:

Reply via email to