This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6365a302d1 [TIR] Fix reduce buffer allocation position (#17799)
6365a302d1 is described below
commit 6365a302d179f13109b01a25b640e0250523ad03
Author: wrongtest <[email protected]>
AuthorDate: Thu Apr 3 09:05:00 2025 +0800
[TIR] Fix reduce buffer allocation position (#17799)
* fix reduce buffer allocation position
* fix test_tir_analysis_detect_buffer_access_lca.py::test_buffer_load_store
---
src/tir/analysis/buffer_access_lca_detector.cc | 91 +++++++++++++++-------
.../test_tir_analysis_detect_buffer_access_lca.py | 7 +-
...sform_plan_update_buffer_allocation_location.py | 50 ++++++++++++
3 files changed, 115 insertions(+), 33 deletions(-)
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc
b/src/tir/analysis/buffer_access_lca_detector.cc
index ff0b11a73c..dd1fce0fbe 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -117,10 +117,13 @@ class LCADetector : public StmtExprVisitor {
ancestor_scopes_.push_back(current_scope);
- // For each accessed buffer of the block, update the buffer's lca to
+ // For each accessed buffer of the block
+ // If it accesses the opaque block iter vars, update the buffer's lca to
// the lowest inclusive stmt position, which should dominate all loops
- // related to the accessed opaque block iter vars in buffer indices.
- UpdateDominateScopeOfOpaqueIter(op);
+ // related to the accessed opaque block iter vars.
+ // If it is the reduction block write buffer, update the buffer's lca to
+ // dominate all reduction iter var related loops.
+ UpdateDominateScopeOfNonDataParIter(op);
// Update match_buffers
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
@@ -132,43 +135,70 @@ class LCADetector : public StmtExprVisitor {
ancestor_scopes_.pop_back();
}
- void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
- // map opaque iter var to the scope which dominate all loop carried
dependencies.
- std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;
+ void UpdateDominateScopeOfNonDataParIter(const BlockRealizeNode*
block_realize) {
+ // map iter var to the scope which dominate all loop carried dependencies.
+ std::unordered_map<const VarNode*, const ScopeInfo*> opaque_var_scope;
+ // maintain highest scope which dominate all reduce loop iters. null
denotes non-reduce block.
+ const ScopeInfo* highest_reduce_scope = nullptr;
// function to collect `itervar_to_dom_scope`, the result scope for each
block
// iter var should be above all loop scopes the opaque iter var binding
relates to.
- auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const
IterVar& itervar,
- const
PrimExpr& binding) {
- PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const
ObjectRef& obj) {
+ auto do_collect_itervar_scope = [this](const IterVar& itervar,
+ const PrimExpr& binding) -> const
ScopeInfo* {
+ const ScopeInfo* highest_scope = nullptr;
+ PostOrderVisit(binding, [this, &itervar, &highest_scope](const
ObjectRef& obj) {
if (const VarNode* loop_var = obj.as<VarNode>()) {
auto it = loop_scope_map_.find(loop_var);
if (it == loop_scope_map_.end()) {
return;
}
const ScopeInfo* scope = it->second->parent_scope_info;
- // find the highest loop scope the iter var binding has related to.
- auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
- if (dom_scope_it == itervar_to_dom_scope.end()) {
- itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(),
scope});
- } else if (scope->depth < dom_scope_it->second->depth) {
- dom_scope_it->second = scope;
+ if (highest_scope == nullptr) {
+ highest_scope = scope;
+ } else if (scope->depth < highest_scope->depth) {
+ highest_scope = scope;
}
}
});
+ return highest_scope;
};
+ // collect non-data-parallel block iteration's dominate scope.
+ // for reduction iter type, we maintain the highest dominate scope for all
reduce iters.
+ // for other iter type, we maintain the dict for each individual iter.
+ const Block& block = block_realize->block;
+ bool is_reduce_block = false;
+ for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ if (iter_var->iter_type != IterVarType::kDataPar) {
+ const auto* scope = do_collect_itervar_scope(iter_var,
block_realize->iter_values[i]);
+ if (scope == nullptr) continue;
+ if (iter_var->iter_type == IterVarType::kCommReduce) {
+ is_reduce_block = true;
+ if (highest_reduce_scope == nullptr || scope->depth <
highest_reduce_scope->depth) {
+ highest_reduce_scope = scope;
+ }
+ } else {
+ opaque_var_scope[iter_var->var.get()] = scope;
+ for (const auto& write : block->writes) {
+ UpdateBufferLCA(write->buffer.get(), scope);
+ }
+ }
+ }
+ }
+
// function to update lca scope of the buffer with loop carried dependent
buffer accesses.
// the result scope should be above all loop scopes the accessed opaque
block iter vars
// relate to, which is record in `itervar_to_dom_scope`.
- auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region)
{
+ auto do_update = [this, &opaque_var_scope, highest_reduce_scope](const
BufferRegion& region,
+ bool
is_reduce_write = false) {
const Buffer& buffer = region->buffer;
const ScopeInfo* scope = ancestor_scopes_.back();
- auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef&
obj) {
+ auto handle_itervar = [&opaque_var_scope, &scope](const ObjectRef& obj) {
if (const VarNode* iter_var = obj.as<VarNode>()) {
- auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
- if (dom_scope_it == itervar_to_dom_scope.end()) {
+ auto dom_scope_it = opaque_var_scope.find(iter_var);
+ if (dom_scope_it == opaque_var_scope.end()) {
return;
}
// find the highest loop scope the accessed buffer index has
@@ -184,24 +214,25 @@ class LCADetector : public StmtExprVisitor {
PostOrderVisit(range->min, handle_itervar);
PostOrderVisit(range->min + range->extent - 1, handle_itervar);
}
+
+ // the scope should be above `highest_reduce_scope` for reduce output
buffer.
+ if (is_reduce_write && highest_reduce_scope != nullptr &&
+ scope->depth > highest_reduce_scope->depth) {
+ scope = highest_reduce_scope;
+ }
UpdateBufferLCA(buffer.get(), scope);
};
- // do collect and update
- const Block& block = block_realize->block;
- for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
- const IterVar& iter_var = block->iter_vars[i];
- if (iter_var->iter_type != IterVarType::kDataPar &&
- iter_var->iter_type != IterVarType::kCommReduce) {
- do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
- }
- }
- if (!itervar_to_dom_scope.empty()) {
+ if (!opaque_var_scope.empty()) {
for (const auto& read : block->reads) {
do_update(read);
}
for (const auto& write : block->writes) {
- do_update(write);
+ do_update(write, /*is_reduce_write=*/is_reduce_block);
+ }
+ } else if (is_reduce_block && highest_reduce_scope != nullptr) {
+ for (const auto& write : block->writes) {
+ do_update(write, /*is_reduce_write=*/true);
}
}
}
diff --git
a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
index a1808c8413..b3ce7efd05 100644
--- a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
+++ b/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
@@ -116,9 +116,10 @@ def test_buffer_load_store():
root_block = func.body.block
assert lca[A] == func.body.block
- # LCA of Buffer B is reduction block
- reduce_block = root_block.body[1].body.body.body.block
- assert lca[B] == reduce_block
+ # LCA of Buffer B is the loop dominate all reduction loop
+ reduce_dom_loop = root_block.body[1].body
+ reduce_block = reduce_dom_loop.body.body.block
+ assert lca[B] == reduce_dom_loop
# LCA of Buffer C is the second loop kk
loop_jj = reduce_block.body.body
diff --git
a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
index 8500f11461..ff3fa8cf70 100644
---
a/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
+++
b/tests/python/tir-transform/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -402,5 +402,55 @@ def test_dltensor_buffer_is_unlowered():
_check(before, after)
+def test_reduce_buffer_dominate_reduce_loops():
+ """Reduction write buffer allocation should dominate all reduce loops"""
+
+ @T.prim_func
+ def before(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256,
256), "float32")):
+ x_red_ = T.alloc_buffer((256, 256))
+ for ax0_0, k1_0, ax1_0 in T.grid(4, 4, 4):
+ for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
+ with T.block("x_red"):
+ v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
+ v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
+ v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
+ if v_k1 == 0:
+ x_red_[v_ax0, v_ax1] = T.float32(0.0)
+ x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] + x[v_ax0,
v_k1, v_ax1]
+ for ax0, ax1 in T.grid(64, 64):
+ with T.block("x_red_"):
+ v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
+ v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
+ x_red[v0, v1] = x_red_[v0, v1]
+
+ @T.prim_func
+ def after(x: T.Buffer((256, 256, 256), "float32"), x_red: T.Buffer((256,
256), "float32")):
+ for ax0_0 in range(4):
+ with T.block(""):
+ T.reads(x[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256, 0:256])
+ T.writes(x_red[ax0_0 * 64 : ax0_0 * 64 + 64, 0:256])
+ x_red_ = T.alloc_buffer((256, 256))
+ for k1_0, ax1_0 in T.grid(4, 4):
+ for ax0_1, k1_1, ax1_1 in T.grid(64, 64, 64):
+ with T.block("x_red"):
+ v_ax0 = T.axis.spatial(256, ax0_0 * 64 + ax0_1)
+ v_ax1 = T.axis.spatial(256, ax1_0 * 64 + ax1_1)
+ v_k1 = T.axis.reduce(256, k1_0 * 64 + k1_1)
+ T.reads(x_red_[v_ax0, v_ax1], x[v_ax0, v_k1,
v_ax1])
+ T.writes(x_red_[v_ax0, v_ax1])
+ if v_k1 == 0:
+ x_red_[v_ax0, v_ax1] = T.float32(0.0)
+ x_red_[v_ax0, v_ax1] = x_red_[v_ax0, v_ax1] +
x[v_ax0, v_k1, v_ax1]
+ for ax0, ax1 in T.grid(64, 64):
+ with T.block("x_red_"):
+ v0 = T.axis.spatial(256, ax0_0 * 64 + ax0)
+ v1 = T.axis.spatial(256, ax1_0 * 64 + ax1)
+ T.reads(x_red_[v0, v1])
+ T.writes(x_red[v0, v1])
+ x_red[v0, v1] = x_red_[v0, v1]
+
+ _check(before, after)
+
+
if __name__ == "__main__":
tvm.testing.main()