This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 59527567ccf728b28a4fabc99f750c496a7d789c Author: Yong Wu <[email protected]> AuthorDate: Sat Feb 18 11:41:10 2023 -0800 [Unity] Disallow inline prim_func in relax IR (#14040) Disallow inline prim_func in relax IR --- python/tvm/script/parser/relax/parser.py | 10 ++++++ src/relax/analysis/well_formed.cc | 6 +++- src/relax/ir/block_builder.cc | 22 ------------ tests/python/relax/test_analysis_well_formed.py | 36 ++++++++++++++++++++ tests/python/relax/test_tvmscript_parser.py | 45 ++++++++++++------------- 5 files changed, 73 insertions(+), 46 deletions(-) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index ef26ddd6e9..e5e5bb2743 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -139,6 +139,16 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: R.func_ret_struct_info(ann_sinfo) self.visit(node.args) + + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + if not stmt.decorator_list: + self.report_error(stmt, "Function must be decorated") + dec = self.eval_expr(stmt.decorator_list[-1]) + # inline prim_func was found + if dec.dispatch_token == "tir": + self.report_error(stmt, "inline prim_func is disallowed in Relax IR") + self.visit_body(node.body) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index e7ec237fd5..05ad0954bb 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -316,7 +316,11 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitBinding_(const VarBindingNode* binding) final { - this->VisitExpr(binding->value); + if (binding->value->IsInstance<tir::PrimFuncNode>()) { + Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR."); + } else { + this->VisitExpr(binding->value); + } this->VisitVarDef(binding->var); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 6a2d7ea5c5..5976cbb3f4 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -469,12 +469,6 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr& * \note This function create a new binding for non-leaf expressions except for tuple. */ Expr NormalizeArgument(const Expr& arg) final { - // Temp patch to ensure we handle inline PrimFunc case. - // TODO(relax-team) remove such cases from parser and testcases. - if (auto* prim_func = arg.as<tir::PrimFuncNode>()) { - return NormalizePrimFunc(GetRef<tir::PrimFunc>(prim_func)); - } - if (!block_stack_.empty()) { // cache lookup BlockFrame* cur_frame = CurrentBlockFrame(); @@ -520,23 +514,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr& Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_<DataflowVar>(var); } - // Temp patch to ensure we handle inline PrimFunc case. - // TODO(relax-team) remove such cases from parser and testcases. - Expr NormalizePrimFunc(tir::PrimFunc prim_func) { - if (!prim_func->struct_info_.defined()) { - auto finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); - UpdateStructInfo(prim_func, finfo); - } - return prim_func; - } - Expr VisitExpr(const Expr& expr) final { - // Temp patch to ensure we handle inline PrimFunc case. - // TODO(relax-team) remove such cases from parser and testcases. - if (auto* prim_func = expr.as<tir::PrimFuncNode>()) { - return NormalizePrimFunc(GetRef<tir::PrimFunc>(prim_func)); - } - // lookup normalize map if (!block_stack_.empty()) { BlockFrame* cur_frame = CurrentBlockFrame(); diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index cc0de84d53..67da772741 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -357,6 +357,42 @@ def test_complex_seq_body(): assert rx.analysis.well_formed(normalized, check_struct_info=True) +def test_inline_prim_func(): + # Error: inline prim_func is disallowed in Relax IR + x = rx.Var("x", R.Tensor([], "int32")) + y = rx.Var("y", R.Tensor([], "int32")) + new_func = rx.Function( + [], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=x, + value=tir.PrimFunc([], tir.Evaluate(0)), + ), + rx.VarBinding( + var=y, + value=rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[ + rx.GlobalVar("GlobalVar0"), + rx.Tuple([x, tir.PrimFunc([], tir.Evaluate(0))]), + rx.ShapeExpr([]), + ], + ), + ), + ] + ) + ], + y, + ), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + assert not rx.analysis.well_formed(new_mod, check_struct_info=False) + + def test_ANF(): # Error: Nested Call gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 6e9e14d3dc..507ce72c06 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -736,30 +736,29 @@ def test_local_function(): inner_func = outer_func_bindings[0].value assert isinstance(inner_func, relax.Function) - @I.ir_module - class TestModule: - @R.function - def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): - @T.prim_func - def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - C = T.match_buffer(c, (128, 128)) - - for i, j, k in T.grid(128, 128, 128): - with T.block(): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] += A[vi, vk] * B[vj, vk] - - z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) - return z - bindings = TestModule["f"].body.blocks[0].bindings - assert len(bindings) == 2 - tir_func = bindings[0].value - assert isinstance(tir_func, tir.PrimFunc) +def test_inline_prim_func(): + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) + return z def test_cross_function_call():
