This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9879263c73 [Unity][TVMScript] Fix prim_func lost issue in
relax.emit_te (#14189)
9879263c73 is described below
commit 9879263c7399f806dc31ce24190636b14bfd599a
Author: Yong Wu <[email protected]>
AuthorDate: Wed Mar 8 05:53:21 2023 -0800
[Unity][TVMScript] Fix prim_func lost issue in relax.emit_te (#14189)
Previously R.emit_te was introduced in
https://github.com/apache/tvm/pull/14123. The `prim_func`s were not added into
the same `ir_module`. The pr is to fix this issue, move some `call_tir` input
handling code from `bb.call_te` to utils, then it is able to be leveraged by
both `bb.emit_te` and `R.emit_te`.
---
python/tvm/relax/block_builder.py | 167 ++--------------------
python/tvm/relax/utils.py | 181 +++++++++++++++++++++++-
python/tvm/script/ir_builder/relax/ir.py | 22 +--
src/script/ir_builder/ir/ir.cc | 4 +-
src/script/ir_builder/ir/utils.h | 11 ++
tests/python/relax/test_tvmscript_ir_builder.py | 12 +-
tests/python/relax/test_tvmscript_parser.py | 19 +++
7 files changed, 235 insertions(+), 181 deletions(-)
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index 3421bd4d09..ebf35b5765 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -16,7 +16,6 @@
# under the License.
# pylint: disable=no-else-return, invalid-name
"""Developer API of constructing Relax AST."""
-import typing
from typing import Dict, List, Optional, Union, Any, Callable
from tvm.ir.module import IRModule
@@ -25,18 +24,17 @@ from tvm import relax as rx, tir
import tvm
from .expr import (
Expr,
- te_tensor,
Var,
- ShapeExpr,
GlobalVar,
BindingBlock,
Tuple,
BaseFunc,
Binding,
)
-from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo,
TensorStructInfo
+from .struct_info import StructInfo
from .op.base import call_tir
from . import _ffi_api
+from .utils import gen_call_tir_inputs
class FunctionScope(object):
@@ -196,107 +194,6 @@ class BlockBuilder(Object):
if not is_emit_func_output_called:
raise RuntimeError("emit_func_output must be called in a relax
function.")
- def _convert_te_arg(
- self, te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
- ) -> typing.Tuple[Any, List[tvm.te.Tensor]]:
- """Helper function used by `call_te` to convert Relax expressions to
TE tensor.
-
- In the common case, the type of te_args is a Relax expression and is
converted
- into a TE tensor.
- If te_args is a nested or recursive datatype (i.e list, dict,
tvm.ir.Map, tvm.ir.Array),
- we recursive and convert any value of type Relax expression into a TE
tensor.
- Common values of type int, float, and str are preserved.
-
- In dynamic shape cases, the passed in arguments may contain TIR
variable.
- For example, the argument can be a Relax Var with TensorStructInfo,
which
- has symbolic shape, or the argument can be a ShapeExpr with symbolic
variables.
- To make the PrimFunc generated by `call_te` has independent variables
with
- the caller Relax function, we will substitute the TIR variables in the
input
- arguments with fresh ones, which is done by maintaining a TIR variable
mapping.
-
- Parameters
- ----------
- te_args : Any
- Argument to convert to TE
-
- tir_var_map : Dict[tir.Var, tir.PrimExpr]
- The TIR variable mapping, which maps TIR variables on the Relax
function
- side to the new set of variables used on the PrimFunc side.
-
- Returns
- -------
- ret : (Any, [tvm.te.Tensor])
- A tuple of the converted te_args, and a list of te tensors for
each converted
- Relax expression
- """
- te_args_list = []
-
- def _copy_undefined_var(expr: tir.PrimExpr):
- def _visit_expr(e: tir.PrimExpr):
- if isinstance(e, tir.Var) and e not in tir_var_map:
- new_var = tir.Var(e.name, e.dtype)
- tir_var_map[e] = new_var
-
- tir.stmt_functor.post_order_visit(expr, _visit_expr)
-
- def _convert_te_arg_helper(arg):
- if isinstance(arg, Expr): # type: ignore
- if isinstance(arg.struct_info, TensorStructInfo):
- assert isinstance(
- arg.struct_info.shape, ShapeExpr
- ), "emit_te now only supports Tensor that has ShapeExpr
shape"
- for shape_value in arg.struct_info.shape.values:
- _copy_undefined_var(shape_value)
-
- arg = te_tensor(arg, tir_var_map)
- te_args_list.append(arg)
- return arg
- elif isinstance(arg.struct_info, ShapeStructInfo):
- assert isinstance(
- arg, ShapeExpr
- ), "For Expr having ShapeStructInfo, emit_te now only
supports ShapeExpr"
- return [_convert_te_arg_helper(val) for val in arg.values]
- elif isinstance(arg.struct_info, PrimStructInfo):
- return arg.value
- elif isinstance(arg, (list, tvm.ir.Array)):
- return [_convert_te_arg_helper(x) for x in arg]
- elif isinstance(arg, tuple):
- return tuple([_convert_te_arg_helper(x) for x in arg])
- elif isinstance(arg, (dict, tvm.ir.Map)):
- for key in arg:
- assert isinstance(
- key, str
- ), "emit_te only supports dict with string as the key
currently"
- return {k: _convert_te_arg_helper(arg[k]) for k in arg}
- elif isinstance(arg, tir.PrimExpr):
- _copy_undefined_var(arg)
- return tir.stmt_functor.substitute(arg, tir_var_map)
- elif isinstance(arg, (int, float, str, tvm.ir.Type, tvm.ir.Attrs))
or arg is None:
- return arg
- raise TypeError("not supported type in emit_te:
{}".format(type(arg)))
-
- new_arg = _convert_te_arg_helper(te_args)
- return new_arg, te_args_list
-
- def _get_unbound_tir_vars(self, args: List[tvm.te.Tensor]) ->
List[tvm.tir.Var]:
- """get unbound TIR vars (i.e TIR vars used in the shape but is not
- itself a dimension of a shape)"""
- bound_vars = set()
- used_vars = set()
-
- def _populate_used_vars(expr):
- if isinstance(expr, tvm.tir.Var):
- used_vars.add(expr)
-
- for x in args:
- for s in x.shape:
- tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars)
- if isinstance(s, tir.Var):
- bound_vars.add(s)
-
- diff = used_vars - bound_vars
- return list(diff)
-
def function(
self,
name: str,
@@ -410,61 +307,13 @@ class BlockBuilder(Object):
A newly created call node
"""
- primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
- tir_var_map: Dict[tir.Var, tir.PrimExpr] = dict()
- new_args, te_arg_list = self._convert_te_arg(args, tir_var_map)
- new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs, tir_var_map)
+ primfunc_name = kwargs.pop("primfunc_name_hint", None)
+ tir_func, call_args, output_sinfo, tir_vars =
gen_call_tir_inputs(func, *args, **kwargs)
+ if not primfunc_name:
+ primfunc_name = func.__name__
+ gvar = self.add_func(tir_func, primfunc_name)
- te_args = te_arg_list + te_kwarg_list
-
- te_out = func(*new_args, **new_kwargs)
- assert isinstance(te_out, tvm.te.tensor.Tensor) or (
- isinstance(te_out, (tuple, list, tvm.ir.Array))
- and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out)
- ), "only support te.tensor or tuple/list/Array of te.tensor as
function output"
-
- outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else
list(te_out)
- unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs)
-
- inputs = [*te_args] + outs
- tir_func = tvm.te.create_relax_prim_func(inputs, unbound_tir_vars,
"int64")
-
- tir_func = tir_func.without_attr("global_symbol")
-
- if primfunc_name_hint:
- gvar = self.add_func(tir_func, primfunc_name_hint)
- else:
- gvar = self.add_func(tir_func, func.__name__)
-
- call_args = [x.op.value for x in te_args]
-
- def _shape_with_old_tir_var(
- shape_values: List[tir.PrimExpr], tir_var_inverse_map:
Dict[tir.Var, tir.PrimExpr]
- ):
- return ShapeExpr(
- [tir.stmt_functor.substitute(value, tir_var_inverse_map) for
value in shape_values]
- )
-
- # Invert the TIR variable mapping, to convert the output shape back
- # with old set of variables.
- tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
-
- output_sinfo = [
- TensorStructInfo(_shape_with_old_tir_var(out.shape,
tir_var_inverse_map), out.dtype)
- for out in outs
- ]
-
- # add arguments for extra parameters from unbound var
- if len(unbound_tir_vars) > 0:
- call = call_tir(
- gvar,
- call_args,
- output_sinfo,
- tir_vars=_shape_with_old_tir_var(unbound_tir_vars,
tir_var_inverse_map),
- )
- else:
- call = call_tir(gvar, call_args, output_sinfo)
- return call
+ return call_tir(gvar, call_args, output_sinfo, tir_vars)
def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
"""Emit a call node according to the te function.
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index d6b405f183..587097d689 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -14,17 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=invalid-name,too-many-locals
"""Utility functions for Relax"""
import functools
import inspect
-from typing import Any, Callable, List, Optional, TypeVar
+from typing import Tuple as typing_Tuple
+from typing import Any, Callable, List, Dict, Optional, TypeVar
from .. import tir
-from ..runtime import String, convert_to_object
from ..tir import PrimExpr
+from ..runtime import String, convert_to_object
from . import _ffi_api
-from .expr import Expr, Function, PrimValue, StringImm
from .expr import Tuple as rx_Tuple
+from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
+from ..te import Tensor as te_Tensor, create_relax_prim_func
+from ..ir import Array, Attrs, Type, Map
+from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
def metadata_partitioner(rx_txt: str) -> List[str]:
@@ -268,3 +273,173 @@ def copy_with_new_vars(func: Function) -> Function:
The copied function.
"""
return _ffi_api.CopyWithNewVars(func) # type: ignore
+
+
+def gen_call_tir_inputs(
+ func: Callable, *args: Any, **kwargs: Any
+) -> typing_Tuple[tir.PrimFunc, Expr, List[TensorStructInfo],
Optional[ShapeExpr]]:
+ """Generate the inputs for call_tir according to the te function.
+ This function converts arguments from relax expression to te tensor,
+ The callback func should return a te tensor or a list of te tensors.
+
+ Parameters
+ ----------
+ func : Callable
+ A function that returns a te tensor or a list of te tensors.
+
+ args : Any, optional
+ arguments passed to the function.
+
+ kwargs : Any, optional
+ The keyword arguments passed to the function.
+
+ Returns
+ -------
+ ret : Tuple[tir.PrimFunc, Expr, List[TensorStructInfo],
Optional[ShapeExpr]]
+ ret contains the inputs for call_tir, including a tir prim_func, args,
+ out_sinfo, and tir_vars.
+ """
+
+ def _convert_te_arg(
+ te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
+ ) -> typing_Tuple[Any, List[te_Tensor]]:
+ """Helper function used to convert Relax expressions to TE tensor.
+
+ In the common case, the type of te_args is a Relax expression and is
converted
+ into a TE tensor.
+ If te_args is a nested or recursive datatype (i.e list, dict,
tvm.ir.Map, tvm.ir.Array),
+ we recursive and convert any value of type Relax expression into a TE
tensor.
+ Common values of type int, float, and str are preserved.
+
+ In dynamic shape cases, the passed in arguments may contain TIR
variable.
+ For example, the argument can be a Relax Var with TensorStructInfo,
which
+ has symbolic shape, or the argument can be a ShapeExpr with symbolic
variables.
+ To make the PrimFunc generated has independent variables with
+ the caller Relax function, we will substitute the TIR variables in the
input
+ arguments with fresh ones, which is done by maintaining a TIR variable
mapping.
+
+ Parameters
+ ----------
+ te_args : Any
+ Argument to convert to TE
+
+ tir_var_map : Dict[tir.Var, tir.PrimExpr]
+ The TIR variable mapping, which maps TIR variables on the Relax
function
+ side to the new set of variables used on the PrimFunc side.
+
+ Returns
+ -------
+ ret : (Any, [tvm.te.Tensor])
+ A tuple of the converted te_args, and a list of te tensors for
each converted
+ Relax expression
+ """
+ te_args_list = []
+
+ def _copy_undefined_var(expr: tir.PrimExpr):
+ def _visit_expr(e: tir.PrimExpr):
+ if isinstance(e, tir.Var) and e not in tir_var_map:
+ new_var = tir.Var(e.name, e.dtype)
+ tir_var_map[e] = new_var
+
+ tir.stmt_functor.post_order_visit(expr, _visit_expr)
+
+ def _convert_te_arg_helper(arg):
+ if isinstance(arg, Expr): # type: ignore
+ if isinstance(arg.struct_info, TensorStructInfo):
+ assert isinstance(
+ arg.struct_info.shape, ShapeExpr
+ ), "emit_te now only supports Tensor that has ShapeExpr
shape"
+ for shape_value in arg.struct_info.shape.values:
+ _copy_undefined_var(shape_value)
+
+ arg = te_tensor(arg, tir_var_map)
+ te_args_list.append(arg)
+ return arg
+ if isinstance(arg.struct_info, ShapeStructInfo):
+ assert isinstance(
+ arg, ShapeExpr
+ ), "For Expr having ShapeStructInfo, emit_te now only
supports ShapeExpr"
+ return [_convert_te_arg_helper(val) for val in arg.values]
+ if isinstance(arg.struct_info, PrimStructInfo):
+ return arg.value
+ elif isinstance(arg, (list, Array)):
+ return [_convert_te_arg_helper(x) for x in arg]
+ elif isinstance(arg, tuple):
+ return tuple(_convert_te_arg_helper(x) for x in arg)
+ elif isinstance(arg, (dict, Map)):
+ for key in arg:
+ assert isinstance(
+ key, str
+ ), "emit_te only supports dict with string as the key
currently"
+ return {k: _convert_te_arg_helper(arg[k]) for k in arg}
+ elif isinstance(arg, tir.PrimExpr):
+ _copy_undefined_var(arg)
+ return tir.stmt_functor.substitute(arg, tir_var_map)
+ elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is
None:
+ return arg
+ raise TypeError("not supported type in emit_te:
{}".format(type(arg)))
+
+ new_arg = _convert_te_arg_helper(te_args)
+ return new_arg, te_args_list
+
+ def _get_unbound_tir_vars(args: List[te_Tensor]) -> List[tir.Var]:
+ """get unbound TIR vars (i.e TIR vars used in the shape but is not
+ itself a dimension of a shape)"""
+ bound_vars = set()
+ used_vars = set()
+
+ def _populate_used_vars(expr):
+ if isinstance(expr, tir.Var):
+ used_vars.add(expr)
+
+ for x in args:
+ for s in x.shape:
+ tir.stmt_functor.post_order_visit(s, _populate_used_vars)
+ if isinstance(s, tir.Var):
+ bound_vars.add(s)
+
+ diff = used_vars - bound_vars
+ return list(diff)
+
+ def _shape_with_old_tir_var(
+ shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var,
tir.PrimExpr]
+ ):
+ return ShapeExpr(
+ [tir.stmt_functor.substitute(value, tir_var_inverse_map) for value
in shape_values]
+ )
+
+ tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
+ new_args, te_arg_list = _convert_te_arg(args, tir_var_map)
+ new_kwargs, te_kwarg_list = _convert_te_arg(kwargs, tir_var_map)
+
+ te_args = te_arg_list + te_kwarg_list
+
+ te_out = func(*new_args, **new_kwargs)
+ assert isinstance(te_out, te_Tensor) or (
+ isinstance(te_out, (tuple, list, Array)) and all(isinstance(t,
te_Tensor) for t in te_out)
+ ), "only support te.tensor or tuple/list/Array of te.tensor as function
output"
+
+ outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
+ unbound_tir_vars = _get_unbound_tir_vars(te_args + outs)
+
+ inputs = [*te_args] + outs
+ tir_func = create_relax_prim_func(inputs, unbound_tir_vars, "int64")
+
+ tir_func = tir_func.without_attr("global_symbol")
+
+ call_tir_args = [x.op.value for x in te_args]
+
+ # Invert the TIR variable mapping, to convert the output shape back
+ # with old set of variables.
+ tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
+
+ output_sinfo = [
+ TensorStructInfo(_shape_with_old_tir_var(out.shape,
tir_var_inverse_map), out.dtype)
+ for out in outs
+ ]
+
+ tir_vars = None
+ if len(unbound_tir_vars) > 0:
+ tir_vars = _shape_with_old_tir_var(unbound_tir_vars,
tir_var_inverse_map)
+
+ return (tir_func, call_tir_args, output_sinfo, tir_vars)
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 03f1c1db46..9ef403181b 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -25,8 +25,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union,
Callable
import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr
+from ..ir import decl_function
from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, ShapeExpr, Var,
VarBinding, const
-from tvm.relax.block_builder import BlockBuilder as rx_bb
+from tvm.relax.utils import gen_call_tir_inputs
+
############################### Operators ###############################
from tvm.relax.op import (
@@ -310,7 +312,7 @@ invoke_closure = _sinfo_arg_wrapper(invoke_closure) #
pylint: disable=invalid-n
call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint:
disable=invalid-name
-############################### Bindings ###############################
+############################### Emits ###############################
def emit(value: Expr, annotate_struct_info: Optional[StructInfo] = None) ->
Var:
@@ -331,7 +333,7 @@ def emit(value: Expr, annotate_struct_info:
Optional[StructInfo] = None) -> Var:
return _ffi_api.Emit(value, annotate_struct_info) # type:
ignore[attr-defined] # pylint: disable=no-member
-def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Var:
+def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call:
"""Emit a call node according to the te function.
This function converts arguments from relax expression to te tensor,
The callback func should return a te tensor or a list of te tensors.
@@ -351,13 +353,15 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) ->
Var:
Returns
-------
- var : Var
- A newly created variable that gets bound to the call code.
+ call : Call
+ A newly created call that calls into a tir function.
"""
-
- # Levarage the util function call_te in Relax Block Blocker
- emit_expr = rx_bb().call_te(func, *args, **kwargs)
- return emit(emit_expr)
+ primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
+ tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func,
*args, **kwargs)
+ if not primfunc_name_hint:
+ primfunc_name_hint = func.__name__
+ gvar = decl_function(primfunc_name_hint, tir_func) # type: ignore
+ return call_tir(gvar, call_args, out_sinfo, tir_vars)
def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var:
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index 148e90b28c..6521df265e 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -36,7 +36,7 @@ IRModuleFrame IRModule() {
}
GlobalVar DeclFunction(const String& func_name, const BaseFunc&
func_signature) {
- IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
+ IRModuleFrame frame = FindModuleFrame();
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
GlobalVar gv = GlobalVar(func_name);
@@ -58,7 +58,7 @@ GlobalVar DeclFunction(const String& func_name, const
BaseFunc& func_signature)
}
void DefFunction(const String& func_name, const BaseFunc& func) {
- IRModuleFrame frame = FindModuleFrame("I.DefFunction");
+ IRModuleFrame frame = FindModuleFrame();
auto it = frame->global_var_map.find(func_name);
CHECK(it != frame->global_var_map.end())
<< "ValueError: function " << func_name << " does not exist, please
declare it first.";
diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h
index 58d5e53f70..b12e5e270d 100644
--- a/src/script/ir_builder/ir/utils.h
+++ b/src/script/ir_builder/ir/utils.h
@@ -41,6 +41,17 @@ inline IRModuleFrame FindModuleFrame(const String& method) {
throw;
}
+inline IRModuleFrame FindModuleFrame() {
+ IRBuilder builder = IRBuilder::Current();
+ if (Optional<IRModuleFrame> frame = builder->FindFrame<IRModuleFrame>()) {
+ return frame.value();
+ } else {
+ LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure it"
+ << " is called under I.ir_module()";
+ }
+ throw;
+}
+
} // namespace ir
} // namespace ir_builder
} // namespace script
diff --git a/tests/python/relax/test_tvmscript_ir_builder.py
b/tests/python/relax/test_tvmscript_ir_builder.py
index 014b00af00..f7c29b8dbe 100644
--- a/tests/python/relax/test_tvmscript_ir_builder.py
+++ b/tests/python/relax/test_tvmscript_ir_builder.py
@@ -58,7 +58,7 @@ def test_function_simple():
def test_emits():
- """Tests for R.emit, R.emit_match_cast, R.emit_var_binding, R.emit_te
+ """Tests for R.emit, R.emit_match_cast, R.emit_var_binding
@R.function
def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) ->
R.Shape(ndim=2):
@@ -67,8 +67,6 @@ def test_emits():
gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,),
dtype="float32"))
gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,),
dtype="float32"))
v: R.Tensor((n,), dtype="float32") = gv1
- gv2 = R.call_tir(add, (v, v), out_sinfo=R.Tensor((n,),
dtype="float32"))
- gv3: R.Tensor((n,), dtype="float32") = gv2
return R.shape([m, n * 2])
"""
# create with Script IRBuilder
@@ -84,8 +82,7 @@ def test_emits():
v = relax.Var("v", relax.TensorStructInfo((n,), "float32"))
vb = relax.VarBinding(v, y1)
v = R.emit_var_binding(vb)
- v1 = R.emit_te(topi.add, v, v)
- R.emit(v1)
+ R.emit(v)
IRBuilder.name("v", v)
R.func_ret_value(relax.ShapeExpr([m, n * 2]))
@@ -102,12 +99,11 @@ def test_emits():
_ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32"))
y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32"))
bb.emit_normalized(relax.VarBinding(v, y1))
- v1 = bb.emit_te(topi.add, v, v)
- bb.emit(v1)
+ bb.emit(v)
bb.emit_func_output(relax.ShapeExpr([m, n * 2]))
mod = bb.get()
- tvm.ir.assert_structural_equal(func, mod["foo"], map_free_vars=True)
+ tvm.ir.assert_structural_equal(func, mod["foo"])
def test_dataflow_block():
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 65d0115915..c57460902f 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -184,6 +184,25 @@ def test_simple_module():
_check(TestModule, bb.get())
+def test_emit_te():
+ @I.ir_module
+ class EmitTE:
+ @R.function
+ def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20),
dtype="float32"):
+ lv1 = R.emit_te(topi.add, x, x)
+ out = R.emit_te(topi.multiply, lv1, lv1)
+ return out
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32"))
+ with bb.function("main", [x]):
+ lv1 = bb.emit_te(topi.add, x, x)
+ out = bb.emit_te(topi.multiply, lv1, lv1)
+ bb.emit_func_output(out)
+
+ _check(EmitTE, bb.get())
+
+
def test_module_with_attr_and_global_info():
@I.ir_module
class TestModule: