This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 73a62f647f [TIR] Preserve AllocateNode::annotations (#15242)
73a62f647f is described below

commit 73a62f647f19d914763d26f9d25fca325134a8df
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jul 6 08:04:17 2023 -0500

    [TIR] Preserve AllocateNode::annotations (#15242)
    
    Prior to this commit, some lowering passes would erroneously strip out
    the annotations from `Allocate` nodes.  This commit updates these
    passes to preserve the annotations where present.
---
 src/tir/transforms/inject_double_buffer.cc         | 4 ++--
 src/tir/transforms/ir_utils.cc                     | 3 ++-
 src/tir/transforms/lower_custom_datatypes.cc       | 2 +-
 src/tir/transforms/lower_thread_allreduce.cc       | 2 +-
 src/tir/transforms/lower_warp_memory.cc            | 2 +-
 src/tir/transforms/update_pointer_storage_scope.cc | 9 ++++++---
 6 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/src/tir/transforms/inject_double_buffer.cc 
b/src/tir/transforms/inject_double_buffer.cc
index 88188425a9..4e2e79db26 100644
--- a/src/tir/transforms/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -119,8 +119,8 @@ class DoubleBufferInjector : public StmtExprMutator {
       Array<PrimExpr> new_extents = {op->extents[0] * 
make_const(op->extents[0].dtype(), 2)};
       ICHECK(entry.loop != nullptr);
       auto& alloc_nest = loop_allocs_[entry.loop];
-      alloc_nest.emplace_back(
-          Allocate(op->buffer_var, op->dtype, new_extents, op->condition, 
Evaluate(0)));
+      alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, 
op->condition,
+                                       Evaluate(0), op->annotations));
       Stmt body = op->body;
       if (auto ptr = body.as<DeclBufferNode>()) {
         auto new_buf = GetRemappedBuffer(ptr->buffer, entry.stride);
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 43bf6b983e..99ed437659 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -335,7 +335,8 @@ class IRConvertSSA final : public StmtExprMutator {
       ScopedRedefine redefine(this, v);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AllocateNode>();
-      return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, 
op->body);
+      return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, 
op->body,
+                      op->annotations);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
diff --git a/src/tir/transforms/lower_custom_datatypes.cc 
b/src/tir/transforms/lower_custom_datatypes.cc
index c5bcda2eff..273d37829d 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
       allocate = stmt.as<AllocateNode>();
 
       return Allocate(new_buffer_var, new_allocate_type, allocate->extents, 
allocate->condition,
-                      allocate->body);
+                      allocate->body, allocate->annotations);
     } else {
       return StmtExprMutator::VisitStmt_(allocate);
     }
diff --git a/src/tir/transforms/lower_thread_allreduce.cc 
b/src/tir/transforms/lower_thread_allreduce.cc
index f6cda51f43..c1566936c5 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -53,7 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public 
UpdatePointerStorageScop
         // use volatile access to shared buffer.
         body = AttrStmt(remapped, attr::volatile_scope, 1, body);
       }
-      return Allocate(remapped, op->dtype, op->extents, op->condition, body);
+      return Allocate(remapped, op->dtype, op->extents, op->condition, body, 
op->annotations);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
diff --git a/src/tir/transforms/lower_warp_memory.cc 
b/src/tir/transforms/lower_warp_memory.cc
index 571f512bfd..8702359546 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -249,7 +249,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
     alloc_size = warp_group_ * factor;
 
     return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), 
alloc_size / width_)},
-                    op->condition, this->VisitStmt(op->body));
+                    op->condition, this->VisitStmt(op->body), op->annotations);
   }
 
  protected:
diff --git a/src/tir/transforms/update_pointer_storage_scope.cc 
b/src/tir/transforms/update_pointer_storage_scope.cc
index 18950bc199..2049487b4a 100644
--- a/src/tir/transforms/update_pointer_storage_scope.cc
+++ b/src/tir/transforms/update_pointer_storage_scope.cc
@@ -29,6 +29,7 @@
 #include <tvm/tir/transform.h>
 
 #include <unordered_map>
+#include <utility>
 
 #include "../../runtime/thread_storage_scope.h"
 #include "ir_utils.h"
@@ -59,9 +60,11 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const 
VarNode* op) {
 }
 
 Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) {
-  auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
-  return Allocate(remapped, op->dtype, op->extents, 
StmtExprMutator::VisitExpr(op->condition),
-                  StmtExprMutator::VisitStmt(op->body));
+  auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+  if (auto it = new_var_remap_.find(node->buffer_var.get()); it != 
new_var_remap_.end()) {
+    node.CopyOnWrite()->buffer_var = it->second;
+  }
+  return std::move(node);
 }
 
 template <typename Node>

Reply via email to