This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6365a302d1 [TIR] Fix reduce buffer allocation position (#17799)
6365a302d1 is described below

commit 6365a302d179f13109b01a25b640e0250523ad03
Author: wrongtest <[email protected]>
AuthorDate: Thu Apr 3 09:05:00 2025 +0800

    [TIR] Fix reduce buffer allocation position (#17799)
    
    * fix reduce buffer allocation position
    
    * fix test_tir_analysis_detect_buffer_access_lca.py::test_buffer_load_store
---
 src/tir/analysis/buffer_access_lca_detector.cc     | 91 +++++++++++++++-------
 .../test_tir_analysis_detect_buffer_access_lca.py  |  7 +-
 ...sform_plan_update_buffer_allocation_location.py | 50 ++++++++++++
 3 files changed, 115 insertions(+), 33 deletions(-)

diff --git a/src/tir/analysis/buffer_access_lca_detector.cc 
b/src/tir/analysis/buffer_access_lca_detector.cc
index ff0b11a73c..dd1fce0fbe 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -117,10 +117,13 @@ class LCADetector : public StmtExprVisitor {
 
     ancestor_scopes_.push_back(current_scope);
 
-    // For each accessed buffer of the block, update the buffer's lca to
+    // For each accessed buffer of the block
+    // If it accesses the opaque block iter vars, update the buffer's lca to
     // the lowest inclusive stmt position, which should dominate all loops
-    // related to the accessed opaque block iter vars in buffer indices.
-    UpdateDominateScopeOfOpaqueIter(op);
+    // related to the accessed opaque block iter vars.
+    // If it is the reduction block write buffer, update the buffer's lca to
+    // dominate all reduction iter var related loops.
+    UpdateDominateScopeOfNonDataParIter(op);
 
     // Update match_buffers
     for (const MatchBufferRegion& match_buffer : block->match_buffers) {
@@ -132,43 +135,70 @@ class LCADetector : public StmtExprVisitor {
     ancestor_scopes_.pop_back();
   }
 
-  void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
-    // map opaque iter var to the scope which dominate all loop carried 
dependencies.
-    std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;
+  void UpdateDominateScopeOfNonDataParIter(const BlockRealizeNode* 
block_realize) {
+    // map iter var to the scope which dominate all loop carried dependencies.
+    std::unordered_map<const VarNode*, const ScopeInfo*> opaque_var_scope;
+    // maintain highest scope which dominate all reduce loop iters. null 
denotes non-reduce block.
+    const ScopeInfo* highest_reduce_scope = nullptr;
 
     // function to collect `itervar_to_dom_scope`, the result scope for each 
block
     // iter var should be above all loop scopes the opaque iter var binding 
relates to.
-    auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const 
IterVar& itervar,
-                                                                  const 
PrimExpr& binding) {
-      PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const 
ObjectRef& obj) {
+    auto do_collect_itervar_scope = [this](const IterVar& itervar,
+                                           const PrimExpr& binding) -> const 
ScopeInfo* {
+      const ScopeInfo* highest_scope = nullptr;
+      PostOrderVisit(binding, [this, &itervar, &highest_scope](const 
ObjectRef& obj) {
         if (const VarNode* loop_var = obj.as<VarNode>()) {
           auto it = loop_scope_map_.find(loop_var);
           if (it == loop_scope_map_.end()) {
             return;
           }
           const ScopeInfo* scope = it->second->parent_scope_info;
-          // find the highest loop scope the iter var binding has related to.
-          auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
-          if (dom_scope_it == itervar_to_dom_scope.end()) {
-            itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), 
scope});
-          } else if (scope->depth < dom_scope_it->second->depth) {
-            dom_scope_it->second = scope;
+          if (highest_scope == nullptr) {
+            highest_scope = scope;
+          } else if (scope->depth < highest_scope->depth) {
+            highest_scope = scope;
           }
         }
       });
+      return highest_scope;
     };
 
+    // collect non-data-parallel block iteration's dominate scope.
+    // for reduction iter type, we maintain the highest dominate scope for all 
reduce iters.
+    // for other iter type, we maintain the dict for each individual iter.
+    const Block& block = block_realize->block;
+    bool is_reduce_block = false;
+    for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
+      const IterVar& iter_var = block->iter_vars[i];
+      if (iter_var->iter_type != IterVarType::kDataPar) {
+        const auto* scope = do_collect_itervar_scope(iter_var, 
block_realize->iter_values[i]);
+        if (scope == nullptr) continue;
+        if (iter_var->iter_type == IterVarType::kCommReduce) {
+          is_reduce_block = true;
+          if (highest_reduce_scope == nullptr || scope->depth < 
highest_reduce_scope->depth) {
+            highest_reduce_scope = scope;
+          }
+        } else {
+          opaque_var_scope[iter_var->var.get()] = scope;
+          for (const auto& write : block->writes) {
+            UpdateBufferLCA(write->buffer.get(), scope);
+          }
+        }
+      }
+    }
+
     // function to update lca scope of the buffer with loop carried dependent 
buffer accesses.
     // the result scope should be above all loop scopes the accessed opaque 
block iter vars
     // relate to, which is record in `itervar_to_dom_scope`.
-    auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) 
{
+    auto do_update = [this, &opaque_var_scope, highest_reduce_scope](const 
BufferRegion& region,
+                                                                     bool 
is_reduce_write = false) {
       const Buffer& buffer = region->buffer;
       const ScopeInfo* scope = ancestor_scopes_.back();
 
-      auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& 
obj) {
+      auto handle_itervar = [&opaque_var_scope, &scope](const ObjectRef& obj) {
         if (const VarNode* iter_var = obj.as<VarNode>()) {
-          auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
-          if (dom_scope_it == itervar_to_dom_scope.end()) {
+          auto dom_scope_it = opaque_var_scope.find(iter_var);
+          if (dom_scope_it == opaque_var_scope.end()) {
             return;
           }
           // find the highest loop scope the accessed buffer index has
@@ -184,24 +214,25 @@ class LCADetector : public StmtExprVisitor {
         PostOrderVisit(range->min, handle_itervar);
         PostOrderVisit(range->min + range->extent - 1, handle_itervar);
       }
+
+      // the scope should be above `highest_reduce_scope` for reduce output 
buffer.
+      if (is_reduce_write && highest_reduce_scope != nullptr &&
+          scope->depth > highest_reduce_scope->depth) {
+        scope = highest_reduce_scope;
+      }
       UpdateBufferLCA(buffer.get(), scope);
     };
 
-    // do collect and update
-    const Block& block = block_realize->block;
-    for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
-      const IterVar& iter_var = block->iter_vars[i];
-      if (iter_var->iter_type != IterVarType::kDataPar &&
-          iter_var->iter_type != IterVarType::kCommReduce) {
-        do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
-      }
-    }
-    if (!itervar_to_dom_scope.empty()) {
+    if (!opaque_var_scope.empty()) {
       for (const auto& read : block->reads) {
         do_update(read);
       }
       for (const auto& write : block->writes) {
-        do_update(write);
+        do_update(write, /*is_reduce_write=*/is_reduce_block);
+      }
+    } else if (is_reduce_block && highest_reduce_scope != nullptr) {
+      for (const auto& write : block->writes) {
+        do_update(write, /*is_reduce_write=*/true);
       }
     }
   }
diff --git 
a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py 
b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
index a1808c8413..b3ce7efd05 100644
--- a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
+++ b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
@@ -116,9 +116,10 @@ def test_buffer_load_store():
     root_block = func.body.block
     assert lca[A] == func.body.block
 
-    # LCA of Buffer B is reduction block
-    reduce_block = root_block.body[1].body.body.body.block
-    assert lca[B] == reduce_block
+    # LCA of Buffer B is the loop dominate all reduction loop
+    reduce_dom_loop = root_block.body[1].body
+    reduce_block = reduce_dom_loop.body.body.block
+    assert lca[B] == reduce_dom_loop
 
     # LCA of Buffer C is the second loop kk
     loop_jj = reduce_block.body.body
diff --git 
a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
 
b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
index 8500f11461..ff3fa8cf70 100644
--- 
a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
+++ 
b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -402,5 +402,55 @@ def test_dltensor_buffer_is_unlowered():
     _check(before, after)
 
 
+def test_reduce_buffer_dominate_reduce_loops():
+    """Reduction write buffer allocation should dominate all reduce loops"""
+
+    @T.prim_func
+    def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 
256), "float32")):
+        x_red_ = T.alloc_buffer((256, 256))
+        for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4):
+            for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
+                with T.block("x_red"):
+                    v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
+                    v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
+                    v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
+                    if v_k1 == 0:
+                        x_red_[v_ax0, v_ax1] = T.float32(0.0)
+                    x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0, 
v_k1, v_ax1]
+            for ax0, ax1 in T.grid(64, 64):
+                with T.block("x_red_"):
+                    v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
+                    v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
+                    x_red[v0, v1] = x_red_[v0, v1]
+
+    @T.prim_func
+    def after(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256, 
256), "float32")):
+        for ax0_0 in range(4):
+            with T.block(""):
+                T.reads(x[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256, 0:256])
+                T.writes(x_red[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256])
+                x_red_ = T.alloc_buffer((256, 256))
+                for k1_0, ax1_0 in T.grid(4, 4):
+                    for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
+                        with T.block("x_red"):
+                            v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
+                            v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
+                            v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
+                            T.reads(x_red_[v_ax0, v_ax1], x[v_ax0, v_k1, 
v_ax1])
+                            T.writes(x_red_[v_ax0, v_ax1])
+                            if v_k1 == 0:
+                                x_red_[v_ax0, v_ax1] = T.float32(0.0)
+                            x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + 
x[v_ax0, v_k1, v_ax1]
+                    for ax0, ax1 in T.grid(64, 64):
+                        with T.block("x_red_"):
+                            v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
+                            v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
+                            T.reads(x_red_[v0, v1])
+                            T.writes(x_red[v0, v1])
+                            x_red[v0, v1] = x_red_[v0, v1]
+
+    _check(before, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to