This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e6482c808 [Relax] Expose name_hint field for BlockBuilder.match_cast 
(#16600)
1e6482c808 is described below

commit 1e6482c80893c7fb682006d8ca104125f6693616
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Feb 18 21:06:13 2024 -0600

    [Relax] Expose name_hint field for BlockBuilder.match_cast (#16600)
    
    * [Relax] Expose name_hint field for BlockBuilder.match_cast
    
    Prior to this commit, while a `relax.VarBinding` created using
    `BlockBuilder.emit` could have its name explicitly specified by the
    user, a `relax.MatchCast` created using `BlockBuilder.match_cast`
    could not.  This commit updates `BlockBuilder.match_cast` to accept an
    optional `name_hint` parameter, which is then provided to the C++
    `BlockBuilder::EmitMatchCast` method.
    
    * Fix lint error
---
 python/tvm/relax/block_builder.py            | 12 ++++++++++--
 src/relax/ir/block_builder.cc                |  4 ++--
 tests/python/relax/test_blockbuilder_core.py |  3 ++-
 3 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/block_builder.py 
b/python/tvm/relax/block_builder.py
index b4206f76f4..330585599d 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -534,7 +534,7 @@ class BlockBuilder(Object):
         name_hint = kwargs.pop("name_hint", "")
         return self.emit(self.call_te(func, *args, **kwargs), 
name_hint=name_hint)
 
-    def match_cast(self, value: Expr, struct_info: StructInfo) -> Var:
+    def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str 
= "") -> Var:
         """Emit a MatchCast.
 
         Parameters
@@ -545,12 +545,20 @@ class BlockBuilder(Object):
         struct_info : StructInfo
             The struct info to be matched.
 
+        name_hint : str
+            The name of the match cast
+
         Returns
         -------
         ret : tvm.relax.Var
             A newly created variable that get bounds to be the casted result.
         """
-        return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info)  # 
type: ignore
+        return _ffi_api.BlockBuilderEmitMatchCast(
+            self,
+            value,
+            struct_info,
+            name_hint,
+        )  # type: ignore
 
     def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: 
str = "") -> Var:
         """Emit output for the current dataflow block or function.
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index b39beae740..a1fac27e06 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -1015,8 +1015,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit")
     });
 
 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
-    .set_body_typed([](BlockBuilder builder, Expr value, StructInfo 
struct_info) {
-      return builder->EmitMatchCast(value, struct_info);
+    .set_body_typed([](BlockBuilder builder, Expr value, StructInfo 
struct_info, String name_hint) {
+      return builder->EmitMatchCast(value, struct_info, name_hint);
     });
 
 TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
diff --git a/tests/python/relax/test_blockbuilder_core.py 
b/tests/python/relax/test_blockbuilder_core.py
index 16023c9c91..19bbdf5854 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -226,7 +226,7 @@ def test_emit_match_cast():
             assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, 
n], "float32"))
 
             # lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n]))
-            lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]))
+            lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name")
             assert lv1.struct_info == rx.ShapeStructInfo([m, n])
             gv0 = bb.emit_output(lv1)
 
@@ -244,6 +244,7 @@ def test_emit_match_cast():
     assert b1.value == y
     assert b1.struct_info == rx.ShapeStructInfo([m, n])
     assert b1.var == lv1
+    assert b1.var.name_hint == "var_name"
 
 
 def test_emit_match_cast_binding_in_dataflow_block():

Reply via email to