This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 82578c394c [Unity][BlockBuilder] Add `name_hint` argument for `emit`
and `emit_output` (#14126)
82578c394c is described below
commit 82578c394c17ef3aeced530fbcbd18ee66ce6425
Author: Chaofan Lin <[email protected]>
AuthorDate: Sat Feb 25 14:08:12 2023 +0800
[Unity][BlockBuilder] Add `name_hint` argument for `emit` and `emit_output`
(#14126)
This PR adds `name_hint` argument for `emit` and `emit_output` API of Relax
blockbuilder. The argument exists in the C++ side but not exposed to Python
side (So user who use the Python bb.emit will let `name_hint` be `""` by
default).
Co-authored-by: Yixin Dong <[email protected]>
---
python/tvm/relax/block_builder.py | 19 ++++++++++---------
src/relax/ir/block_builder.cc | 11 ++++++-----
tests/python/relax/test_blockbuilder.py | 16 ++++++++++++++++
3 files changed, 32 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index f219641c81..3421bd4d09 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -365,7 +365,7 @@ class BlockBuilder(Object):
"""
return DataflowScope(self)
- def emit(self, expr: Expr) -> Var:
+ def emit(self, expr: Expr, name_hint: str = "") -> Var:
"""Emit an expr.
This infers the shape and type of the expr, create a variable,
and bind the expr to the variable.
@@ -375,12 +375,15 @@ class BlockBuilder(Object):
expr : tvm.relax.Expr
The Expr to be emitted.
+ name_hint : str
+ Name hint for the bound variable.
+
Returns
-------
ret : tvm.relax.Var
A newly created variable that gets bound to the input expr.
"""
- return _ffi_api.BlockBuilderEmit(self, expr) # type: ignore
+ return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore
def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
"""Generate a call node according to the te function.
@@ -601,7 +604,7 @@ class BlockBuilder(Object):
"""
return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) #
type: ignore
- def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
+ def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint:
str = "") -> Var:
"""Emit output for the current dataflow block or function.
Parameters
@@ -609,6 +612,9 @@ class BlockBuilder(Object):
output : Expr | Tuple | List[Expr]
The output of the current block/function.
+ name_hint : str
+ Name hint for the bound variable.
+
Returns
-------
ret : tvm.relax.Var
@@ -616,7 +622,7 @@ class BlockBuilder(Object):
"""
if isinstance(output, (list, tuple)):
output = Tuple(output)
- return _ffi_api.BlockBuilderEmitOutput(self, output) # type: ignore
+ return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint) #
type: ignore
def emit_func_output(
self,
@@ -633,11 +639,6 @@ class BlockBuilder(Object):
params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
The parameters of the function to be built.
If params is None, it means the params have been initialized in
the function with scope.
-
- Returns
- -------
- ret : tvm.relax.Var
- The return variable which gets bound to the output.
"""
if self._is_emit_func_output_called:
raise RuntimeError("emit_func_output must be called exactly once
in a relax function.")
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 5976cbb3f4..ac92114ef9 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -899,9 +899,10 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock")
TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize")
.set_body_method<BlockBuilder>(&BlockBuilderNode::Normalize);
-TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder
builder, Expr expr) {
- return builder->Emit(expr);
-});
+TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit")
+ .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) {
+ return builder->Emit(expr, name_hint);
+ });
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
.set_body_typed([](BlockBuilder builder, Expr value, StructInfo
struct_info) {
@@ -909,8 +910,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
});
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
- .set_body_typed([](BlockBuilder builder, const Expr& output) {
- return builder->EmitOutput(output);
+ .set_body_typed([](BlockBuilder builder, const Expr& output, String
name_hint) {
+ return builder->EmitOutput(output, name_hint);
});
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized")
diff --git a/tests/python/relax/test_blockbuilder.py
b/tests/python/relax/test_blockbuilder.py
index e54e2b7bf9..9d9d28d7d6 100644
--- a/tests/python/relax/test_blockbuilder.py
+++ b/tests/python/relax/test_blockbuilder.py
@@ -57,6 +57,22 @@ def test_block_builder():
assert not isinstance(b2, rx.DataflowBlock)
+def test_emit_with_name():
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
+ y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
+ bb = rx.BlockBuilder()
+
+ bb._begin_dataflow_block()
+ lv0 = bb.emit(rx.op.add(x, y), "add")
+ gv0 = bb.emit_output(rx.op.multiply(lv0, y), "multi")
+ b0 = bb._end_block()
+
+ assert b0.bindings[0].var.name_hint == "add"
+ assert b0.bindings[1].var.name_hint == "multi"
+
+
def test_function_single_block():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")