This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c3281c04d4 [Script] Add support for merging block annotations (#18079)
c3281c04d4 is described below
commit c3281c04d45b8501c7743e328a44271675b6fc38
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Jun 21 01:45:49 2025 +0800
[Script] Add support for merging block annotations (#18079)
This commit introduces functionality to merge block annotations in TVM
script.
The implementation includes:
- MergeAnnotations function that recursively merges annotation dictionaries
- Support for nested dictionary merging with new values overriding old ones
- Error handling for conflicting annotation values
- BlockAttrs function that uses the merging logic to combine multiple
T.block_attr() calls within the same block
The feature allows users to specify block attributes incrementally using
multiple T.block_attr() calls, which will be automatically merged together.
---
src/script/ir_builder/tir/ir.cc | 41 ++++++++++++++++++--
.../tvmscript/test_tvmscript_error_report.py | 17 +++++----
.../python/tvmscript/test_tvmscript_parser_tir.py | 44 ++++++++++++++++++++++
3 files changed, 91 insertions(+), 11 deletions(-)
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 6f73254ff2..831dbcdd4a 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -219,12 +219,47 @@ void Writes(Array<ObjectRef> buffer_slices) {
frame->writes = writes;
}
+/*! \brief Recursively merge two annotations, the new attrs will override the
old ones */
+Map<String, Any> MergeAnnotations(const Map<String, Any>& new_attrs,
+ const Map<String, Any>& old_attrs) {
+ Map<String, Any> result = old_attrs;
+ for (const auto& [key, value] : new_attrs) {
+ auto old_value = old_attrs.Get(key);
+ // Case 1: the key is not in the old annotations, set the key to the new
value
+ if (!old_value) {
+ result.Set(key, value);
+ continue;
+ }
+
+ // Case 2: the key is in the old annotations
+ // Case 2.1: both are dicts
+ auto old_dict = old_value->try_cast<Map<String, Any>>();
+ auto new_dict = value.try_cast<Map<String, Any>>();
+ if (old_dict && new_dict) {
+ // Recursively merge the two dicts
+ auto merged_dict = MergeAnnotations(*old_dict, *new_dict);
+ result.Set(key, merged_dict);
+ continue;
+ }
+ // Case 2.2: the values are not both dicts, check if the keys are the same
+ if (!ffi::AnyEqual()(old_value.value(), value)) {
+ LOG(FATAL) << "ValueError: Try to merge two annotations with different
values for key `"
+ << key << "`, previous one is " <<
old_value->cast<ObjectRef>() << ", new one is "
+ << value.cast<ObjectRef>();
+ }
+ }
+ return result;
+}
+
void BlockAttrs(Map<String, Any> attrs) {
BlockFrame frame = FindBlockFrame("T.block_attr");
- if (frame->annotations.defined()) {
- LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is "
<< frame->annotations;
+ // Case 1: the block has no annotations, set the new annotations
+ if (!frame->annotations.defined()) {
+ frame->annotations = attrs;
+ } else {
+ // Case 2: the block has annotations, merge the new annotations with the
old ones
+ frame->annotations = MergeAnnotations(attrs, frame->annotations.value());
}
- frame->annotations = attrs;
}
Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py
b/tests/python/tvmscript/test_tvmscript_error_report.py
index d8212d3885..1cbd6af961 100644
--- a/tests/python/tvmscript/test_tvmscript_error_report.py
+++ b/tests/python/tvmscript/test_tvmscript_error_report.py
@@ -280,13 +280,6 @@ def test_duplicate_block_signature():
T.where(1)
T.where(0) # error
- def duplicate_annotations() -> None:
- for i, j in T.grid(16, 16):
- with T.block():
- vi, vj = T.axis.remap("SS", [i, j])
- T.block_attr({})
- T.block_attr({}) # error
-
def duplicate_init() -> None:
for i, j in T.grid(16, 16):
with T.block():
@@ -303,12 +296,20 @@ def test_duplicate_block_signature():
vi = T.axis.S(i, 16) # error
T.evaluate(1.0)
+ def duplicate_block_attrs_with_same_key_diff_value() -> None:
+ for i, j in T.grid(16, 16):
+ with T.block():
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.block_attr({"key1": "block1"})
+ T.block_attr({"key1": "block2"}) # error
+ T.evaluate(1.0)
+
check_error(duplicate_reads, 7)
check_error(duplicate_writes, 7)
check_error(duplicate_predicate, 6)
- check_error(duplicate_annotations, 6)
check_error(duplicate_init, 7)
check_error(duplicate_axes, 5)
+ check_error(duplicate_block_attrs_with_same_key_diff_value, 6)
def test_opaque_access_during_complete():
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 16b2067514..d5ee2e0772 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -544,5 +544,49 @@ def test_deterministic_branch():
tvm.ir.assert_structural_equal(create_func(False), create_expected(1))
+def test_block_annotation_merge():
+ def _to_dict(anno: tvm.ffi.container.Map):
+ result = {}
+ for k, v in anno.items():
+ result[k] = _to_dict(v) if isinstance(v, tvm.ffi.container.Map)
else v
+ return result
+
+ @T.prim_func
+ def func0():
+ with T.block():
+ T.block_attr({"key1": "block1"})
+ T.block_attr({"key2": "block2"})
+ T.evaluate(0)
+
+ assert _to_dict(func0.body.block.annotations) == {"key1": "block1",
"key2": "block2"}
+
+ @T.prim_func
+ def func1():
+ with T.block():
+ T.block_attr({"key": {"key1": "block1"}})
+ T.block_attr({"key": {"key2": "block2"}})
+ T.evaluate(0)
+
+ assert _to_dict(func1.body.block.annotations) == {"key": {"key1":
"block1", "key2": "block2"}}
+
+ @T.prim_func
+ def func2():
+ with T.block():
+ T.block_attr({"key1": "block1"})
+ T.block_attr({"key1": "block1"})
+ T.evaluate(0)
+
+ assert _to_dict(func2.body.block.annotations) == {"key1": "block1"}
+
+ with pytest.raises(tvm.TVMError):
+
+ @T.prim_func
+ def func3():
+ with T.block():
+ T.block_attr({"key1": "block1"})
+ T.block_attr({"key1": "block2"})
+ T.evaluate(0)
+
+
if __name__ == "__main__":
tvm.testing.main()