This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9879263c73 [Unity][TVMScript] Fix prim_func lost issue in 
relax.emit_te (#14189)
9879263c73 is described below

commit 9879263c7399f806dc31ce24190636b14bfd599a
Author: Yong Wu <[email protected]>
AuthorDate: Wed Mar 8 05:53:21 2023 -0800

    [Unity][TVMScript] Fix prim_func lost issue in relax.emit_te (#14189)
    
    Previously R.emit_te was introduced in 
https://github.com/apache/tvm/pull/14123. The `prim_func`s were not added into 
the same `ir_module`. The pr is to fix this issue, move some `call_tir` input 
handling code from `bb.call_te` to utils, then it is able to be leveraged by 
both `bb.emit_te` and `R.emit_te`.
---
 python/tvm/relax/block_builder.py               | 167 ++--------------------
 python/tvm/relax/utils.py                       | 181 +++++++++++++++++++++++-
 python/tvm/script/ir_builder/relax/ir.py        |  22 +--
 src/script/ir_builder/ir/ir.cc                  |   4 +-
 src/script/ir_builder/ir/utils.h                |  11 ++
 tests/python/relax/test_tvmscript_ir_builder.py |  12 +-
 tests/python/relax/test_tvmscript_parser.py     |  19 +++
 7 files changed, 235 insertions(+), 181 deletions(-)

diff --git a/python/tvm/relax/block_builder.py 
b/python/tvm/relax/block_builder.py
index 3421bd4d09..ebf35b5765 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=no-else-return, invalid-name
 """Developer API of constructing Relax AST."""
-import typing
 
 from typing import Dict, List, Optional, Union, Any, Callable
 from tvm.ir.module import IRModule
@@ -25,18 +24,17 @@ from tvm import relax as rx, tir
 import tvm
 from .expr import (
     Expr,
-    te_tensor,
     Var,
-    ShapeExpr,
     GlobalVar,
     BindingBlock,
     Tuple,
     BaseFunc,
     Binding,
 )
-from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo, 
TensorStructInfo
+from .struct_info import StructInfo
 from .op.base import call_tir
 from . import _ffi_api
+from .utils import gen_call_tir_inputs
 
 
 class FunctionScope(object):
@@ -196,107 +194,6 @@ class BlockBuilder(Object):
             if not is_emit_func_output_called:
                 raise RuntimeError("emit_func_output must be called in a relax 
function.")
 
-    def _convert_te_arg(
-        self, te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
-    ) -> typing.Tuple[Any, List[tvm.te.Tensor]]:
-        """Helper function used by `call_te` to convert Relax expressions to 
TE tensor.
-
-        In the common case, the type of te_args is a Relax expression and is 
converted
-        into a TE tensor.
-        If te_args is a nested or recursive datatype (i.e list, dict, 
tvm.ir.Map, tvm.ir.Array),
-        we recursive and convert any value of type Relax expression into a TE 
tensor.
-        Common values of type int, float, and str are preserved.
-
-        In dynamic shape cases, the passed in arguments may contain TIR 
variable.
-        For example, the argument can be a Relax Var with TensorStructInfo, 
which
-        has symbolic shape, or the argument can be a ShapeExpr with symbolic 
variables.
-        To make the PrimFunc generated by `call_te` has independent variables 
with
-        the caller Relax function, we will substitute the TIR variables in the 
input
-        arguments with fresh ones, which is done by maintaining a TIR variable 
mapping.
-
-        Parameters
-        ----------
-        te_args : Any
-            Argument to convert to TE
-
-        tir_var_map : Dict[tir.Var, tir.PrimExpr]
-            The TIR variable mapping, which maps TIR variables on the Relax 
function
-            side to the new set of variables used on the PrimFunc side.
-
-        Returns
-        -------
-        ret : (Any, [tvm.te.Tensor])
-            A tuple of the converted te_args, and a list of te tensors for 
each converted
-            Relax expression
-        """
-        te_args_list = []
-
-        def _copy_undefined_var(expr: tir.PrimExpr):
-            def _visit_expr(e: tir.PrimExpr):
-                if isinstance(e, tir.Var) and e not in tir_var_map:
-                    new_var = tir.Var(e.name, e.dtype)
-                    tir_var_map[e] = new_var
-
-            tir.stmt_functor.post_order_visit(expr, _visit_expr)
-
-        def _convert_te_arg_helper(arg):
-            if isinstance(arg, Expr):  # type: ignore
-                if isinstance(arg.struct_info, TensorStructInfo):
-                    assert isinstance(
-                        arg.struct_info.shape, ShapeExpr
-                    ), "emit_te now only supports Tensor that has ShapeExpr 
shape"
-                    for shape_value in arg.struct_info.shape.values:
-                        _copy_undefined_var(shape_value)
-
-                    arg = te_tensor(arg, tir_var_map)
-                    te_args_list.append(arg)
-                    return arg
-                elif isinstance(arg.struct_info, ShapeStructInfo):
-                    assert isinstance(
-                        arg, ShapeExpr
-                    ), "For Expr having ShapeStructInfo, emit_te now only 
supports ShapeExpr"
-                    return [_convert_te_arg_helper(val) for val in arg.values]
-                elif isinstance(arg.struct_info, PrimStructInfo):
-                    return arg.value
-            elif isinstance(arg, (list, tvm.ir.Array)):
-                return [_convert_te_arg_helper(x) for x in arg]
-            elif isinstance(arg, tuple):
-                return tuple([_convert_te_arg_helper(x) for x in arg])
-            elif isinstance(arg, (dict, tvm.ir.Map)):
-                for key in arg:
-                    assert isinstance(
-                        key, str
-                    ), "emit_te only supports dict with string as the key 
currently"
-                return {k: _convert_te_arg_helper(arg[k]) for k in arg}
-            elif isinstance(arg, tir.PrimExpr):
-                _copy_undefined_var(arg)
-                return tir.stmt_functor.substitute(arg, tir_var_map)
-            elif isinstance(arg, (int, float, str, tvm.ir.Type, tvm.ir.Attrs)) 
or arg is None:
-                return arg
-            raise TypeError("not supported type in emit_te: 
{}".format(type(arg)))
-
-        new_arg = _convert_te_arg_helper(te_args)
-        return new_arg, te_args_list
-
-    def _get_unbound_tir_vars(self, args: List[tvm.te.Tensor]) -> 
List[tvm.tir.Var]:
-        """get unbound TIR vars (i.e TIR vars used in the shape but is not
-        itself a dimension of a shape)"""
-        bound_vars = set()
-        used_vars = set()
-
-        def _populate_used_vars(expr):
-            if isinstance(expr, tvm.tir.Var):
-                used_vars.add(expr)
-
-        for x in args:
-            for s in x.shape:
-                tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars)
-                if isinstance(s, tir.Var):
-                    bound_vars.add(s)
-
-        diff = used_vars - bound_vars
-        return list(diff)
-
     def function(
         self,
         name: str,
@@ -410,61 +307,13 @@ class BlockBuilder(Object):
             A newly created call node
         """
 
-        primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
-        tir_var_map: Dict[tir.Var, tir.PrimExpr] = dict()
-        new_args, te_arg_list = self._convert_te_arg(args, tir_var_map)
-        new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs, tir_var_map)
+        primfunc_name = kwargs.pop("primfunc_name_hint", None)
+        tir_func, call_args, output_sinfo, tir_vars = 
gen_call_tir_inputs(func, *args, **kwargs)
+        if not primfunc_name:
+            primfunc_name = func.__name__
+        gvar = self.add_func(tir_func, primfunc_name)
 
-        te_args = te_arg_list + te_kwarg_list
-
-        te_out = func(*new_args, **new_kwargs)
-        assert isinstance(te_out, tvm.te.tensor.Tensor) or (
-            isinstance(te_out, (tuple, list, tvm.ir.Array))
-            and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
-        ), "only support te.tensor or tuple/list/Array of te.tensor as 
function output"
-
-        outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else 
list(te_out)
-        unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)
-
-        inputs = [*te_args] + outs
-        tir_func = tvm.te.create_relax_prim_func(inputs, unbound_tir_vars, 
"int64")
-
-        tir_func = tir_func.without_attr("global_symbol")
-
-        if primfunc_name_hint:
-            gvar = self.add_func(tir_func, primfunc_name_hint)
-        else:
-            gvar = self.add_func(tir_func, func.__name__)
-
-        call_args = [x.op.value for x in te_args]
-
-        def _shape_with_old_tir_var(
-            shape_values: List[tir.PrimExpr], tir_var_inverse_map: 
Dict[tir.Var, tir.PrimExpr]
-        ):
-            return ShapeExpr(
-                [tir.stmt_functor.substitute(value, tir_var_inverse_map) for 
value in shape_values]
-            )
-
-        # Invert the TIR variable mapping, to convert the output shape back
-        # with old set of variables.
-        tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
-
-        output_sinfo = [
-            TensorStructInfo(_shape_with_old_tir_var(out.shape, 
tir_var_inverse_map), out.dtype)
-            for out in outs
-        ]
-
-        # add arguments for extra parameters from unbound var
-        if len(unbound_tir_vars) > 0:
-            call = call_tir(
-                gvar,
-                call_args,
-                output_sinfo,
-                tir_vars=_shape_with_old_tir_var(unbound_tir_vars, 
tir_var_inverse_map),
-            )
-        else:
-            call = call_tir(gvar, call_args, output_sinfo)
-        return call
+        return call_tir(gvar, call_args, output_sinfo, tir_vars)
 
     def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
         """Emit a call node according to the te function.
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index d6b405f183..587097d689 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -14,17 +14,22 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=invalid-name,too-many-locals
 """Utility functions for Relax"""
 import functools
 import inspect
-from typing import Any, Callable, List, Optional, TypeVar
+from typing import Tuple as typing_Tuple
+from typing import Any, Callable, List, Dict, Optional, TypeVar
 
 from .. import tir
-from ..runtime import String, convert_to_object
 from ..tir import PrimExpr
+from ..runtime import String, convert_to_object
 from . import _ffi_api
-from .expr import Expr, Function, PrimValue, StringImm
 from .expr import Tuple as rx_Tuple
+from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
+from ..te import Tensor as te_Tensor, create_relax_prim_func
+from ..ir import Array, Attrs, Type, Map
+from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
 
 
 def metadata_partitioner(rx_txt: str) -> List[str]:
@@ -268,3 +273,173 @@ def copy_with_new_vars(func: Function) -> Function:
         The copied function.
     """
     return _ffi_api.CopyWithNewVars(func)  # type: ignore
+
+
+def gen_call_tir_inputs(
+    func: Callable, *args: Any, **kwargs: Any
+) -> typing_Tuple[tir.PrimFunc, Expr, List[TensorStructInfo], 
Optional[ShapeExpr]]:
+    """Generate the inputs for call_tir according to the te function.
+    This function converts arguments from relax expression to te tensor,
+    The callback func should return a te tensor or a list of te tensors.
+
+    Parameters
+    ----------
+    func : Callable
+        A function that returns a te tensor or a list of te tensors.
+
+    args : Any, optional
+        arguments passed to the function.
+
+    kwargs : Any, optional
+        The keyword arguments passed to the function.
+
+    Returns
+    -------
+    ret : Tuple[tir.PrimFunc, Expr, List[TensorStructInfo], 
Optional[ShapeExpr]]
+        ret contains the inputs for call_tir, including a tir prim_func, args,
+        out_sinfo, and tir_vars.
+    """
+
+    def _convert_te_arg(
+        te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
+    ) -> typing_Tuple[Any, List[te_Tensor]]:
+        """Helper function used to convert Relax expressions to TE tensor.
+
+        In the common case, the type of te_args is a Relax expression and is 
converted
+        into a TE tensor.
+        If te_args is a nested or recursive datatype (i.e list, dict, 
tvm.ir.Map, tvm.ir.Array),
+        we recursive and convert any value of type Relax expression into a TE 
tensor.
+        Common values of type int, float, and str are preserved.
+
+        In dynamic shape cases, the passed in arguments may contain TIR 
variable.
+        For example, the argument can be a Relax Var with TensorStructInfo, 
which
+        has symbolic shape, or the argument can be a ShapeExpr with symbolic 
variables.
+        To make the PrimFunc generated has independent variables with
+        the caller Relax function, we will substitute the TIR variables in the 
input
+        arguments with fresh ones, which is done by maintaining a TIR variable 
mapping.
+
+        Parameters
+        ----------
+        te_args : Any
+            Argument to convert to TE
+
+        tir_var_map : Dict[tir.Var, tir.PrimExpr]
+            The TIR variable mapping, which maps TIR variables on the Relax 
function
+            side to the new set of variables used on the PrimFunc side.
+
+        Returns
+        -------
+        ret : (Any, [tvm.te.Tensor])
+            A tuple of the converted te_args, and a list of te tensors for 
each converted
+            Relax expression
+        """
+        te_args_list = []
+
+        def _copy_undefined_var(expr: tir.PrimExpr):
+            def _visit_expr(e: tir.PrimExpr):
+                if isinstance(e, tir.Var) and e not in tir_var_map:
+                    new_var = tir.Var(e.name, e.dtype)
+                    tir_var_map[e] = new_var
+
+            tir.stmt_functor.post_order_visit(expr, _visit_expr)
+
+        def _convert_te_arg_helper(arg):
+            if isinstance(arg, Expr):  # type: ignore
+                if isinstance(arg.struct_info, TensorStructInfo):
+                    assert isinstance(
+                        arg.struct_info.shape, ShapeExpr
+                    ), "emit_te now only supports Tensor that has ShapeExpr 
shape"
+                    for shape_value in arg.struct_info.shape.values:
+                        _copy_undefined_var(shape_value)
+
+                    arg = te_tensor(arg, tir_var_map)
+                    te_args_list.append(arg)
+                    return arg
+                if isinstance(arg.struct_info, ShapeStructInfo):
+                    assert isinstance(
+                        arg, ShapeExpr
+                    ), "For Expr having ShapeStructInfo, emit_te now only 
supports ShapeExpr"
+                    return [_convert_te_arg_helper(val) for val in arg.values]
+                if isinstance(arg.struct_info, PrimStructInfo):
+                    return arg.value
+            elif isinstance(arg, (list, Array)):
+                return [_convert_te_arg_helper(x) for x in arg]
+            elif isinstance(arg, tuple):
+                return tuple(_convert_te_arg_helper(x) for x in arg)
+            elif isinstance(arg, (dict, Map)):
+                for key in arg:
+                    assert isinstance(
+                        key, str
+                    ), "emit_te only supports dict with string as the key 
currently"
+                return {k: _convert_te_arg_helper(arg[k]) for k in arg}
+            elif isinstance(arg, tir.PrimExpr):
+                _copy_undefined_var(arg)
+                return tir.stmt_functor.substitute(arg, tir_var_map)
+            elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is 
None:
+                return arg
+            raise TypeError("not supported type in emit_te: 
{}".format(type(arg)))
+
+        new_arg = _convert_te_arg_helper(te_args)
+        return new_arg, te_args_list
+
+    def _get_unbound_tir_vars(args: List[te_Tensor]) -> List[tir.Var]:
+        """get unbound TIR vars (i.e TIR vars used in the shape but is not
+        itself a dimension of a shape)"""
+        bound_vars = set()
+        used_vars = set()
+
+        def _populate_used_vars(expr):
+            if isinstance(expr, tir.Var):
+                used_vars.add(expr)
+
+        for x in args:
+            for s in x.shape:
+                tir.stmt_functor.post_order_visit(s, _populate_used_vars)
+                if isinstance(s, tir.Var):
+                    bound_vars.add(s)
+
+        diff = used_vars - bound_vars
+        return list(diff)
+
+    def _shape_with_old_tir_var(
+        shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, 
tir.PrimExpr]
+    ):
+        return ShapeExpr(
+            [tir.stmt_functor.substitute(value, tir_var_inverse_map) for value 
in shape_values]
+        )
+
+    tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
+    new_args, te_arg_list = _convert_te_arg(args, tir_var_map)
+    new_kwargs, te_kwarg_list = _convert_te_arg(kwargs, tir_var_map)
+
+    te_args = te_arg_list + te_kwarg_list
+
+    te_out = func(*new_args, **new_kwargs)
+    assert isinstance(te_out, te_Tensor) or (
+        isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, 
te_Tensor) for t in te_out)
+    ), "only support te.tensor or tuple/list/Array of te.tensor as function 
output"
+
+    outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
+    unbound_tir_vars = _get_unbound_tir_vars(te_args + outs)
+
+    inputs = [*te_args] + outs
+    tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64")
+
+    tir_func = tir_func.without_attr("global_symbol")
+
+    call_tir_args = [x.op.value for x in te_args]
+
+    # Invert the TIR variable mapping, to convert the output shape back
+    # with old set of variables.
+    tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
+
+    output_sinfo = [
+        TensorStructInfo(_shape_with_old_tir_var(out.shape, 
tir_var_inverse_map), out.dtype)
+        for out in outs
+    ]
+
+    tir_vars = None
+    if len(unbound_tir_vars) > 0:
+        tir_vars = _shape_with_old_tir_var(unbound_tir_vars, 
tir_var_inverse_map)
+
+    return (tir_func, call_tir_args, output_sinfo, tir_vars)
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 03f1c1db46..9ef403181b 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -25,8 +25,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union, 
Callable
 import tvm
 from tvm import DataType, relax
 from tvm.ir import PrimExpr
+from ..ir import decl_function
 from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, ShapeExpr, Var, 
VarBinding, const
-from tvm.relax.block_builder import BlockBuilder as rx_bb
+from tvm.relax.utils import gen_call_tir_inputs
+
 
 ############################### Operators ###############################
 from tvm.relax.op import (
@@ -310,7 +312,7 @@ invoke_closure = _sinfo_arg_wrapper(invoke_closure)  # 
pylint: disable=invalid-n
 call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx)  # pylint: 
disable=invalid-name
 
 
-############################### Bindings ###############################
+############################### Emits ###############################
 
 
 def emit(value: Expr, annotate_struct_info: Optional[StructInfo] = None) -> 
Var:
@@ -331,7 +333,7 @@ def emit(value: Expr, annotate_struct_info: 
Optional[StructInfo] = None) -> Var:
     return _ffi_api.Emit(value, annotate_struct_info)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
-def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Var:
+def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call:
     """Emit a call node according to the te function.
     This function converts arguments from relax expression to te tensor,
     The callback func should return a te tensor or a list of te tensors.
@@ -351,13 +353,15 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> 
Var:
 
     Returns
     -------
-    var : Var
-        A newly created variable that gets bound to the call code.
+    call : Call
+        A newly created call that calls into a tir function.
     """
-
-    # Levarage the util function call_te in Relax Block Blocker
-    emit_expr = rx_bb().call_te(func, *args, **kwargs)
-    return emit(emit_expr)
+    primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
+    tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func, 
*args, **kwargs)
+    if not primfunc_name_hint:
+        primfunc_name_hint = func.__name__
+    gvar = decl_function(primfunc_name_hint, tir_func)  # type: ignore
+    return call_tir(gvar, call_args, out_sinfo, tir_vars)
 
 
 def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var:
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 148e90b28c..6521df265e 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -36,7 +36,7 @@ IRModuleFrame IRModule() {
 }
 
 GlobalVar DeclFunction(const String& func_name, const BaseFunc& 
func_signature) {
-  IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
+  IRModuleFrame frame = FindModuleFrame();
   CHECK(!frame->global_var_map.count(func_name))
       << "ValueError: function " << func_name << " already exists";
   GlobalVar gv = GlobalVar(func_name);
@@ -58,7 +58,7 @@ GlobalVar DeclFunction(const String& func_name, const 
BaseFunc& func_signature)
 }
 
 void DefFunction(const String& func_name, const BaseFunc& func) {
-  IRModuleFrame frame = FindModuleFrame("I.DefFunction");
+  IRModuleFrame frame = FindModuleFrame();
   auto it = frame->global_var_map.find(func_name);
   CHECK(it != frame->global_var_map.end())
       << "ValueError: function " << func_name << " does not exist, please 
declare it first.";
diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h
index 58d5e53f70..b12e5e270d 100644
--- a/src/script/ir_builder/ir/utils.h
+++ b/src/script/ir_builder/ir/utils.h
@@ -41,6 +41,17 @@ inline IRModuleFrame FindModuleFrame(const String& method) {
   throw;
 }
 
+inline IRModuleFrame FindModuleFrame() {
+  IRBuilder builder = IRBuilder::Current();
+  if (Optional<IRModuleFrame> frame = builder->FindFrame<IRModuleFrame>()) {
+    return frame.value();
+  } else {
+    LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure it"
+               << " is called under I.ir_module()";
+  }
+  throw;
+}
+
 }  // namespace ir
 }  // namespace ir_builder
 }  // namespace script
diff --git a/tests/python/relax/test_tvmscript_ir_builder.py 
b/tests/python/relax/test_tvmscript_ir_builder.py
index 014b00af00..f7c29b8dbe 100644
--- a/tests/python/relax/test_tvmscript_ir_builder.py
+++ b/tests/python/relax/test_tvmscript_ir_builder.py
@@ -58,7 +58,7 @@ def test_function_simple():
 
 
 def test_emits():
-    """Tests for R.emit, R.emit_match_cast, R.emit_var_binding, R.emit_te
+    """Tests for R.emit, R.emit_match_cast, R.emit_var_binding
 
     @R.function
     def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> 
R.Shape(ndim=2):
@@ -67,8 +67,6 @@ def test_emits():
         gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), 
dtype="float32"))
         gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,), 
dtype="float32"))
         v: R.Tensor((n,), dtype="float32") = gv1
-        gv2 = R.call_tir(add, (v, v), out_sinfo=R.Tensor((n,), 
dtype="float32"))
-        gv3: R.Tensor((n,), dtype="float32") = gv2
         return R.shape([m, n * 2])
     """
     # create with Script IRBuilder
@@ -84,8 +82,7 @@ def test_emits():
             v = relax.Var("v", relax.TensorStructInfo((n,), "float32"))
             vb = relax.VarBinding(v, y1)
             v = R.emit_var_binding(vb)
-            v1 = R.emit_te(topi.add, v, v)
-            R.emit(v1)
+            R.emit(v)
 
             IRBuilder.name("v", v)
             R.func_ret_value(relax.ShapeExpr([m, n * 2]))
@@ -102,12 +99,11 @@ def test_emits():
         _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32"))
         y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32"))
         bb.emit_normalized(relax.VarBinding(v, y1))
-        v1 = bb.emit_te(topi.add, v, v)
-        bb.emit(v1)
+        bb.emit(v)
         bb.emit_func_output(relax.ShapeExpr([m, n * 2]))
     mod = bb.get()
 
-    tvm.ir.assert_structural_equal(func, mod["foo"], map_free_vars=True)
+    tvm.ir.assert_structural_equal(func, mod["foo"])
 
 
 def test_dataflow_block():
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 65d0115915..c57460902f 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -184,6 +184,25 @@ def test_simple_module():
     _check(TestModule, bb.get())
 
 
+def test_emit_te():
+    @I.ir_module
+    class EmitTE:
+        @R.function
+        def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), 
dtype="float32"):
+            lv1 = R.emit_te(topi.add, x, x)
+            out = R.emit_te(topi.multiply, lv1, lv1)
+            return out
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32"))
+    with bb.function("main", [x]):
+        lv1 = bb.emit_te(topi.add, x, x)
+        out = bb.emit_te(topi.multiply, lv1, lv1)
+        bb.emit_func_output(out)
+
+    _check(EmitTE, bb.get())
+
+
 def test_module_with_attr_and_global_info():
     @I.ir_module
     class TestModule:

Reply via email to