This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 82578c394c [Unity][BlockBuilder] Add `name_hint` argument for `emit` 
and `emit_output` (#14126)
82578c394c is described below

commit 82578c394c17ef3aeced530fbcbd18ee66ce6425
Author: Chaofan Lin <[email protected]>
AuthorDate: Sat Feb 25 14:08:12 2023 +0800

    [Unity][BlockBuilder] Add `name_hint` argument for `emit` and `emit_output` 
(#14126)
    
    This PR adds `name_hint` argument for `emit` and `emit_output` API of Relax 
blockbuilder. The argument exists in the C++ side but not exposed to Python 
side (So user who use the Python bb.emit will let `name_hint` be `""` by 
default).
    
    Co-authored-by: Yixin Dong <[email protected]>
---
 python/tvm/relax/block_builder.py       | 19 ++++++++++---------
 src/relax/ir/block_builder.cc           | 11 ++++++-----
 tests/python/relax/test_blockbuilder.py | 16 ++++++++++++++++
 3 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relax/block_builder.py 
b/python/tvm/relax/block_builder.py
index f219641c81..3421bd4d09 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -365,7 +365,7 @@ class BlockBuilder(Object):
         """
         return DataflowScope(self)
 
-    def emit(self, expr: Expr) -> Var:
+    def emit(self, expr: Expr, name_hint: str = "") -> Var:
         """Emit an expr.
         This infers the shape and type of the expr, create a variable,
         and bind the expr to the variable.
@@ -375,12 +375,15 @@ class BlockBuilder(Object):
         expr : tvm.relax.Expr
             The Expr to be emitted.
 
+        name_hint : str
+            Name hint for the bound variable.
+
         Returns
         -------
         ret : tvm.relax.Var
             A newly created variable that gets bound to the input expr.
         """
-        return _ffi_api.BlockBuilderEmit(self, expr)  # type: ignore
+        return _ffi_api.BlockBuilderEmit(self, expr, name_hint)  # type: ignore
 
     def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
         """Generate a call node according to the te function.
@@ -601,7 +604,7 @@ class BlockBuilder(Object):
         """
         return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info)  # 
type: ignore
 
-    def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
+    def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: 
str = "") -> Var:
         """Emit output for the current dataflow block or function.
 
         Parameters
@@ -609,6 +612,9 @@ class BlockBuilder(Object):
         output : Expr | Tuple | List[Expr]
             The output of the current block/function.
 
+        name_hint : str
+            Name hint for the bound variable.
+
         Returns
         -------
         ret : tvm.relax.Var
@@ -616,7 +622,7 @@ class BlockBuilder(Object):
         """
         if isinstance(output, (list, tuple)):
             output = Tuple(output)
-        return _ffi_api.BlockBuilderEmitOutput(self, output)  # type: ignore
+        return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint)  # 
type: ignore
 
     def emit_func_output(
         self,
@@ -633,11 +639,6 @@ class BlockBuilder(Object):
         params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
             The parameters of the function to be built.
             If params is None, it means the params have been initialized in 
the function with scope.
-
-        Returns
-        -------
-        ret : tvm.relax.Var
-            The return variable which gets bound to the output.
         """
         if self._is_emit_func_output_called:
             raise RuntimeError("emit_func_output must be called exactly once 
in a relax function.")
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 5976cbb3f4..ac92114ef9 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -899,9 +899,10 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock")
 TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize")
     .set_body_method<BlockBuilder>(&BlockBuilderNode::Normalize);
 
-TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder 
builder, Expr expr) {
-  return builder->Emit(expr);
-});
+TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit")
+    .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) {
+      return builder->Emit(expr, name_hint);
+    });
 
 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
     .set_body_typed([](BlockBuilder builder, Expr value, StructInfo 
struct_info) {
@@ -909,8 +910,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
     });
 
 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
-    .set_body_typed([](BlockBuilder builder, const Expr& output) {
-      return builder->EmitOutput(output);
+    .set_body_typed([](BlockBuilder builder, const Expr& output, String 
name_hint) {
+      return builder->EmitOutput(output, name_hint);
     });
 
 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized")
diff --git a/tests/python/relax/test_blockbuilder.py 
b/tests/python/relax/test_blockbuilder.py
index e54e2b7bf9..9d9d28d7d6 100644
--- a/tests/python/relax/test_blockbuilder.py
+++ b/tests/python/relax/test_blockbuilder.py
@@ -57,6 +57,22 @@ def test_block_builder():
     assert not isinstance(b2, rx.DataflowBlock)
 
 
+def test_emit_with_name():
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x = rx.Var("x", rx.TensorStructInfo([m, n], "float16"))
+    y = rx.Var("y", rx.TensorStructInfo([n], "float16"))
+    bb = rx.BlockBuilder()
+
+    bb._begin_dataflow_block()
+    lv0 = bb.emit(rx.op.add(x, y), "add")
+    gv0 = bb.emit_output(rx.op.multiply(lv0, y), "multi")
+    b0 = bb._end_block()
+
+    assert b0.bindings[0].var.name_hint == "add"
+    assert b0.bindings[1].var.name_hint == "multi"
+
+
 def test_function_single_block():
     m = tir.Var("m", "int64")
     n = tir.Var("n", "int64")

Reply via email to