This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 31d68aa19b [Unity][Analysis] Checking function return struct info in
well-formed check (#14155)
31d68aa19b is described below
commit 31d68aa19b802e6d142de21119a8e6ea0b371ec2
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 1 21:49:11 2023 +0800
[Unity][Analysis] Checking function return struct info in well-formed check
(#14155)
The current well-formed misses the check of function return struct info,
which may mistakenly pass the check if there are undefined vars in the
function return struct info.
---
src/relax/analysis/well_formed.cc | 6 ++++++
tests/python/relax/test_analysis_well_formed.py | 25 ++++++++++++++++++++++++-
2 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 25b9155d77..9a97931136 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -225,6 +225,12 @@ class WellFormedChecker : public relax::ExprVisitor,
}
param_var_func_map_.insert({param, GetRef<Function>(op)});
}
+ // check function ret_struct_info
+ if (op->ret_struct_info.defined()) {
+ this->VisitStructInfo(op->ret_struct_info);
+ } else {
+ Malformed(Diagnostic::Error(op) << "Function must have defined
ret_struct_info");
+ }
if (auto seq = op->body.as<SeqExprNode>()) {
this->VisitSeqExpr(seq);
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index ee5814eb7b..7b8035b17c 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -17,9 +17,10 @@
import pytest
import tvm
import tvm.testing
-from tvm import tir
from tvm import relax as rx
+from tvm import tir
from tvm.script import relax as R
+from tvm.script import tir as T
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
@@ -497,5 +498,27 @@ def test_sinfo_args_tir_var_used_before_define_call_tir():
assert not rx.analysis.well_formed(mod, check_struct_info=False)
+def test_sinfo_erase_to_well_formed():
+ # Error: The return sinfo contains undefined symbolic vars
+ """
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1",
"n1"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ gv = R.call_tir("my_func", (x,), out_sinfo=R.Tensor((m, n),
dtype="float32"))
+ return gv
+ """
+ m1 = tir.Var("m1", "int64")
+ n1 = tir.Var("n1", "int64")
+ call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m, n), "float32"))
+ blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]
+ seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
+ func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr(
+ "global_symbol", "foo"
+ )
+ mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
+ assert not rx.analysis.well_formed(mod)
+
+
if __name__ == "__main__":
tvm.testing.main()