This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 73a62f647f [TIR] Preserve AllocateNode::annotations (#15242)
73a62f647f is described below
commit 73a62f647f19d914763d26f9d25fca325134a8df
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jul 6 08:04:17 2023 -0500
[TIR] Preserve AllocateNode::annotations (#15242)
Prior to this commit, some lowering passes would erroneously strip out
the annotations from `Allocate` nodes. This commit updates these
passes to preserve the annotations where present.
---
src/tir/transforms/inject_double_buffer.cc | 4 ++--
src/tir/transforms/ir_utils.cc | 3 ++-
src/tir/transforms/lower_custom_datatypes.cc | 2 +-
src/tir/transforms/lower_thread_allreduce.cc | 2 +-
src/tir/transforms/lower_warp_memory.cc | 2 +-
src/tir/transforms/update_pointer_storage_scope.cc | 9 ++++++---
6 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/src/tir/transforms/inject_double_buffer.cc
b/src/tir/transforms/inject_double_buffer.cc
index 88188425a9..4e2e79db26 100644
--- a/src/tir/transforms/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -119,8 +119,8 @@ class DoubleBufferInjector : public StmtExprMutator {
Array<PrimExpr> new_extents = {op->extents[0] *
make_const(op->extents[0].dtype(), 2)};
ICHECK(entry.loop != nullptr);
auto& alloc_nest = loop_allocs_[entry.loop];
- alloc_nest.emplace_back(
- Allocate(op->buffer_var, op->dtype, new_extents, op->condition,
Evaluate(0)));
+ alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents,
op->condition,
+ Evaluate(0), op->annotations));
Stmt body = op->body;
if (auto ptr = body.as<DeclBufferNode>()) {
auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride);
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 43bf6b983e..99ed437659 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -335,7 +335,8 @@ class IRConvertSSA final : public StmtExprMutator {
ScopedRedefine redefine(this, v);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
- return Allocate(redefine.new_var, op->dtype, op->extents, op->condition,
op->body);
+ return Allocate(redefine.new_var, op->dtype, op->extents, op->condition,
op->body,
+ op->annotations);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
diff --git a/src/tir/transforms/lower_custom_datatypes.cc
b/src/tir/transforms/lower_custom_datatypes.cc
index c5bcda2eff..273d37829d 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
allocate = stmt.as<AllocateNode>();
return Allocate(new_buffer_var, new_allocate_type, allocate->extents,
allocate->condition,
- allocate->body);
+ allocate->body, allocate->annotations);
} else {
return StmtExprMutator::VisitStmt_(allocate);
}
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index f6cda51f43..c1566936c5 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public
UpdatePointerStorageScop
// use volatile access to shared buffer.
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
}
- return Allocate(remapped, op->dtype, op->extents, op->condition, body);
+ return Allocate(remapped, op->dtype, op->extents, op->condition, body,
op->annotations);
}
return StmtExprMutator::VisitStmt_(op);
}
diff --git a/src/tir/transforms/lower_warp_memory.cc
b/src/tir/transforms/lower_warp_memory.cc
index 571f512bfd..8702359546 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -249,7 +249,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
alloc_size = warp_group_ * factor;
return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32),
alloc_size / width_)},
- op->condition, this->VisitStmt(op->body));
+ op->condition, this->VisitStmt(op->body), op->annotations);
}
protected:
diff --git a/src/tir/transforms/update_pointer_storage_scope.cc
b/src/tir/transforms/update_pointer_storage_scope.cc
index 18950bc199..2049487b4a 100644
--- a/src/tir/transforms/update_pointer_storage_scope.cc
+++ b/src/tir/transforms/update_pointer_storage_scope.cc
@@ -29,6 +29,7 @@
#include <tvm/tir/transform.h>
#include <unordered_map>
+#include <utility>
#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"
@@ -59,9 +60,11 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const
VarNode* op) {
}
Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) {
- auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
- return Allocate(remapped, op->dtype, op->extents,
StmtExprMutator::VisitExpr(op->condition),
- StmtExprMutator::VisitStmt(op->body));
+ auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+ if (auto it = new_var_remap_.find(node->buffer_var.get()); it !=
new_var_remap_.end()) {
+ node.CopyOnWrite()->buffer_var = it->second;
+ }
+ return std::move(node);
}
template <typename Node>