This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1e6482c808 [Relax] Expose name_hint field for BlockBuilder.match_cast
(#16600)
1e6482c808 is described below
commit 1e6482c80893c7fb682006d8ca104125f6693616
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Feb 18 21:06:13 2024 -0600
[Relax] Expose name_hint field for BlockBuilder.match_cast (#16600)
* [Relax] Expose name_hint field for BlockBuilder.match_cast
Prior to this commit, while a `relax.VarBinding` created using
`BlockBuilder.emit` could have its name explicitly specified by the
user, a `relax.MatchCast` created using `BlockBuilder.match_cast`
could not. This commit updates `BlockBuilder.match_cast` to accept an
optional `name_hint` parameter, which is then provided to the C++
`BlockBuilder::EmitMatchCast` method.
* Fix lint error
---
python/tvm/relax/block_builder.py | 12 ++++++++++--
src/relax/ir/block_builder.cc | 4 ++--
tests/python/relax/test_blockbuilder_core.py | 3 ++-
3 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/block_builder.py
b/python/tvm/relax/block_builder.py
index b4206f76f4..330585599d 100644
--- a/python/tvm/relax/block_builder.py
+++ b/python/tvm/relax/block_builder.py
@@ -534,7 +534,7 @@ class BlockBuilder(Object):
name_hint = kwargs.pop("name_hint", "")
return self.emit(self.call_te(func, *args, **kwargs),
name_hint=name_hint)
- def match_cast(self, value: Expr, struct_info: StructInfo) -> Var:
+ def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str
= "") -> Var:
"""Emit a MatchCast.
Parameters
@@ -545,12 +545,20 @@ class BlockBuilder(Object):
struct_info : StructInfo
The struct info to be matched.
+ name_hint : str
+ The name of the match cast
+
Returns
-------
ret : tvm.relax.Var
A newly created variable that get bounds to be the casted result.
"""
- return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) #
type: ignore
+ return _ffi_api.BlockBuilderEmitMatchCast(
+ self,
+ value,
+ struct_info,
+ name_hint,
+ ) # type: ignore
def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint:
str = "") -> Var:
"""Emit output for the current dataflow block or function.
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index b39beae740..a1fac27e06 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -1015,8 +1015,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit")
});
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
- .set_body_typed([](BlockBuilder builder, Expr value, StructInfo
struct_info) {
- return builder->EmitMatchCast(value, struct_info);
+ .set_body_typed([](BlockBuilder builder, Expr value, StructInfo
struct_info, String name_hint) {
+ return builder->EmitMatchCast(value, struct_info, name_hint);
});
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
diff --git a/tests/python/relax/test_blockbuilder_core.py
b/tests/python/relax/test_blockbuilder_core.py
index 16023c9c91..19bbdf5854 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -226,7 +226,7 @@ def test_emit_match_cast():
assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m,
n], "float32"))
# lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n]))
- lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]))
+ lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name")
assert lv1.struct_info == rx.ShapeStructInfo([m, n])
gv0 = bb.emit_output(lv1)
@@ -244,6 +244,7 @@ def test_emit_match_cast():
assert b1.value == y
assert b1.struct_info == rx.ShapeStructInfo([m, n])
assert b1.var == lv1
+ assert b1.var.name_hint == "var_name"
def test_emit_match_cast_binding_in_dataflow_block():