This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ac645b2348 [Unity] Fix emit_te with symbolic input (#14552)
ac645b2348 is described below
commit ac645b2348a8b4d901e8a2f130fcecd67192de9f
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 10 02:36:54 2023 -0400
[Unity] Fix emit_te with symbolic input (#14552)
This PR fixes the use of emit_te with
symbolic shape input.
Testcases are added
---
python/tvm/relax/block_builder.py | 1 +
python/tvm/relax/utils.py | 22 +++++--
...t_blockbuilder.py => test_blockbuilder_core.py} | 3 +-
tests/python/relax/test_blockbuilder_emit_te.py | 71 ++++++++++++++++++++++
4 files changed, 90 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index 07955f3d2d..2b9cdc5fe6 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -312,6 +312,7 @@ class BlockBuilder(Object):
primfunc_name = kwargs.pop("primfunc_name_hint", None)
tir_func, call_args, output_sinfo, tir_vars =
gen_call_tir_inputs(func, *args, **kwargs)
+
if not primfunc_name:
primfunc_name = func.__name__
gvar = self.add_func(tir_func, primfunc_name)
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index e8ff144f95..02d9b1e0d4 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -336,6 +336,9 @@ def gen_call_tir_inputs(
Relax expression
"""
te_args_list = []
+ # extra list of tir expression arguments
+ # that are not covered by Tensor
+ extra_tir_args_list = []
def _copy_undefined_var(expr: tir.PrimExpr):
def _visit_expr(e: tir.PrimExpr):
@@ -376,15 +379,19 @@ def gen_call_tir_inputs(
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
elif isinstance(arg, tir.PrimExpr):
_copy_undefined_var(arg)
- return tir.stmt_functor.substitute(arg, tir_var_map)
+ new_arg = tir.stmt_functor.substitute(arg, tir_var_map)
+ extra_tir_args_list.append(new_arg)
+ return new_arg
elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is
None:
return arg
raise TypeError("not supported type in emit_te:
{}".format(type(arg)))
new_arg = _convert_te_arg_helper(te_args)
- return new_arg, te_args_list
+ return new_arg, te_args_list, extra_tir_args_list
- def _get_unbound_tir_vars(args: List[te_Tensor]) -> List[tir.Var]:
+ def _get_unbound_tir_vars(
+ args: List[te_Tensor], extra_tir_args: List[PrimExpr]
+ ) -> List[tir.Var]:
"""get unbound TIR vars (i.e TIR vars used in the shape but is not
itself a dimension of a shape)"""
bound_vars = set()
@@ -394,6 +401,9 @@ def gen_call_tir_inputs(
if isinstance(expr, tir.Var):
used_vars.add(expr)
+ for val in extra_tir_args:
+ tir.stmt_functor.post_order_visit(val, _populate_used_vars)
+
for x in args:
for s in x.shape:
tir.stmt_functor.post_order_visit(s, _populate_used_vars)
@@ -413,8 +423,8 @@ def gen_call_tir_inputs(
primfunc_attrs = kwargs.pop("primfunc_attrs", None)
tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
- new_args, te_arg_list = _convert_te_arg(args, tir_var_map)
- new_kwargs, te_kwarg_list = _convert_te_arg(kwargs, tir_var_map)
+ new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
+ new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs,
tir_var_map)
te_args = te_arg_list + te_kwarg_list
@@ -424,7 +434,7 @@ def gen_call_tir_inputs(
), "only support te.tensor or tuple/list/Array of te.tensor as function
output"
outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
- unbound_tir_vars = _get_unbound_tir_vars(te_args + outs)
+ unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list +
tir_kwarg_list)
inputs = [*te_args] + outs
tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64")
diff --git a/tests/python/relax/test_blockbuilder.py
b/tests/python/relax/test_blockbuilder_core.py
similarity index 99%
rename from tests/python/relax/test_blockbuilder.py
rename to tests/python/relax/test_blockbuilder_core.py
index 9d9d28d7d6..9932227854 100644
--- a/tests/python/relax/test_blockbuilder.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+"""Block builder unit test"""
+# The test here do not depend on tvmscript to cover most basic features
import pytest
import tvm
import tvm.testing
diff --git a/tests/python/relax/test_blockbuilder_emit_te.py
b/tests/python/relax/test_blockbuilder_emit_te.py
new file mode 100644
index 0000000000..7a519d1022
--- /dev/null
+++ b/tests/python/relax/test_blockbuilder_emit_te.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" This file tests advanced emit_te features with help of TVMScript
assertion"""
+# The tests here depend on tvmscript
+from tvm import te, tir
+from tvm import relax as rx
+from tvm.ir.base import assert_structural_equal
+from tvm.script.parser import ir as I
+from tvm.script.parser import relax as R
+from tvm.script.parser import tir as T
+
+
+def test_emit_te_with_symbolic_arg():
+ bb = rx.BlockBuilder()
+ m = tir.Var("m", "int64")
+ x = rx.Var("x", R.Tensor([10], "float32"))
+ y = rx.Var("y", R.Shape([m]))
+
+ def te_func(A, offset):
+ return te.compute(A.shape, lambda i: A[i + offset], name="B")
+
+ with bb.function("main", [x, y]):
+ out = bb.emit_te(te_func, x, m)
+ bb.emit_func_output(out)
+
+ after = bb.get()
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def te_func(
+ A: T.Buffer((T.int64(10),), "float32"),
+ B: T.Buffer((T.int64(10),), "float32"),
+ m: T.int64,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i in range(T.int64(10)):
+ with T.block("B"):
+ v_i = T.axis.spatial(T.int64(10), i)
+ T.writes(B[v_i])
+ B[v_i] = A[v_i + m]
+
+ @R.function
+ def main(
+ x: R.Tensor((10,), dtype="float32"), y: R.Shape(["m"])
+ ) -> R.Tensor((10,), dtype="float32"):
+ m = T.int64()
+ cls = Expected
+ gv = R.call_tir(
+ cls.te_func,
+ (x,),
+ out_sinfo=R.Tensor((10,), dtype="float32"),
+ tir_vars=R.shape([m]),
+ )
+ return gv
+
+ assert_structural_equal(after, Expected)