This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 85a8770857 [Relax] Enhance unique block name generation with numeric
suffixes (#18554)
85a8770857 is described below
commit 85a877085714b4d10d65e2c267dab3937915e8a1
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Dec 10 13:58:38 2025 +0800
[Relax] Enhance unique block name generation with numeric suffixes (#18554)
## Why
Resolve todo in `fuse_tir.cc` by enhancing unique block name generation
with numeric suffixes
---
src/relax/transform/fuse_tir.cc | 51 ++++++++++++++++---
tests/python/relax/test_transform_fuse_tir.py | 72 +++++++++++++++++++++++++++
2 files changed, 115 insertions(+), 8 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index ba4515faf3..549cd2197b 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -357,17 +357,52 @@ class BlockNameDeduplicator : public tir::StmtMutator {
}
ffi::String GetUniqueName(const ffi::String& prefix) {
- ffi::String unique_prefix = prefix;
- auto it = name_count_.find(prefix);
- while (name_count_.count(unique_prefix)) {
- unique_prefix = prefix + "_" + std::to_string(++it->second);
+ std::string str_prefix = std::string(prefix);
+
+ // Find where the trailing digits start
+ size_t base_len = str_prefix.length();
+ while (base_len > 0 && std::isdigit(str_prefix[base_len - 1])) {
+ --base_len;
+ }
+
+ std::string base_name;
+ int64_t start_num = 0;
+ bool has_suffix = base_len < str_prefix.length();
+
+ if (has_suffix) {
+ base_name = str_prefix.substr(0, base_len);
+ try {
+ start_num = std::stoll(str_prefix.substr(base_len));
+ } catch (const std::out_of_range&) {
+ // Fallback: if the number is too large, treat the whole string as a
base name.
+ has_suffix = false;
+ base_name = str_prefix;
+ }
+ } else {
+ base_name = str_prefix;
+ }
+
+ // Check if the original name is available
+ ffi::String candidate = prefix;
+ if (!name_count_.count(candidate)) {
+ name_count_[candidate] = 0;
+ return candidate;
+ }
+
+ // Generate unique name by incrementing the numeric suffix
+ int64_t counter = has_suffix ? start_num + 1 : 1;
+ while (true) {
+ candidate = ffi::String(base_name + std::to_string(counter));
+ if (!name_count_.count(candidate)) {
+ name_count_[candidate] = 0;
+ return candidate;
+ }
+ ++counter;
+ ICHECK_GT(counter, 0) << "Counter overflow when generating unique block
name for prefix: "
+ << prefix;
}
- name_count_[unique_prefix] = 0;
- return unique_prefix;
}
- // TODO(relax-team): It should detects the number suffix and do renaming
properly
- // e.g. GetUniqueName("name1") should return "name2" instead of "name10".
/*! \brief The count map to make block name unique. */
std::unordered_map<ffi::String, int> name_count_;
};
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 8e583b3dd4..a67bc63f9b 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -2444,5 +2444,77 @@ def
test_fuse_with_axis_separators_inconsistent_buffer_mapping():
relax.transform.FuseTIR()(Before)
+def test_block_name_numeric_suffix_deduplication():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i in range(10):
+ with T.block("compute1"):
+ vi = T.axis.spatial(10, i)
+ y[vi] = x[vi] + T.float32(1.0)
+
+ @T.prim_func(private=True)
+ def mul1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for i in range(10):
+ with T.block("compute1"):
+ vi = T.axis.spatial(10, i)
+ y[vi] = x[vi] * T.float32(2.0)
+
+ @R.function(private=True)
+ def fused_add_mul(x: R.Tensor((10,), "float32")) -> R.Tensor((10,),
dtype="float32"):
+ R.func_attr({"Primitive": True})
+ cls = Before
+ with R.dataflow():
+ lv1 = R.call_tir(cls.add1, (x,), out_sinfo=R.Tensor((10,),
dtype="float32"))
+ lv2 = R.call_tir(cls.mul1, (lv1,), out_sinfo=R.Tensor((10,),
dtype="float32"))
+ R.output(lv2)
+ return lv2
+
+ @R.function
+ def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused_add_mul(x)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused_add_mul(p_x: T.handle, p_output0: T.handle):
+ T.func_attr({"tir.noalias": True})
+ x = T.match_buffer(p_x, (T.int64(10),))
+ y_intermediate_1 = T.match_buffer(p_output0, (T.int64(10),),
elem_offset=T.int32(0))
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ y_intermediate = T.alloc_buffer((T.int64(10),),
elem_offset=T.int32(0))
+ for i in range(10):
+ with T.block("compute1"):
+ vi = T.axis.spatial(10, i)
+ T.reads(x[vi])
+ T.writes(y_intermediate[vi])
+ y_intermediate[vi] = x[vi] + T.float32(1.0)
+ for i in range(10):
+ with T.block("compute2"):
+ vi = T.axis.spatial(10, i)
+ T.reads(y_intermediate[vi])
+ T.writes(y_intermediate_1[vi])
+ y_intermediate_1[vi] = y_intermediate[vi] *
T.float32(2.0)
+
+ @R.function
+ def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(cls.fused_add_mul, (x,),
out_sinfo=R.Tensor((10,), dtype="float32"))
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()