This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9dad95df06 [TIR] Fix block access region detection for nested let 
bindings (#18069)
9dad95df06 is described below

commit 9dad95df06f4826581ef12cceab341ab0bea57ca
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 17 20:24:32 2025 +0800

    [TIR] Fix block access region detection for nested let bindings (#18069)
    
    Recursively substitute let bindings in buffer indices until no more 
substitutions are possible. Add test case to verify handling of nested let 
bindings.
---
 src/tir/analysis/block_access_region_detector.cc   | 12 ++++++++--
 .../test_tir_analysis_get_block_access_region.py   | 26 ++++++++++++++++++++++
 2 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/src/tir/analysis/block_access_region_detector.cc 
b/src/tir/analysis/block_access_region_detector.cc
index ce13ac56c8..8b3598e356 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -153,8 +153,12 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op) 
{ UpdateOpaque(GetRef
 
 void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
   std::vector<arith::IntSet> relaxed_region;
-  for (const PrimExpr& index : op->indices) {
+  for (PrimExpr index : op->indices) {
     PrimExpr remapped_index = Substitute(index, let_bindings_);
+    while (!remapped_index.same_as(index)) {
+      index = remapped_index;
+      remapped_index = Substitute(index, let_bindings_);
+    }
     
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), 
dom_map_));
   }
   Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
@@ -236,8 +240,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* 
op) {
 
 void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
   std::vector<arith::IntSet> relaxed_region;
-  for (const PrimExpr& index : op->indices) {
+  for (PrimExpr index : op->indices) {
     PrimExpr remapped_index = Substitute(index, let_bindings_);
+    while (!remapped_index.same_as(index)) {
+      index = remapped_index;
+      remapped_index = Substitute(index, let_bindings_);
+    }
     
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), 
dom_map_));
   }
   Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
diff --git 
a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py 
b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
index a65277df61..1fa013399e 100644
--- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
@@ -385,5 +385,31 @@ def test_buffer_access_with_let_binding():
     tvm.ir.assert_structural_equal(block.writes, ret[1])
 
 
+def test_buffer_access_with_nested_let_binding():
+    @T.prim_func
+    def func(
+        A: T.Buffer((16, 16), "float32"),
+        B: T.Buffer((16, 16), "float32"),
+        C: T.Buffer((16, 16), "float32"),
+    ):
+        for i, s in T.grid(16, 16):
+            with T.block("copy"):
+                vi, vs = T.axis.remap("SS", [i, s])
+                T.reads(A[vi, vs], B[vi, vs])
+                T.writes(C[vi, vs])
+                vi1: T.int32 = vi
+                vi2: T.int32 = vi1
+                vs1: T.int32 = vs
+                vs2: T.int32 = vs1
+                vs3: T.int32 = vs2
+                C[vi, vs1] = A[vi1, vs2] + B[vi2, vs3]
+
+    block = func.body.block.body.body.body.block
+    buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()}
+    ret = tir.analysis.get_block_access_region(block, buffer_var_map)
+    tvm.ir.assert_structural_equal(block.reads, ret[0])
+    tvm.ir.assert_structural_equal(block.writes, ret[1])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to