This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5dcc25a15c [Unity][Relax][Transform] Do not remove MatchCast for 
RemoveAllUnused (#15290)
5dcc25a15c is described below

commit 5dcc25a15ca9fef4ba6417c91243d60575c8e9f3
Author: l0phTg <[email protected]>
AuthorDate: Thu Jul 13 10:26:23 2023 +0800

    [Unity][Relax][Transform] Do not remove MatchCast for RemoveAllUnused 
(#15290)
    
    Match_cast can be used to capture symbolic shapes, remove MatchCast
    binding will cause `Prim Expr xxx has not been computed` in build stage.
---
 src/relax/ir/binding_rewrite.cc     |  2 +-
 tests/python/relax/test_analysis.py | 23 +++++++++++++++++++++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index 2b2fb23d1d..188ed39560 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -249,7 +249,7 @@ class RemoveUnusedVars : public ExprMutator {
     auto prev_dfb = GetRef<DataflowBlock>(block);
     builder_->BeginDataflowBlock();
     for (Binding binding : block->bindings) {
-      if (!unused_vars.count(binding->var)) {
+      if (!unused_vars.count(binding->var) || binding.as<MatchCastNode>()) {
         VisitBinding(binding);
       }
     }
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
index fd47d25a47..500a57775e 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -156,6 +156,29 @@ def 
test_edge_binding_block_fake_unused_remove_all_unused():
     tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
 
 
+def test_edge_binding_block_fake_unused_remove_all_unused2():
+    @tvm.script.ir_module
+    class IdentityUnused:
+        @R.function
+        def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(dtype="int32", 
ndim=3):
+            m = T.int64()
+            n = T.int64()
+            k = T.int64()
+            with R.dataflow():
+                lv: R.Shape(ndim=3) = R.call_pure_packed(
+                    "vm.builtin.tensor_to_shape", x, 
sinfo_args=(R.Shape(ndim=3),)
+                )
+                lv1: R.Shape([m, n, k]) = R.match_cast(lv, R.Shape([m, n, k]))
+                gv: R.Tensor((m, n, k), dtype="int32") = R.full(
+                    R.shape([m, n, k]), R.const(1, "int32"), dtype="int32"
+                )
+                R.output(gv)
+            return gv
+
+    optimized = remove_all_unused(IdentityUnused["main"])
+    tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
+
+
 def test_name_to_binding_var_shadowing():
     @R.function
     def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:

Reply via email to