This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b94119692e [TIR] Ignore Allocate/AllocateConst in
BufferAllocationLocator (#10998)
b94119692e is described below
commit b94119692eaa7307201fbad3e3434f8721c50ede
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Apr 14 15:18:09 2022 -0500
[TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (#10998)
* [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator
Prior to this commit, the BufferAllocationLocator mutator used in the
PlanAndUpdateBufferAllocationLocation pass would erroneously insert an
entry to `BlockNode::alloc_buffers` for buffers allocated using
`Allocate` or `AllocateConst` nodes. This error was introduced in
https://github.com/apache/tvm/pull/9727, which deprecated `Load` and
`Store` nodes, replacing them with `BufferLoad` and `BufferStore`
nodes. As a result, BufferAllocationLocator identified these as
buffers whose allocations should be moved to inner loops, rather than
as unmanaged allocations that should be ignored.
This commit restores the earlier behavior by only operating on buffer
allocations in `BlockNode::alloc_buffers`, and explicitly ignoring any
buffers whose allocation is done with `Allocate` or `AllocateConst`.
* Only inject opaque block if managed buffers exist.
Previously, all buffers found were managed buffers, so this check
wasn't needed.
---
.../plan_update_buffer_allocation_location.cc | 33 ++++++++++++++++------
.../test_tir_transform_extract_constants.py | 2 ++
2 files changed, 27 insertions(+), 8 deletions(-)
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 6b495b3bf4..81dfceb40d 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -61,16 +61,21 @@ class BufferAllocationLocator : public StmtExprMutator {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.Set(buf->data, buf);
}
- Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<ForNode>();
- ICHECK(op != nullptr);
+ auto node = Downcast<For>(StmtMutator::VisitStmt_(op));
+
+ Array<Buffer> new_block_alloc_bufs;
for (const Buffer& buf : it->second) {
- buffer_data_to_buffer_.erase(buf->data);
+ if (!unmanaged_allocations_.count(buf->data.get())) {
+ buffer_data_to_buffer_.erase(buf->data);
+ new_block_alloc_bufs.push_back(buf);
+ }
}
- Stmt body = InjectOpaqueBlock(op->body, it->second);
- ObjectPtr<ForNode> n = CopyOnWrite(op);
- n->body = std::move(body);
- return Stmt(n);
+
+ if (new_block_alloc_bufs.size()) {
+ node.CopyOnWrite()->body = InjectOpaqueBlock(node->body,
new_block_alloc_bufs);
+ }
+
+ return std::move(node);
}
Stmt VisitStmt_(const BlockNode* op) final {
@@ -114,6 +119,16 @@ class BufferAllocationLocator : public StmtExprMutator {
return Stmt(n);
}
+ Stmt VisitStmt_(const AllocateNode* op) final {
+ unmanaged_allocations_.insert(op->buffer_var.get());
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ Stmt VisitStmt_(const AllocateConstNode* op) final {
+ unmanaged_allocations_.insert(op->buffer_var.get());
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
Stmt VisitStmt_(const BufferRealizeNode* op) final {
ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in
TensorIR.";
throw;
@@ -151,6 +166,8 @@ class BufferAllocationLocator : public StmtExprMutator {
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
+ /*! \brief Buffers that are allocated outside of the BlockNode, and should
not be moved. */
+ std::unordered_set<const VarNode*> unmanaged_allocations_;
};
PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py
b/tests/python/unittest/test_tir_transform_extract_constants.py
index 9636a9bdde..cb49e7286f 100644
--- a/tests/python/unittest/test_tir_transform_extract_constants.py
+++ b/tests/python/unittest/test_tir_transform_extract_constants.py
@@ -59,6 +59,8 @@ def test_const_extraction():
for n, f in mod.functions.items():
tvm.tir.stmt_functor.post_order_visit(f.body, _visit)
+ tvm.lower(mod)
+
if __name__ == "__main__":
test_const_extraction()