This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7214f52 [TIR] Fix opaque access in buffer locator pass and
match_buffer in region detector (#8855)
7214f52 is described below
commit 7214f5239dbb8da4585d4d10fbc8c65c8f155b12
Author: Siyuan Feng <[email protected]>
AuthorDate: Sat Aug 28 17:23:43 2021 +0800
[TIR] Fix opaque access in buffer locator pass and match_buffer in region
detector (#8855)
* init
* fix
* Update src/tir/transforms/plan_update_buffer_allocation_location.cc
Co-authored-by: Ruihang Lai <[email protected]>
* Update src/tir/transforms/plan_update_buffer_allocation_location.cc
Co-authored-by: Ruihang Lai <[email protected]>
* address
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
---
src/tir/analysis/block_access_region_detector.cc | 7 ++-
.../plan_update_buffer_allocation_location.cc | 39 +++++++++-----
.../test_tir_analysis_get_block_access_region.py | 21 +++++---
...sform_plan_update_buffer_allocation_location.py | 62 ++++++++++++++++++++++
4 files changed, 109 insertions(+), 20 deletions(-)
diff --git a/src/tir/analysis/block_access_region_detector.cc
b/src/tir/analysis/block_access_region_detector.cc
index 8f87ef9..dd01aed 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -110,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) {
ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " <<
stmt->GetTypeKey();
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
- match_buffers_[target_var.get()] = match_buffer;
- buffer_var_map_.Set(target_var, match_buffer->buffer);
+ const Var& source_var = match_buffer->source->buffer->data;
+ if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) {
+ match_buffers_[target_var.get()] = match_buffer;
+ buffer_var_map_.Set(target_var, match_buffer->buffer);
+ }
}
StmtExprVisitor::operator()(stmt);
}
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index bee11ad..59f9170 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -75,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator {
Stmt VisitStmt_(const BlockNode* op) final {
ICHECK(!op->init.defined());
- bool is_root = is_root_;
- is_root_ = false;
Array<Buffer> alloc_buffers;
auto it = alloc_buffers_.find(op);
if (it != alloc_buffers_.end()) {
@@ -85,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator {
buffer_data_to_buffer_.Set(buf->data, buf);
}
}
+ for (const MatchBufferRegion match_buffer : op->match_buffers) {
+ const Var& target_var = match_buffer->buffer->data;
+ const Var& source_var = match_buffer->source->buffer->data;
+ ICHECK(buffer_data_to_buffer_.count(source_var));
+ buffer_data_to_buffer_.Set(target_var, match_buffer->buffer);
+ }
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);
- // Ignore buffer allocated inside the block when getting access region.
+ // No longer consider buffers created by match_buffer inside the block
when updating access
+ // region.
+ for (const MatchBufferRegion match_buffer : op->match_buffers) {
+ const Var& target_var = match_buffer->buffer->data;
+ buffer_data_to_buffer_.erase(target_var);
+ }
+ // No longer consider buffers allocated inside the block when updating
access region.
if (it != alloc_buffers_.end()) {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
@@ -98,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator {
ObjectPtr<BlockNode> n = CopyOnWrite(op);
n->alloc_buffers = std::move(alloc_buffers);
- // The read/write regions of root block are always empty.
- if (!is_root) {
- // Recalculate block access region
- CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
- }
-
+ // Erase buffer allocated inside the block from access region.
+ n->reads = RemoveRedundantBufferRegion(n->reads);
+ n->writes = RemoveRedundantBufferRegion(n->writes);
return Stmt(n);
}
@@ -127,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator {
return std::move(realize);
}
+ Array<BufferRegion> RemoveRedundantBufferRegion(const Array<BufferRegion>&
region) const {
+ Array<BufferRegion> result;
+ for (const BufferRegion& buffer_region : region) {
+ if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) {
+ result.push_back(buffer_region);
+ }
+ }
+ return result;
+ }
+
void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
- Array<BufferRegion>* writes) {
+ Array<BufferRegion>* writes) const {
Array<Array<BufferRegion>> access = GetBlockAccessRegion(block,
buffer_data_to_buffer_);
*reads = access[0];
*writes = access[1];
@@ -142,8 +159,6 @@ class BufferAllocationLocator : public StmtExprMutator {
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
- /*! \brief indicate the whether the block is root. */
- bool is_root_{true};
};
PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 7641f0a..9c95b98 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -114,20 +114,29 @@ def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
block_inner = block.body[0].body.body.block
- alloc_buffers = func.body.block.alloc_buffers
+ alloc_buffers = match_buffer_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- # Check inner block AAA
- ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
- tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
- tvm.ir.assert_structural_equal(block_inner.writes, ret[1])
-
# Check block
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.writes, ret[1])
# B is opaque access
tvm.ir.assert_structural_equal(block.reads, ret[2])
+ # Check inner block AAA without updating buffer_var_map
+ ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
+ # Since AA is not in the buffer_var_map, region of AA will not be
collected.
+ tvm.ir.assert_structural_equal([], ret[1])
+
+ # Check inner block AAA
+ for match_buffer in block.match_buffers:
+ target_buffer = match_buffer.buffer
+ buffer_var_map[target_buffer.data] = target_buffer
+
+ ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
+ tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
+ tvm.ir.assert_structural_equal(block_inner.writes, ret[1])
+
if __name__ == "__main__":
test_block_access_region_detector()
diff --git
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 8418e19..07140ab 100644
---
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -137,6 +137,63 @@ def transformed_match_buffer_func() -> None:
C1[()] = 0
[email protected]
+def opaque_access(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, [1024])
+ B = tir.match_buffer(b, [1024])
+ A_cache = tir.alloc_buffer([1024])
+ for i in tir.serial(0, 8):
+ with tir.block([8]) as [vi]:
+ with tir.block([8]) as [v]:
+ tir.bind(v, vi)
+ tir.reads([A[(v * 128) : ((v * 128) + 128)]])
+ tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]])
+ tir.evaluate(
+ tir.call_extern(
+ "test",
+ A_cache.data,
+ (v * 128),
+ 128,
+ A.data,
+ (v * 128),
+ 128,
+ dtype="float32",
+ )
+ )
+ for j in tir.serial(0, 128):
+ with tir.block([1024]) as [v]:
+ tir.bind(v, ((vi * 128) + j))
+ tir.reads([A_cache[v]])
+ tir.writes([B[v]])
+ B[v] = A_cache[v]
+
+
[email protected]
+def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, [1024])
+ B = tir.match_buffer(b, [1024])
+ for i in tir.serial(0, 8):
+ with tir.block([8]) as [vi]:
+ tir.reads(A[vi * 128 : vi * 128 + 128])
+ tir.writes(B[vi * 128 : vi * 128 + 128])
+ A_cache = tir.alloc_buffer([1024])
+ with tir.block([8]) as [v]:
+ tir.bind(v, vi)
+ tir.reads([A[v * 128 : v * 128 + 128]])
+ tir.writes([A_cache[v * 128 : v * 128 + 128]])
+ tir.evaluate(
+ tir.call_extern(
+ "test", A_cache.data, v * 128, 128, A.data, v * 128,
128, dtype="float32"
+ )
+ )
+ for j in tir.serial(0, 128):
+ with tir.block([1024]) as [v]:
+ tir.bind(v, ((vi * 128) + j))
+ tir.reads([A_cache[v]])
+ tir.writes([B[v]])
+ B[v] = A_cache[v]
+
+
def test_elementwise():
_check(element_func, transformed_element_func)
@@ -149,6 +206,10 @@ def test_match_buffer_allocation():
_check(match_buffer_func, transformed_match_buffer_func)
+def test_opaque_access():
+ _check(opaque_access, transformed_opaque_access)
+
+
def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)
@@ -164,4 +225,5 @@ if __name__ == "__main__":
test_elementwise()
test_locate_buffer_allocation()
test_match_buffer_allocation()
+ test_opaque_access()
test_lower_te()