This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 170827f6346470702462e609a69557be853066c3 Author: Ruihang Lai <[email protected]> AuthorDate: Fri Feb 17 19:26:37 2023 -0500 [Unity][BlockBuilder] CallTE convert PrimValue args (#14028) Prior to this PR, the `call_te` of BlockBuilder is not capable of converting PrimValue arguments and directly rejects PrimValues instead. This PR fixes this behavior with PrimValue conversion support and one regression test. Co-authored-by: Siyuan Feng <[email protected]> --- python/tvm/relax/block_builder.py | 4 +++- tests/python/relax/test_blockbuilder.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 77b45fdf55..7837008479 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -34,7 +34,7 @@ from .expr import ( BaseFunc, Binding, ) -from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo +from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo, TensorStructInfo from .op.base import call_tir from . import _ffi_api @@ -256,6 +256,8 @@ class BlockBuilder(Object): arg, ShapeExpr ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" return [_convert_te_arg_helper(val) for val in arg.values] + elif isinstance(arg.struct_info, PrimStructInfo): + return arg.value elif isinstance(arg, (list, tvm.ir.Array)): return [_convert_te_arg_helper(x) for x in arg] elif isinstance(arg, tuple): diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 36a22f9712..e54e2b7bf9 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -23,6 +23,7 @@ from tvm import te, tir, topi from tvm import relax as rx, relay from tvm.ir.base import assert_structural_equal from tvm.relax import ExternFunc +from tvm.script import relax as R from tvm.tir.function import PrimFunc @@ -462,6 +463,29 @@ def test_emit_te_extern(): assert call_node.sinfo_args[0].shape[1] == n +def test_emit_te_prim_value(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", R.Tensor([n, m], "float32")) + a_min = rx.PrimValue(0) + a_max = rx.PrimValue(6) + + with bb.function("rx_clip", [x]): + out = bb.emit_te(topi.clip, x, a_min, a_max) + bb.emit_func_output(out) + + rx_func = bb.get()["rx_clip"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[1][0] == x + + def test_nested_function_fail(): m = tir.Var("m", "int64") n = tir.Var("n", "int64")
