This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5dcc25a15c [Unity][Relax][Transform] Do not remove MatchCast for
RemoveAllUnused (#15290)
5dcc25a15c is described below
commit 5dcc25a15ca9fef4ba6417c91243d60575c8e9f3
Author: l0phTg <[email protected]>
AuthorDate: Thu Jul 13 10:26:23 2023 +0800
[Unity][Relax][Transform] Do not remove MatchCast for RemoveAllUnused
(#15290)
Match_cast can be used to capture symbolic shapes, remove MatchCast
binding will cause `Prim Expr xxx has not been computed` in build stage.
---
src/relax/ir/binding_rewrite.cc | 2 +-
tests/python/relax/test_analysis.py | 23 +++++++++++++++++++++++
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index 2b2fb23d1d..188ed39560 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -249,7 +249,7 @@ class RemoveUnusedVars : public ExprMutator {
auto prev_dfb = GetRef<DataflowBlock>(block);
builder_->BeginDataflowBlock();
for (Binding binding : block->bindings) {
- if (!unused_vars.count(binding->var)) {
+ if (!unused_vars.count(binding->var) || binding.as<MatchCastNode>()) {
VisitBinding(binding);
}
}
diff --git a/tests/python/relax/test_analysis.py
b/tests/python/relax/test_analysis.py
index fd47d25a47..500a57775e 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -156,6 +156,29 @@ def
test_edge_binding_block_fake_unused_remove_all_unused():
tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
+def test_edge_binding_block_fake_unused_remove_all_unused2():
+ @tvm.script.ir_module
+ class IdentityUnused:
+ @R.function
+ def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(dtype="int32",
ndim=3):
+ m = T.int64()
+ n = T.int64()
+ k = T.int64()
+ with R.dataflow():
+ lv: R.Shape(ndim=3) = R.call_pure_packed(
+ "vm.builtin.tensor_to_shape", x,
sinfo_args=(R.Shape(ndim=3),)
+ )
+ lv1: R.Shape([m, n, k]) = R.match_cast(lv, R.Shape([m, n, k]))
+ gv: R.Tensor((m, n, k), dtype="int32") = R.full(
+ R.shape([m, n, k]), R.const(1, "int32"), dtype="int32"
+ )
+ R.output(gv)
+ return gv
+
+ optimized = remove_all_unused(IdentityUnused["main"])
+ tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
+
+
def test_name_to_binding_var_shadowing():
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: