This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ae8f9e373b [Unity][BlockBuilder] CallTE convert PrimValue args
(#14028)
ae8f9e373b is described below
commit ae8f9e373bf9e05358ca512f640c8db97015ac40
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Feb 17 19:26:37 2023 -0500
[Unity][BlockBuilder] CallTE convert PrimValue args (#14028)
Prior to this PR, the `call_te` of BlockBuilder is not capable of
converting PrimValue arguments and directly rejects PrimValues instead. This PR
fixes this behavior with PrimValue conversion support and one regression test.
Co-authored-by: Siyuan Feng <[email protected]>
---
python/tvm/relax/block_builder.py | 4 +++-
python/tvm/topi/math.py | 18 ++++++++++++++----
tests/python/relax/test_blockbuilder.py | 24 ++++++++++++++++++++++++
tests/python/topi/python/test_topi_clip.py | 16 +++++++++++++---
4 files changed, 54 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index 77b45fdf55..7837008479 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -34,7 +34,7 @@ from .expr import (
BaseFunc,
Binding,
)
-from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo
+from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo,
TensorStructInfo
from .op.base import call_tir
from . import _ffi_api
@@ -256,6 +256,8 @@ class BlockBuilder(Object):
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only
supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
+ elif isinstance(arg.struct_info, PrimStructInfo):
+ return arg.value
elif isinstance(arg, (list, tvm.ir.Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index dd191c49be..4d305d1e2f 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -18,6 +18,8 @@
# pylint: disable=redefined-builtin,unused-argument
import tvm
from tvm import te
+from tvm.tir import PrimExpr
+
from . import tag
from . import cpp
from .utils import get_const_tuple
@@ -620,9 +622,9 @@ def clip(x, a_min, a_max):
----------
x : tvm.te.Tensor
Input argument.
- a_min : int or float
+ a_min : tvm.tir.PrimExpr
Minimum value.
- a_max : int or float
+ a_max : tvm.tir.PrimExpr
Maximum value.
Returns
@@ -633,8 +635,16 @@ def clip(x, a_min, a_max):
def _compute(*indices):
value = x(*indices)
- const_min = tvm.tir.const(a_min, value.dtype)
- const_max = tvm.tir.const(a_max, value.dtype)
+ const_min = (
+ tvm.tir.Cast(value.dtype, a_min)
+ if isinstance(a_min, PrimExpr)
+ else tvm.tir.const(a_min, value.dtype)
+ )
+ const_max = (
+ tvm.tir.Cast(value.dtype, a_max)
+ if isinstance(a_max, PrimExpr)
+ else tvm.tir.const(a_max, value.dtype)
+ )
return tvm.te.max(tvm.te.min(value, const_max), const_min)
return te.compute(x.shape, _compute)
diff --git a/tests/python/relax/test_blockbuilder.py
b/tests/python/relax/test_blockbuilder.py
index 36a22f9712..e54e2b7bf9 100644
--- a/tests/python/relax/test_blockbuilder.py
+++ b/tests/python/relax/test_blockbuilder.py
@@ -23,6 +23,7 @@ from tvm import te, tir, topi
from tvm import relax as rx, relay
from tvm.ir.base import assert_structural_equal
from tvm.relax import ExternFunc
+from tvm.script import relax as R
from tvm.tir.function import PrimFunc
@@ -462,6 +463,29 @@ def test_emit_te_extern():
assert call_node.sinfo_args[0].shape[1] == n
+def test_emit_te_prim_value():
+ bb = rx.BlockBuilder()
+ n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+ x = rx.Var("x", R.Tensor([n, m], "float32"))
+ a_min = rx.PrimValue(0)
+ a_max = rx.PrimValue(6)
+
+ with bb.function("rx_clip", [x]):
+ out = bb.emit_te(topi.clip, x, a_min, a_max)
+ bb.emit_func_output(out)
+
+ rx_func = bb.get()["rx_clip"]
+
+ # check Relax function calls TIR function with call_tir call
+ assert rx_func.params[0] == x
+ assert len(rx_func.body.blocks) == 1
+ call_node = rx_func.body.blocks[0].bindings[0].value
+ assert isinstance(call_node, rx.Call)
+ assert call_node.op == relay.op.get("relax.call_tir")
+ assert len(call_node.args) == 2
+ assert call_node.args[1][0] == x
+
+
def test_nested_function_fail():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
diff --git a/tests/python/topi/python/test_topi_clip.py
b/tests/python/topi/python/test_topi_clip.py
index 21546e8b57..68bb45580f 100644
--- a/tests/python/topi/python/test_topi_clip.py
+++ b/tests/python/topi/python/test_topi_clip.py
@@ -17,7 +17,7 @@
"""Test code for clip operator"""
import numpy as np
import tvm
-from tvm import te
+from tvm import te, tir
from tvm import topi
import tvm.testing
import tvm.topi.testing
@@ -32,12 +32,14 @@ def verify_clip(N, a_min, a_max, dtype):
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_clip")
- def get_ref_data():
+ def get_ref_data(a_min, a_max):
a_np = np.random.uniform(a_min * 2, a_max * 2, size=(N,
N)).astype(dtype)
b_np = np.clip(a_np, a_min, a_max)
return a_np, b_np
- a_np, b_np = get_ref_data()
+ a_min = a_min.value if isinstance(a_min, (tir.FloatImm, tir.IntImm)) else
a_min
+ a_max = a_max.value if isinstance(a_max, (tir.FloatImm, tir.IntImm)) else
a_max
+ a_np, b_np = get_ref_data(a_min, a_max)
def check_target(target, dev):
print("Running on target: %s" % target)
@@ -61,5 +63,13 @@ def test_clip():
verify_clip(1024, -127, 127, "int8")
[email protected]_gpu
+def test_clip_floaimm_intimm():
+ verify_clip(1024, tir.FloatImm("float32", -127), tir.FloatImm("float32",
127), "float32")
+ verify_clip(1024, tir.IntImm("int32", -127), tir.IntImm("int32", 127),
"int16")
+ verify_clip(1024, tir.IntImm("int32", -127), tir.IntImm("int32", 127),
"int8")
+
+
if __name__ == "__main__":
test_clip()
+ test_clip_floaimm_intimm()