This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ff21d66ab8 [Unity][TVMScript] emit_te sugar (#14123)
ff21d66ab8 is described below
commit ff21d66ab80dabb13a3cc43e26de56b3047cf8c4
Author: Yong Wu <[email protected]>
AuthorDate: Mon Feb 27 06:09:40 2023 -0800
[Unity][TVMScript] emit_te sugar (#14123)
This PR adds R.emit_te meta-programming mechanism to emit a topi operator
from TVMScript
---
python/tvm/script/ir_builder/relax/ir.py | 34 ++++++++++++++++++++++-
src/script/ir_builder/relax/ir.cc | 2 +-
tests/python/relax/test_tvmscript_ir_builder.py | 36 +++++++++++++++++--------
3 files changed, 59 insertions(+), 13 deletions(-)
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 63efea135c..045fe9ddd9 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -20,12 +20,13 @@
import builtins
import functools
import inspect
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr
from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding,
const
+from tvm.relax.block_builder import BlockBuilder as rx_bb
############################### Operators ###############################
from tvm.relax.op import (
@@ -304,6 +305,7 @@ invoke_closure = _sinfo_arg_wrapper(invoke_closure) #
pylint: disable=invalid-n
call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint:
disable=invalid-name
+
############################### Bindings ###############################
@@ -325,6 +327,35 @@ def emit(value: Expr, annotate_struct_info:
Optional[StructInfo] = None) -> Var:
return _ffi_api.Emit(value, annotate_struct_info) # type:
ignore[attr-defined] # pylint: disable=no-member
+def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Var:
+ """Emit a call node according to the te function.
+ This function converts arguments from relax expression to te tensor,
+ The callback func should return a te tensor or a list of te tensors.
+
+ Parameters
+ ----------
+ func : Callable
+ A function that returns a te tensor or a list of te tensors.
+
+ args : Any, optional
+ arguments passed to the function.
+
+ kwargs : Any, optional
+ The keyword arguments passed to the function.
+ Note that the key "primfunc_name_hint" is reserved for passing name
hint
+ to the PrimFunc that gets generated.
+
+ Returns
+ -------
+ var : Var
+ A newly created variable that gets bound to the call code.
+ """
+
+ # Levarage the util function call_te in Relax Block Blocker
+ emit_expr = rx_bb().call_te(func, *args, **kwargs)
+ return emit(emit_expr)
+
+
def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var:
"""Emit a match_cast binding to the last binding block frame.
Parameters
@@ -511,6 +542,7 @@ __all__ = [
"divide",
"dtype",
"emit",
+ "emit_te",
"emit_var_binding",
"emit_match_cast",
"equal",
diff --git a/src/script/ir_builder/relax/ir.cc
b/src/script/ir_builder/relax/ir.cc
index ddfb1ddfa3..71a0651de8 100644
--- a/src/script/ir_builder/relax/ir.cc
+++ b/src/script/ir_builder/relax/ir.cc
@@ -108,7 +108,7 @@ void FuncRetValue(const tvm::relax::Expr& value) {
if (block_frame.defined()) {
block_frame.value()->ExitWithScope();
ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
- << "All block frame are supposed to be popped out already";
+ << "ValueError: Relax functions don't support return in true/false
branch of If Node.";
}
// Step 2. Add the output value to the function frame.
FunctionFrame frame = FindFunctionFrame("return");
diff --git a/tests/python/relax/test_tvmscript_ir_builder.py
b/tests/python/relax/test_tvmscript_ir_builder.py
index eb0aaf5604..014b00af00 100644
--- a/tests/python/relax/test_tvmscript_ir_builder.py
+++ b/tests/python/relax/test_tvmscript_ir_builder.py
@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.testing
-from tvm import relax, tir
+from tvm import relax, tir, topi
from tvm.script.ir_builder import relax as R
from tvm.script.ir_builder.base import IRBuilder
@@ -57,15 +57,19 @@ def test_function_simple():
assert func.body.body.name_hint == "out"
-def test_match_cast():
- """
+def test_emits():
+ """Tests for R.emit, R.emit_match_cast, R.emit_var_binding, R.emit_te
+
@R.function
- def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")):
+ def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) ->
R.Shape(ndim=2):
m = T.int64()
n = T.int64()
- _ = R.match_cast(x, R.Tensor((m,), "float32"))
- y1 = R.match_cast(x, R.Tensor((n,), "float32"))
- return (m, n * 2)
+ gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,),
dtype="float32"))
+ gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,),
dtype="float32"))
+ v: R.Tensor((n,), dtype="float32") = gv1
+ gv2 = R.call_tir(add, (v, v), out_sinfo=R.Tensor((n,),
dtype="float32"))
+ gv3: R.Tensor((n,), dtype="float32") = gv2
+ return R.shape([m, n * 2])
"""
# create with Script IRBuilder
with IRBuilder() as ir_builder:
@@ -77,23 +81,33 @@ def test_match_cast():
n = tir.Var("n", dtype="int64")
_ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32"))
y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32"))
- IRBuilder.name("y1", y1)
+ v = relax.Var("v", relax.TensorStructInfo((n,), "float32"))
+ vb = relax.VarBinding(v, y1)
+ v = R.emit_var_binding(vb)
+ v1 = R.emit_te(topi.add, v, v)
+ R.emit(v1)
+
+ IRBuilder.name("v", v)
R.func_ret_value(relax.ShapeExpr([m, n * 2]))
func = ir_builder.get()
# create with BlockBuilder
- x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1))
- y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1))
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
+ x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1))
+ y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1))
+ v = relax.Var("v", relax.TensorStructInfo((n,), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x, y)):
_ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32"))
y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32"))
+ bb.emit_normalized(relax.VarBinding(v, y1))
+ v1 = bb.emit_te(topi.add, v, v)
+ bb.emit(v1)
bb.emit_func_output(relax.ShapeExpr([m, n * 2]))
mod = bb.get()
- tvm.ir.assert_structural_equal(func, mod["foo"])
+ tvm.ir.assert_structural_equal(func, mod["foo"], map_free_vars=True)
def test_dataflow_block():