This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 31d68aa19b [Unity][Analysis] Checking function return struct info in 
well-formed check (#14155)
31d68aa19b is described below

commit 31d68aa19b802e6d142de21119a8e6ea0b371ec2
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 1 21:49:11 2023 +0800

    [Unity][Analysis] Checking function return struct info in well-formed check 
(#14155)
    
    The current well-formed misses the check of function return struct info,
    which may mistakenly pass the check if there are undefined vars in the
    function return struct info.
---
 src/relax/analysis/well_formed.cc               |  6 ++++++
 tests/python/relax/test_analysis_well_formed.py | 25 ++++++++++++++++++++++++-
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 25b9155d77..9a97931136 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -225,6 +225,12 @@ class WellFormedChecker : public relax::ExprVisitor,
       }
       param_var_func_map_.insert({param, GetRef<Function>(op)});
     }
+    // check function ret_struct_info
+    if (op->ret_struct_info.defined()) {
+      this->VisitStructInfo(op->ret_struct_info);
+    } else {
+      Malformed(Diagnostic::Error(op) << "Function must have defined 
ret_struct_info");
+    }
 
     if (auto seq = op->body.as<SeqExprNode>()) {
       this->VisitSeqExpr(seq);
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index ee5814eb7b..7b8035b17c 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -17,9 +17,10 @@
 import pytest
 import tvm
 import tvm.testing
-from tvm import tir
 from tvm import relax as rx
+from tvm import tir
 from tvm.script import relax as R
+from tvm.script import tir as T
 
 m = tir.Var("m", "int64")
 n = tir.Var("n", "int64")
@@ -497,5 +498,27 @@ def test_sinfo_args_tir_var_used_before_define_call_tir():
     assert not rx.analysis.well_formed(mod, check_struct_info=False)
 
 
+def test_sinfo_erase_to_well_formed():
+    # Error: The return sinfo contains undefined symbolic vars
+    """
+    @R.function
+    def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", 
"n1"), dtype="float32"):
+        m = T.int64()
+        n = T.int64()
+        gv = R.call_tir("my_func", (x,), out_sinfo=R.Tensor((m, n), 
dtype="float32"))
+        return gv
+    """
+    m1 = tir.Var("m1", "int64")
+    n1 = tir.Var("n1", "int64")
+    call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m, n), "float32"))
+    blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]
+    seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var)
+    func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr(
+        "global_symbol", "foo"
+    )
+    mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
+    assert not rx.analysis.well_formed(mod)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to