This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c3281c04d4 [Script] Add support for merging block annotations (#18079)
c3281c04d4 is described below

commit c3281c04d45b8501c7743e328a44271675b6fc38
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Jun 21 01:45:49 2025 +0800

    [Script] Add support for merging block annotations (#18079)
    
    This commit introduces functionality to merge block annotations in TVM 
script.
    The implementation includes:
    
    - MergeAnnotations function that recursively merges annotation dictionaries
    - Support for nested dictionary merging with new values overriding old ones
    - Error handling for conflicting annotation values
    - BlockAttrs function that uses the merging logic to combine multiple
      T.block_attr() calls within the same block
    
    The feature allows users to specify block attributes incrementally using
    multiple T.block_attr() calls, which will be automatically merged together.
---
 src/script/ir_builder/tir/ir.cc                    | 41 ++++++++++++++++++--
 .../tvmscript/test_tvmscript_error_report.py       | 17 +++++----
 .../python/tvmscript/test_tvmscript_parser_tir.py  | 44 ++++++++++++++++++++++
 3 files changed, 91 insertions(+), 11 deletions(-)

diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 6f73254ff2..831dbcdd4a 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -219,12 +219,47 @@ void Writes(Array<ObjectRef> buffer_slices) {
   frame->writes = writes;
 }
 
+/*! \brief Recursively merge two annotations, the new attrs will override the 
old ones */
+Map<String, Any> MergeAnnotations(const Map<String, Any>& new_attrs,
+                                  const Map<String, Any>& old_attrs) {
+  Map<String, Any> result = old_attrs;
+  for (const auto& [key, value] : new_attrs) {
+    auto old_value = old_attrs.Get(key);
+    // Case 1: the key is not in the old annotations, set the key to the new 
value
+    if (!old_value) {
+      result.Set(key, value);
+      continue;
+    }
+
+    // Case 2: the key is in the old annotations
+    // Case 2.1: both are dicts
+    auto old_dict = old_value->try_cast<Map<String, Any>>();
+    auto new_dict = value.try_cast<Map<String, Any>>();
+    if (old_dict && new_dict) {
+      // Recursively merge the two dicts
+      auto merged_dict = MergeAnnotations(*old_dict, *new_dict);
+      result.Set(key, merged_dict);
+      continue;
+    }
+    // Case 2.2: the values are not both dicts, check if the keys are the same
+    if (!ffi::AnyEqual()(old_value.value(), value)) {
+      LOG(FATAL) << "ValueError: Try to merge two annotations with different 
values for key `"
+                 << key << "`, previous one is " << 
old_value->cast<ObjectRef>() << ", new one is "
+                 << value.cast<ObjectRef>();
+    }
+  }
+  return result;
+}
+
 void BlockAttrs(Map<String, Any> attrs) {
   BlockFrame frame = FindBlockFrame("T.block_attr");
-  if (frame->annotations.defined()) {
-    LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " 
<< frame->annotations;
+  // Case 1: the block has no annotations, set the new annotations
+  if (!frame->annotations.defined()) {
+    frame->annotations = attrs;
+  } else {
+    // Case 2: the block has annotations, merge the new annotations with the 
old ones
+    frame->annotations = MergeAnnotations(attrs, frame->annotations.value());
   }
-  frame->annotations = attrs;
 }
 
 Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py 
b/tests/python/tvmscript/test_tvmscript_error_report.py
index d8212d3885..1cbd6af961 100644
--- a/tests/python/tvmscript/test_tvmscript_error_report.py
+++ b/tests/python/tvmscript/test_tvmscript_error_report.py
@@ -280,13 +280,6 @@ def test_duplicate_block_signature():
                 T.where(1)
                 T.where(0)  # error
 
-    def duplicate_annotations() -> None:
-        for i, j in T.grid(16, 16):
-            with T.block():
-                vi, vj = T.axis.remap("SS", [i, j])
-                T.block_attr({})
-                T.block_attr({})  # error
-
     def duplicate_init() -> None:
         for i, j in T.grid(16, 16):
             with T.block():
@@ -303,12 +296,20 @@ def test_duplicate_block_signature():
                 vi = T.axis.S(i, 16)  # error
                 T.evaluate(1.0)
 
+    def duplicate_block_attrs_with_same_key_diff_value() -> None:
+        for i, j in T.grid(16, 16):
+            with T.block():
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.block_attr({"key1": "block1"})
+                T.block_attr({"key1": "block2"})  # error
+                T.evaluate(1.0)
+
     check_error(duplicate_reads, 7)
     check_error(duplicate_writes, 7)
     check_error(duplicate_predicate, 6)
-    check_error(duplicate_annotations, 6)
     check_error(duplicate_init, 7)
     check_error(duplicate_axes, 5)
+    check_error(duplicate_block_attrs_with_same_key_diff_value, 6)
 
 
 def test_opaque_access_during_complete():
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 16b2067514..d5ee2e0772 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -544,5 +544,49 @@ def test_deterministic_branch():
     tvm.ir.assert_structural_equal(create_func(False), create_expected(1))
 
 
+def test_block_annotation_merge():
+    def _to_dict(anno: tvm.ffi.container.Map):
+        result = {}
+        for k, v in anno.items():
+            result[k] = _to_dict(v) if isinstance(v, tvm.ffi.container.Map) 
else v
+        return result
+
+    @T.prim_func
+    def func0():
+        with T.block():
+            T.block_attr({"key1": "block1"})
+            T.block_attr({"key2": "block2"})
+            T.evaluate(0)
+
+    assert _to_dict(func0.body.block.annotations) == {"key1": "block1", 
"key2": "block2"}
+
+    @T.prim_func
+    def func1():
+        with T.block():
+            T.block_attr({"key": {"key1": "block1"}})
+            T.block_attr({"key": {"key2": "block2"}})
+            T.evaluate(0)
+
+    assert _to_dict(func1.body.block.annotations) == {"key": {"key1": 
"block1", "key2": "block2"}}
+
+    @T.prim_func
+    def func2():
+        with T.block():
+            T.block_attr({"key1": "block1"})
+            T.block_attr({"key1": "block1"})
+            T.evaluate(0)
+
+    assert _to_dict(func2.body.block.annotations) == {"key1": "block1"}
+
+    with pytest.raises(tvm.TVMError):
+
+        @T.prim_func
+        def func3():
+            with T.block():
+                T.block_attr({"key1": "block1"})
+                T.block_attr({"key1": "block2"})
+                T.evaluate(0)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to