This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ac645b2348 [Unity] Fix emit_te with symbolic input (#14552)
ac645b2348 is described below

commit ac645b2348a8b4d901e8a2f130fcecd67192de9f
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 10 02:36:54 2023 -0400

    [Unity] Fix emit_te with symbolic input (#14552)
    
    This PR fixes the use of emit_te with
    symbolic shape input.
    
    Testcases are added
---
 python/tvm/relax/block_builder.py                  |  1 +
 python/tvm/relax/utils.py                          | 22 +++++--
 ...t_blockbuilder.py => test_blockbuilder_core.py} |  3 +-
 tests/python/relax/test_blockbuilder_emit_te.py    | 71 ++++++++++++++++++++++
 4 files changed, 90 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/block_builder.py 
b/python/tvm/relax/block_builder.py
index 07955f3d2d..2b9cdc5fe6 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -312,6 +312,7 @@ class BlockBuilder(Object):
 
         primfunc_name = kwargs.pop("primfunc_name_hint", None)
         tir_func, call_args, output_sinfo, tir_vars = 
gen_call_tir_inputs(func, *args, **kwargs)
+
         if not primfunc_name:
             primfunc_name = func.__name__
         gvar = self.add_func(tir_func, primfunc_name)
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index e8ff144f95..02d9b1e0d4 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -336,6 +336,9 @@ def gen_call_tir_inputs(
             Relax expression
         """
         te_args_list = []
+        # extra list of tir expression arguments
+        # that are not covered by Tensor
+        extra_tir_args_list = []
 
         def _copy_undefined_var(expr: tir.PrimExpr):
             def _visit_expr(e: tir.PrimExpr):
@@ -376,15 +379,19 @@ def gen_call_tir_inputs(
                 return {k: _convert_te_arg_helper(arg[k]) for k in arg}
             elif isinstance(arg, tir.PrimExpr):
                 _copy_undefined_var(arg)
-                return tir.stmt_functor.substitute(arg, tir_var_map)
+                new_arg = tir.stmt_functor.substitute(arg, tir_var_map)
+                extra_tir_args_list.append(new_arg)
+                return new_arg
             elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is 
None:
                 return arg
             raise TypeError("not supported type in emit_te: 
{}".format(type(arg)))
 
         new_arg = _convert_te_arg_helper(te_args)
-        return new_arg, te_args_list
+        return new_arg, te_args_list, extra_tir_args_list
 
-    def _get_unbound_tir_vars(args: List[te_Tensor]) -> List[tir.Var]:
+    def _get_unbound_tir_vars(
+        args: List[te_Tensor], extra_tir_args: List[PrimExpr]
+    ) -> List[tir.Var]:
         """get unbound TIR vars (i.e TIR vars used in the shape but is not
         itself a dimension of a shape)"""
         bound_vars = set()
@@ -394,6 +401,9 @@ def gen_call_tir_inputs(
             if isinstance(expr, tir.Var):
                 used_vars.add(expr)
 
+        for val in extra_tir_args:
+            tir.stmt_functor.post_order_visit(val, _populate_used_vars)
+
         for x in args:
             for s in x.shape:
                 tir.stmt_functor.post_order_visit(s, _populate_used_vars)
@@ -413,8 +423,8 @@ def gen_call_tir_inputs(
     primfunc_attrs = kwargs.pop("primfunc_attrs", None)
 
     tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
-    new_args, te_arg_list = _convert_te_arg(args, tir_var_map)
-    new_kwargs, te_kwarg_list = _convert_te_arg(kwargs, tir_var_map)
+    new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
+    new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs, 
tir_var_map)
 
     te_args = te_arg_list + te_kwarg_list
 
@@ -424,7 +434,7 @@ def gen_call_tir_inputs(
     ), "only support te.tensor or tuple/list/Array of te.tensor as function 
output"
 
     outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
-    unbound_tir_vars = _get_unbound_tir_vars(te_args + outs)
+    unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + 
tir_kwarg_list)
 
     inputs = [*te_args] + outs
     tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64")
diff --git a/tests/python/relax/test_blockbuilder.py 
b/tests/python/relax/test_blockbuilder_core.py
similarity index 99%
rename from tests/python/relax/test_blockbuilder.py
rename to tests/python/relax/test_blockbuilder_core.py
index 9d9d28d7d6..9932227854 100644
--- a/tests/python/relax/test_blockbuilder.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -14,7 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+"""Block builder unit test"""
+# The test here do not depend on tvmscript to cover most basic features
 import pytest
 import tvm
 import tvm.testing
diff --git a/tests/python/relax/test_blockbuilder_emit_te.py 
b/tests/python/relax/test_blockbuilder_emit_te.py
new file mode 100644
index 0000000000..7a519d1022
--- /dev/null
+++ b/tests/python/relax/test_blockbuilder_emit_te.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" This file tests advanced emit_te features with help of TVMScript 
assertion"""
+# The tests here depend on tvmscript
+from tvm import te, tir
+from tvm import relax as rx
+from tvm.ir.base import assert_structural_equal
+from tvm.script.parser import ir as I
+from tvm.script.parser import relax as R
+from tvm.script.parser import tir as T
+
+
+def test_emit_te_with_symbolic_arg():
+    bb = rx.BlockBuilder()
+    m = tir.Var("m", "int64")
+    x = rx.Var("x", R.Tensor([10], "float32"))
+    y = rx.Var("y", R.Shape([m]))
+
+    def te_func(A, offset):
+        return te.compute(A.shape, lambda i: A[i + offset], name="B")
+
+    with bb.function("main", [x, y]):
+        out = bb.emit_te(te_func, x, m)
+        bb.emit_func_output(out)
+
+    after = bb.get()
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def te_func(
+            A: T.Buffer((T.int64(10),), "float32"),
+            B: T.Buffer((T.int64(10),), "float32"),
+            m: T.int64,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i in range(T.int64(10)):
+                with T.block("B"):
+                    v_i = T.axis.spatial(T.int64(10), i)
+                    T.writes(B[v_i])
+                    B[v_i] = A[v_i + m]
+
+        @R.function
+        def main(
+            x: R.Tensor((10,), dtype="float32"), y: R.Shape(["m"])
+        ) -> R.Tensor((10,), dtype="float32"):
+            m = T.int64()
+            cls = Expected
+            gv = R.call_tir(
+                cls.te_func,
+                (x,),
+                out_sinfo=R.Tensor((10,), dtype="float32"),
+                tir_vars=R.shape([m]),
+            )
+            return gv
+
+    assert_structural_equal(after, Expected)

Reply via email to