This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 170827f6346470702462e609a69557be853066c3
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 17 19:26:37 2023 -0500

    [Unity][BlockBuilder] CallTE convert PrimValue args  (#14028)
    
    Prior to this PR, the `call_te` of BlockBuilder is not capable of 
converting PrimValue arguments and directly rejects PrimValues instead. This PR 
fixes this behavior with PrimValue conversion support and one regression test.
    
    Co-authored-by: Siyuan Feng <[email protected]>
---
 python/tvm/relax/block_builder.py       |  4 +++-
 tests/python/relax/test_blockbuilder.py | 24 ++++++++++++++++++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/block_builder.py 
b/python/tvm/relax/block_builder.py
index 77b45fdf55..7837008479 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -34,7 +34,7 @@ from .expr import (
     BaseFunc,
     Binding,
 )
-from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo
+from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo, 
TensorStructInfo
 from .op.base import call_tir
 from . import _ffi_api
 
@@ -256,6 +256,8 @@ class BlockBuilder(Object):
                         arg, ShapeExpr
                     ), "For Expr having ShapeStructInfo, emit_te now only 
supports ShapeExpr"
                     return [_convert_te_arg_helper(val) for val in arg.values]
+                elif isinstance(arg.struct_info, PrimStructInfo):
+                    return arg.value
             elif isinstance(arg, (list, tvm.ir.Array)):
                 return [_convert_te_arg_helper(x) for x in arg]
             elif isinstance(arg, tuple):
diff --git a/tests/python/relax/test_blockbuilder.py 
b/tests/python/relax/test_blockbuilder.py
index 36a22f9712..e54e2b7bf9 100644
--- a/tests/python/relax/test_blockbuilder.py
+++ b/tests/python/relax/test_blockbuilder.py
@@ -23,6 +23,7 @@ from tvm import te, tir, topi
 from tvm import relax as rx, relay
 from tvm.ir.base import assert_structural_equal
 from tvm.relax import ExternFunc
+from tvm.script import relax as R
 from tvm.tir.function import PrimFunc
 
 
@@ -462,6 +463,29 @@ def test_emit_te_extern():
     assert call_node.sinfo_args[0].shape[1] == n
 
 
+def test_emit_te_prim_value():
+    bb = rx.BlockBuilder()
+    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+    x = rx.Var("x", R.Tensor([n, m], "float32"))
+    a_min = rx.PrimValue(0)
+    a_max = rx.PrimValue(6)
+
+    with bb.function("rx_clip", [x]):
+        out = bb.emit_te(topi.clip, x, a_min, a_max)
+        bb.emit_func_output(out)
+
+    rx_func = bb.get()["rx_clip"]
+
+    # check Relax function calls TIR function with call_tir call
+    assert rx_func.params[0] == x
+    assert len(rx_func.body.blocks) == 1
+    call_node = rx_func.body.blocks[0].bindings[0].value
+    assert isinstance(call_node, rx.Call)
+    assert call_node.op == relay.op.get("relax.call_tir")
+    assert len(call_node.args) == 2
+    assert call_node.args[1][0] == x
+
+
 def test_nested_function_fail():
     m = tir.Var("m", "int64")
     n = tir.Var("n", "int64")

Reply via email to