This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 59527567ccf728b28a4fabc99f750c496a7d789c
Author: Yong Wu <[email protected]>
AuthorDate: Sat Feb 18 11:41:10 2023 -0800

    [Unity] Disallow inline prim_func in relax IR (#14040)
    
    Disallow inline prim_func in relax IR
---
 python/tvm/script/parser/relax/parser.py        | 10 ++++++
 src/relax/analysis/well_formed.cc               |  6 +++-
 src/relax/ir/block_builder.cc                   | 22 ------------
 tests/python/relax/test_analysis_well_formed.py | 36 ++++++++++++++++++++
 tests/python/relax/test_tvmscript_parser.py     | 45 ++++++++++++-------------
 5 files changed, 73 insertions(+), 46 deletions(-)

diff --git a/python/tvm/script/parser/relax/parser.py 
b/python/tvm/script/parser/relax/parser.py
index ef26ddd6e9..e5e5bb2743 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -139,6 +139,16 @@ def visit_function_def(self: Parser, node: 
doc.FunctionDef) -> None:
                     R.func_ret_struct_info(ann_sinfo)
 
                 self.visit(node.args)
+
+                for stmt in node.body:
+                    if isinstance(stmt, doc.FunctionDef):
+                        if not stmt.decorator_list:
+                            self.report_error(stmt, "Function must be 
decorated")
+                        dec = self.eval_expr(stmt.decorator_list[-1])
+                        # inline prim_func was found
+                        if dec.dispatch_token == "tir":
+                            self.report_error(stmt, "inline prim_func is 
disallowed in Relax IR")
+
                 self.visit_body(node.body)
 
 
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index e7ec237fd5..05ad0954bb 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -316,7 +316,11 @@ class WellFormedChecker : public relax::ExprVisitor,
   }
 
   void VisitBinding_(const VarBindingNode* binding) final {
-    this->VisitExpr(binding->value);
+    if (binding->value->IsInstance<tir::PrimFuncNode>()) {
+      Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is 
disallowed in Relax IR.");
+    } else {
+      this->VisitExpr(binding->value);
+    }
     this->VisitVarDef(binding->var);
   }
 
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 6a2d7ea5c5..5976cbb3f4 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -469,12 +469,6 @@ class Normalizer : public BlockBuilderImpl, private 
ExprFunctor<Expr(const Expr&
    * \note This function create a new binding for non-leaf expressions except 
for tuple.
    */
   Expr NormalizeArgument(const Expr& arg) final {
-    // Temp patch to ensure we handle inline PrimFunc case.
-    // TODO(relax-team) remove such cases from parser and testcases.
-    if (auto* prim_func = arg.as<tir::PrimFuncNode>()) {
-      return NormalizePrimFunc(GetRef<tir::PrimFunc>(prim_func));
-    }
-
     if (!block_stack_.empty()) {
       // cache lookup
       BlockFrame* cur_frame = CurrentBlockFrame();
@@ -520,23 +514,7 @@ class Normalizer : public BlockBuilderImpl, private 
ExprFunctor<Expr(const Expr&
 
   Expr VisitExpr_(const DataflowVarNode* var) final { return 
VisitVar_<DataflowVar>(var); }
 
-  // Temp patch to ensure we handle inline PrimFunc case.
-  // TODO(relax-team) remove such cases from parser and testcases.
-  Expr NormalizePrimFunc(tir::PrimFunc prim_func) {
-    if (!prim_func->struct_info_.defined()) {
-      auto finfo = 
FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type));
-      UpdateStructInfo(prim_func, finfo);
-    }
-    return prim_func;
-  }
-
   Expr VisitExpr(const Expr& expr) final {
-    // Temp patch to ensure we handle inline PrimFunc case.
-    // TODO(relax-team) remove such cases from parser and testcases.
-    if (auto* prim_func = expr.as<tir::PrimFuncNode>()) {
-      return NormalizePrimFunc(GetRef<tir::PrimFunc>(prim_func));
-    }
-
     // lookup normalize map
     if (!block_stack_.empty()) {
       BlockFrame* cur_frame = CurrentBlockFrame();
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index cc0de84d53..67da772741 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -357,6 +357,42 @@ def test_complex_seq_body():
     assert rx.analysis.well_formed(normalized, check_struct_info=True)
 
 
+def test_inline_prim_func():
+    # Error: inline prim_func is disallowed in Relax IR
+    x = rx.Var("x", R.Tensor([], "int32"))
+    y = rx.Var("y", R.Tensor([], "int32"))
+    new_func = rx.Function(
+        [],
+        rx.SeqExpr(
+            [
+                rx.BindingBlock(
+                    [
+                        rx.VarBinding(
+                            var=x,
+                            value=tir.PrimFunc([], tir.Evaluate(0)),
+                        ),
+                        rx.VarBinding(
+                            var=y,
+                            value=rx.Call(
+                                op=tvm.ir.Op.get("relax.call_tir"),
+                                args=[
+                                    rx.GlobalVar("GlobalVar0"),
+                                    rx.Tuple([x, tir.PrimFunc([], 
tir.Evaluate(0))]),
+                                    rx.ShapeExpr([]),
+                                ],
+                            ),
+                        ),
+                    ]
+                )
+            ],
+            y,
+        ),
+        R.Tensor(ndim=0, dtype="int32"),
+    ).with_attr("global_symbol", "foo")
+    new_mod = tvm.IRModule.from_expr(new_func)
+    assert not rx.analysis.well_formed(new_mod, check_struct_info=False)
+
+
 def test_ANF():
     # Error: Nested Call
     gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 6e9e14d3dc..507ce72c06 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -736,30 +736,29 @@ def test_local_function():
     inner_func = outer_func_bindings[0].value
     assert isinstance(inner_func, relax.Function)
 
-    @I.ir_module
-    class TestModule:
-        @R.function
-        def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), 
"float32")):
-            @T.prim_func
-            def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
-                A = T.match_buffer(a, (128, 128))
-                B = T.match_buffer(b, (128, 128))
-                C = T.match_buffer(c, (128, 128))
-
-                for i, j, k in T.grid(128, 128, 128):
-                    with T.block():
-                        vi, vj, vk = T.axis.remap("SSR", [i, j, k])
-                        with T.init():
-                            C[vi, vj] = 0.0
-                        C[vi, vj] += A[vi, vk] * B[vj, vk]
-
-            z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), 
dtype="float32"))
-            return z
 
-    bindings = TestModule["f"].body.blocks[0].bindings
-    assert len(bindings) == 2
-    tir_func = bindings[0].value
-    assert isinstance(tir_func, tir.PrimFunc)
+def test_inline_prim_func():
+    with pytest.raises(tvm.error.DiagnosticError):
+
+        @I.ir_module
+        class TestModule:
+            @R.function
+            def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), 
"float32")):
+                @T.prim_func
+                def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+                    A = T.match_buffer(a, (128, 128))
+                    B = T.match_buffer(b, (128, 128))
+                    C = T.match_buffer(c, (128, 128))
+
+                    for i, j, k in T.grid(128, 128, 128):
+                        with T.block():
+                            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                            with T.init():
+                                C[vi, vj] = 0.0
+                            C[vi, vj] += A[vi, vk] * B[vj, vk]
+
+                z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), 
dtype="float32"))
+                return z
 
 
 def test_cross_function_call():

Reply via email to