This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9dad95df06 [TIR] Fix block access region detection for nested let
bindings (#18069)
9dad95df06 is described below
commit 9dad95df06f4826581ef12cceab341ab0bea57ca
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 17 20:24:32 2025 +0800
[TIR] Fix block access region detection for nested let bindings (#18069)
Recursively substitute let bindings in buffer indices until no more
substitutions are possible. Add test case to verify handling of nested let
bindings.
---
src/tir/analysis/block_access_region_detector.cc | 12 ++++++++--
.../test_tir_analysis_get_block_access_region.py | 26 ++++++++++++++++++++++
2 files changed, 36 insertions(+), 2 deletions(-)
diff --git a/src/tir/analysis/block_access_region_detector.cc
b/src/tir/analysis/block_access_region_detector.cc
index ce13ac56c8..8b3598e356 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -153,8 +153,12 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op)
{ UpdateOpaque(GetRef
void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
- for (const PrimExpr& index : op->indices) {
+ for (PrimExpr index : op->indices) {
PrimExpr remapped_index = Substitute(index, let_bindings_);
+ while (!remapped_index.same_as(index)) {
+ index = remapped_index;
+ remapped_index = Substitute(index, let_bindings_);
+ }
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index),
dom_map_));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
@@ -236,8 +240,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode*
op) {
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
- for (const PrimExpr& index : op->indices) {
+ for (PrimExpr index : op->indices) {
PrimExpr remapped_index = Substitute(index, let_bindings_);
+ while (!remapped_index.same_as(index)) {
+ index = remapped_index;
+ remapped_index = Substitute(index, let_bindings_);
+ }
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index),
dom_map_));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
diff --git
a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
index a65277df61..1fa013399e 100644
--- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
@@ -385,5 +385,31 @@ def test_buffer_access_with_let_binding():
tvm.ir.assert_structural_equal(block.writes, ret[1])
+def test_buffer_access_with_nested_let_binding():
+ @T.prim_func
+ def func(
+ A: T.Buffer((16, 16), "float32"),
+ B: T.Buffer((16, 16), "float32"),
+ C: T.Buffer((16, 16), "float32"),
+ ):
+ for i, s in T.grid(16, 16):
+ with T.block("copy"):
+ vi, vs = T.axis.remap("SS", [i, s])
+ T.reads(A[vi, vs], B[vi, vs])
+ T.writes(C[vi, vs])
+ vi1: T.int32 = vi
+ vi2: T.int32 = vi1
+ vs1: T.int32 = vs
+ vs2: T.int32 = vs1
+ vs3: T.int32 = vs2
+ C[vi, vs1] = A[vi1, vs2] + B[vi2, vs3]
+
+ block = func.body.block.body.body.body.block
+ buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()}
+ ret = tir.analysis.get_block_access_region(block, buffer_var_map)
+ tvm.ir.assert_structural_equal(block.reads, ret[0])
+ tvm.ir.assert_structural_equal(block.writes, ret[1])
+
+
if __name__ == "__main__":
tvm.testing.main()