This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ded4be4b39636c145fde4eeeafdc6f95800027eb
Author: LeiWang <[email protected]>
AuthorDate: Wed Jan 3 12:25:58 2024 -0400

    enhance shared memory merge.
---
 include/tvm/tir/transform.h                        |   4 +-
 python/tvm/tir/transform/transform.py              |   4 +-
 src/driver/driver_api.cc                           |   3 +-
 src/meta_schedule/postproc/verify_gpu_code.cc      |   2 +-
 ...tions.cc => merge_shared_memory_allocations.cc} | 115 ++++++++----
 src/tir/transforms/storage_rewrite.cc              |  19 +-
 .../test_tir_transform_inject_ptx_async_copy.py    |   2 +-
 ...form_merge_dynamic_shared_memory_allocations.py |  15 +-
 ...sform_merge_static_shared_memory_allocations.py | 203 +++++++++++++++++++++
 9 files changed, 310 insertions(+), 57 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index f5fdd191af..45e44f96a6 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -614,9 +614,9 @@ TVM_DLL Pass InstallDebugSpans();
 TVM_DLL Pass UnifyThreadBinding();
 
 /*!
- *  A pass to merge multiple TIR-level dynamic shared memory allocations into 
one
+ *  A pass to merge multiple TIR-level shared memory allocations into one
  */
-TVM_DLL Pass MergeDynamicSharedMemoryAllocations();
+TVM_DLL Pass MergeSharedMemoryAllocations();
 
 /*!
  * \brief This pass is post-scheduling pass to convert all
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index c58062045c..de61823cdf 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1000,7 +1000,7 @@ def UnifyThreadBinding():
     return _ffi_api.UnifyThreadBinding()  # type: ignore
 
 
-def MergeDynamicSharedMemoryAllocations():
+def MergeSharedMemoryAllocations():
     """This pass merges multiple TIR-level dynamic shared memory allocations
     into one allocation.
 
@@ -1009,7 +1009,7 @@ def MergeDynamicSharedMemoryAllocations():
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.MergeDynamicSharedMemoryAllocations()  # type: ignore
+    return _ffi_api.MergeSharedMemoryAllocations()  # type: ignore
 
 
 def ConvertForLoopsToSerial():
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index b7ba0ffe44..17cd5c49a1 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -52,6 +52,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
@@ -584,7 +585,7 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
 
   mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
   mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
-  
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
+  mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
   mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
   mixed_pass_list.push_back(tir::transform::InferFragment());
   mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc 
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 2fb97d32eb..1728306354 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -179,7 +179,7 @@ class VerifyGPUCodeNode : public PostprocNode {
           pass_list.push_back(tir::transform::InjectVirtualThread());
           pass_list.push_back(tir::transform::InjectDoubleBuffer());
           pass_list.push_back(tir::transform::StorageRewrite());
-          
pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
+          pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
           pass_list.push_back(tir::transform::LowerIntrin());
           // Convert Function to IRModule
           transform::PassContext pass_ctx = transform::PassContext::Current();
diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc 
b/src/tir/transforms/merge_shared_memory_allocations.cc
similarity index 82%
rename from src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
rename to src/tir/transforms/merge_shared_memory_allocations.cc
index 99055cebf2..c0e7cf1ebe 100644
--- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
+++ b/src/tir/transforms/merge_shared_memory_allocations.cc
@@ -18,9 +18,9 @@
  */
 
 /*!
- * \file merge_dynamic_shared_memory_allocations.cc
- * \brief Each GPU kernel is allowed to have only one dynamic shared memory 
allocation.
- * This pass merges multiple TIR-level dynamic shared memory allocations into 
one allocation.
+ * \file merge_shared_memory_allocations.cc
+ * \brief Each GPU kernel is allowed to have only one dynamic or static shared 
memory allocation.
+ * This pass merges multiple TIR-level dynamic or shared memory allocations 
into one allocation.
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
@@ -45,6 +45,11 @@ bool IsDynamicSharedMemory(Var buffer_var) {
   return storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn";
 }
 
+bool IsStaticSharedMemory(Var buffer_var) {
+  StorageScope storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  return storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == "";
+}
+
 /*!
  * \brief collect the mapping from the buffer var to its allocate
  */
@@ -53,11 +58,15 @@ class AllocateCollector : public StmtExprVisitor {
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
       dyn_shmem_allocs_[op->buffer_var.get()] = op;
+    } else if (IsStaticSharedMemory(op->buffer_var)) {
+      static_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
-  // The mapping from the original buffer var to its allocate
+  // The dynamic mapping from the original buffer var to its allocate
   std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+  // The static mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> static_shmem_allocs_;
 };
 
 // Find a linear pattern of storage access
@@ -73,8 +82,9 @@ class AllocateCollector : public StmtExprVisitor {
 // The storage need to be kept alive between Allocate and last access.
 // The free point is only inserted at the same scope of Allocate.
 //
-class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
  public:
+  explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true) : 
is_dynamic_(is_dynamic) {}
   /*! \brief record the touch list of statement. */
   struct StmtEntry {
     // The statement
@@ -112,7 +122,7 @@ class DynSharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size());
-      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+      if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
         scope_[it->second.level].touched.push_back(buf);
       }
     }
@@ -143,7 +153,7 @@ class DynSharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places 
other than store.";
-      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+      if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
         scope_[it->second.level].touched.push_back(buf);
       }
     }
@@ -164,7 +174,7 @@ class DynSharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size());
-      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+      if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
         scope_[it->second.level].touched.push_back(buf);
       }
     }
@@ -217,6 +227,12 @@ class DynSharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
   std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
 
  private:
+  // Wrapper function to determine if the shared memory allocation for a 
variable is appropriate.
+  bool IsAppropriateSharedMemory(const Var& var) {
+    return is_dynamic_ ? IsDynamicSharedMemory(var) : 
IsStaticSharedMemory(var);
+  }
+  // Whether do dyanmic analysis.
+  bool is_dynamic_{true};
   // Whether already in thread env.
   bool in_thread_env_{false};
   // The scope stack.
@@ -226,18 +242,23 @@ class DynSharedMemLinearAccessPatternFinder final : 
public StmtExprVisitor {
 /*!
  * \brief merge the buffers whose live range has no intersection and rewrite 
the body
  */
-class DynamicSharedMemoryRewriter : public StmtExprMutator {
+class SharedMemoryRewriter : public StmtExprMutator {
  public:
-  explicit DynamicSharedMemoryRewriter(
-      const std::unordered_map<const VarNode*, const AllocateNode*>& 
dyn_shmem_allocs)
-      : dyn_shmem_allocs_{dyn_shmem_allocs} {}
+  explicit SharedMemoryRewriter(
+      const std::unordered_map<const VarNode*, const AllocateNode*>& 
shmem_allocs,
+      bool is_dynamic = true)
+      : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} {
+    if (!is_dynamic) {
+      merged_buf_var_ = Var("buf_shmem", 
PointerType(PrimType(DataType::UInt(8)), "shared"));
+    }
+  }
 
   /*!
    * \brief plan the memory reuse for all the buffer allocated in the statement
    * \param stmt the statement
    */
-  void PlanReuse(const Stmt& stmt) {
-    DynSharedMemLinearAccessPatternFinder finder;
+  void PlanReuse(const Stmt& stmt, bool is_dynamic = true) {
+    SharedMemLinearAccessPatternFinder finder(is_dynamic);
     finder(stmt);
     this->LivenessAnalysis(finder.linear_seq_);
     this->PlanMemory(finder.linear_seq_);
@@ -263,7 +284,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
       for (const StorageEntry* e : all_entry) {
         for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
           for (const VarNode* buffer : e->allocs[i]) {
-            const AllocateNode* alloc = dyn_shmem_allocs_[buffer];
+            const AllocateNode* alloc = shmem_allocs_[buffer];
             align[i] = std::max(align[i], alloc->dtype.bytes());
           }
         }
@@ -274,7 +295,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
         for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
           PrimExpr inner_offset = 0;
           for (const VarNode* buffer : e->allocs[i]) {
-            const AllocateNode* alloc = dyn_shmem_allocs_[buffer];
+            const AllocateNode* alloc = shmem_allocs_[buffer];
             buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
             inner_offset += alloc->extents[0] * alloc->dtype.bytes();
             inner_offset += indexmod(align[i] - indexmod(inner_offset, 
align[i]), align[i]);
@@ -293,7 +314,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
   }
 
   Stmt VisitStmt_(const AllocateNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
+    if (IsAppropriateSharedMemory(op->buffer_var)) {
       return StmtExprMutator::VisitStmt(op->body);
     }
     return StmtExprMutator::VisitStmt_(op);
@@ -319,9 +340,9 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
 
   template <typename Node>
   Node VisitBufferAccess(Node node) {
-    if (IsDynamicSharedMemory(node->buffer->data)) {
+    if (IsAppropriateSharedMemory(node->buffer->data)) {
       ICHECK_EQ(node->indices.size(), 1)
-          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, 
"
+          << "MergeSharedMemoryAllocations expects flat memory buffers, "
           << "and is to be run after "
           << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
       Array<PrimExpr> indices = {node->indices[0] +
@@ -342,10 +363,10 @@ class DynamicSharedMemoryRewriter : public 
StmtExprMutator {
       return it->second;
     }
 
-    if (IsDynamicSharedMemory(buffer->data)) {
+    if (IsAppropriateSharedMemory(buffer->data)) {
       ICHECK_EQ(buffer->shape.size(), 1)
           << "Buffer " << buffer << " has shape " << buffer->shape << ".  "
-          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, 
"
+          << "MergeSharedMemoryAllocations expects flat memory buffers, "
           << "and is to be run after "
           << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
       auto writer = buffer.CopyOnWrite();
@@ -361,7 +382,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
       ICHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
       Var buffer = Downcast<Var>(op->args[1]);
-      if (!IsDynamicSharedMemory(buffer)) {
+      if (!IsAppropriateSharedMemory(buffer)) {
         return StmtExprMutator::VisitExpr_(op);
       }
       PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
@@ -381,7 +402,12 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator 
{
     return indexdiv(it->second, dtype.bytes());
   }
 
-  using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry;
+  // Wrapper function to determine if the shared memory allocation for a 
variable is appropriate.
+  bool IsAppropriateSharedMemory(const Var& var) {
+    return is_dynamic_ ? IsDynamicSharedMemory(var) : 
IsStaticSharedMemory(var);
+  }
+
+  using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
   struct StorageEntry {
     // The constant size of the buffer in bits, only used if it is constant
     uint64_t const_nbits{0};
@@ -447,9 +473,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator 
{
       // - leaf stmt(offset = 0)
       // - end of scope(offset < 0)
       // In both cases, we need to handle the kill event correctly
+      auto is_leaf_alloc = [&](const VarNode* var) {
+        return seq[i].scope_pair_offset == 0 &&
+               std::find(it->second.gen.begin(), it->second.gen.end(), var) != 
it->second.gen.end();
+      };
       if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
         for (const VarNode* var : it->second.kill) {
-          this->Free(var);
+          if (!is_leaf_alloc(var)) this->Free(var);
         }
       }
       // scope_pair_offset >= 0 means it is either
@@ -458,12 +488,17 @@ class DynamicSharedMemoryRewriter : public 
StmtExprMutator {
       // In both cases, we need to handle the gen event correctly
       if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
         for (const VarNode* var : it->second.gen) {
-          ICHECK(dyn_shmem_allocs_.count(var));
-          const AllocateNode* alloc = dyn_shmem_allocs_[var];
+          ICHECK(shmem_allocs_.count(var));
+          const AllocateNode* alloc = shmem_allocs_[var];
           StorageEntry* dst_entry = FindAlloc(alloc);
           alloc_map_[var] = dst_entry;
         }
       }
+      if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+        for (const VarNode* var : it->second.kill) {
+          if (is_leaf_alloc(var)) this->Free(var);
+        }
+      }
     }
   }
   /*!
@@ -510,6 +545,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
         StorageEntry* e = it->second;
         e->const_nbits = std::max(const_nbits, e->const_nbits);
         const_free_map_.erase(it);
+        it->second->allocs.push_back({op->buffer_var.get()});
         return e;
       }
       // Then start looking at smaller buffers.
@@ -578,10 +614,12 @@ class DynamicSharedMemoryRewriter : public 
StmtExprMutator {
       sym_free_list_.push_back(e);
     }
   }
+  // Wheather enable dyanmic analysis.
+  bool is_dynamic_{true};
   // The var for the merged buffer
   Var merged_buf_var_{"buf_dyn_shmem", 
PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
   // The mapping from the original buffer var to its allocate
-  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+  std::unordered_map<const VarNode*, const AllocateNode*> shmem_allocs_;
   // The size of the merged buffer
   PrimExpr merged_alloc_size_{0};
   // The mapping from the original buffer var to its offset in the merged 
buffer
@@ -602,30 +640,37 @@ class DynamicSharedMemoryRewriter : public 
StmtExprMutator {
   support::Arena arena_;
 };
 
-Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
+Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) {
   AllocateCollector collector;
   collector(stmt);
   if (collector.dyn_shmem_allocs_.size() > 1) {
-    DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
+    SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
     rewriter.PlanReuse(stmt);
-    return rewriter(std::move(stmt));
+    stmt = rewriter(std::move(stmt));
+  }
+  if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
+    SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false);
+    rewriter.PlanReuse(stmt, false);
+    stmt = rewriter(std::move(stmt));
   }
   return stmt;
 }
 
 namespace transform {
 
-Pass MergeDynamicSharedMemoryAllocations() {
+Pass MergeSharedMemoryAllocations() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    bool merge_static_smem =
+        ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
     auto* n = f.CopyOnWrite();
-    n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
+    n->body = MergeSharedMemoryAllocations(std::move(n->body), 
merge_static_smem);
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, 
"tir.MergeDynamicSharedMemoryAllocations", {});
+  return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", 
{});
 }
 
-TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations")
-    .set_body_typed(MergeDynamicSharedMemoryAllocations);
+TVM_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations")
+    .set_body_typed(MergeSharedMemoryAllocations);
 
 }  // namespace transform
 }  // namespace tir
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index f271769c80..70f325e4a2 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -380,13 +380,13 @@ class StoragePlanRewriter : public StmtExprMutator {
   using StmtEntry = LinearAccessPatternFinder::StmtEntry;
   using AllocEntry = LinearAccessPatternFinder::AllocEntry;
 
-  Stmt Rewrite(Stmt stmt, bool detect_inplace) {
+  Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) {
     detect_inplace_ = detect_inplace;
     // plan the rewrite
     LinearAccessPatternFinder finder;
     finder(stmt);
     this->LivenessAnalysis(finder.linear_seq_);
-    this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
+    this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse);
     all_buffers_accessed_ = finder.all_buffers_accessed_;
     this->PrepareNewAlloc();
     // start rewrite
@@ -816,7 +816,8 @@ class StoragePlanRewriter : public StmtExprMutator {
 
   // Memory plan algorithm
   void PlanMemory(const std::vector<StmtEntry>& seq,
-                  const std::unordered_map<const VarNode*, AllocEntry>& 
alloc_info) {
+                  const std::unordered_map<const VarNode*, AllocEntry>& 
alloc_info,
+                  bool enable_reuse = true) {
     std::unordered_set<const VarNode*> inplace_flag;
 
     for (size_t i = 0; i < seq.size(); ++i) {
@@ -863,8 +864,8 @@ class StoragePlanRewriter : public StmtExprMutator {
             }
           }
           if (dst_entry == nullptr) {
-            dst_entry =
-                FindAlloc(alloc, thread_scope_, storage_scope, 
entry.num_physical_dimensions);
+            dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
+                                  entry.num_physical_dimensions, enable_reuse);
           }
           dst_entry->allocs.emplace_back(alloc);
           alloc_map_[var] = dst_entry;
@@ -917,7 +918,8 @@ class StoragePlanRewriter : public StmtExprMutator {
   }
 
   StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
-                          const StorageScope& scope, size_t 
num_physical_dimensions) {
+                          const StorageScope& scope, size_t 
num_physical_dimensions,
+                          bool enable_reuse = true) {
     ICHECK(op != nullptr);
     // skip plan for local variable,
     // compiler can do a better job with register allocation.
@@ -940,7 +942,7 @@ class StoragePlanRewriter : public StmtExprMutator {
         (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || 
op->dtype.is_handle() ||
                                       (is_known_size && const_nbits <= 32));
 
-    if (is_small_array || !is_flat_memory_space) {
+    if (!enable_reuse || is_small_array || !is_flat_memory_space) {
       return NewAlloc(op, attach_scope, scope, const_nbits);
     }
 
@@ -1702,8 +1704,9 @@ namespace transform {
 
 Pass StorageRewrite() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", 
Bool(false)).value();
     auto* n = f.CopyOnWrite();
-    n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
+    n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, 
!merge_static_smem);
     // Parameters may not be rewritten, but internal allocations may.
     // Vectorization of AllocateConst is currently disabled, as it has
     // indexing issues for types that include padding (e.g. int8x3
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index bf68200944..c52aca7674 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -161,7 +161,7 @@ def test_inject_async_copy_shared_dyn():
     mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
     mod = tvm.tir.transform.FlattenBuffer()(mod)
     mod = tvm.tir.transform.VectorizeLoop()(mod)
-    mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod)
+    mod = tvm.tir.transform.MergeSharedMemoryAllocations()(mod)
     mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
 
     assert count_cp_async(mod["main"].body) == 2
diff --git 
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
 
b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
index 37372059a2..343ed1c10f 100644
--- 
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++ 
b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -32,7 +32,7 @@ def run_passes(sch, args):
             tvm.tir.transform.Simplify(),
             tvm.tir.transform.VectorizeLoop(),
             tvm.tir.transform.StorageRewrite(),
-            tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
+            tvm.tir.transform.MergeSharedMemoryAllocations(),
         ]
     )(mod)
 
@@ -136,7 +136,7 @@ def test_matmul_dyn_shared():
         np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32"))
         tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4)
 
-    for target in ["cuda", "nvptx"]:
+    for target in ["cuda"]:
         check_target(target)
 
 
@@ -201,7 +201,7 @@ def test_dyn_shared_vectorized_store():
                 c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4
             )
 
-    for target in ["cuda", "nvptx"]:
+    for target in ["cuda"]:
         check_target(target)
 
 
@@ -266,7 +266,7 @@ def test_dyn_shared_reuse_and_merge():
         fadd(a, b, c, d)
         tvm.testing.assert_allclose(d.numpy(), a.numpy() + b.numpy() + 
c.numpy(), 1e-4, 1e-4)
 
-    for target in ["cuda", "nvptx"]:
+    for target in ["cuda"]:
         check_target(target)
 
 
@@ -323,7 +323,7 @@ def test_dyn_shared_more_dtype():
         fadd(a, b, c)
         tvm.testing.assert_allclose(c.numpy(), a.numpy().astype("float32") + 
b.numpy(), 1e-4, 1e-4)
 
-    for target in ["cuda", "nvptx"]:
+    for target in ["cuda"]:
         check_target(target)
 
 
@@ -336,7 +336,7 @@ class TestMatmul(tvm.testing.CompareBeforeAfter):
     for the replaced allocations.
     """
 
-    transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()
+    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
 
     use_decl_buffer = tvm.testing.parameter(by_dict={"t_buffer": False, 
"decl_buffer": True})
 
@@ -455,4 +455,5 @@ class TestMatmul(tvm.testing.CompareBeforeAfter):
 
 
 if __name__ == "__main__":
-    tvm.testing.main()
+    # tvm.testing.main()
+    test_dyn_shared_more_dtype()
diff --git 
a/tests/python/tir-transform/test_tir_transform_merge_static_shared_memory_allocations.py
 
b/tests/python/tir-transform/test_tir_transform_merge_static_shared_memory_allocations.py
new file mode 100644
index 0000000000..be32514a72
--- /dev/null
+++ 
b/tests/python/tir-transform/test_tir_transform_merge_static_shared_memory_allocations.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import te
+from tvm.driver.build_module import schedule_to_module
+from tvm.topi.math import cast
+from tvm.script import tir as T
+
+
+def run_passes(sch, args):
+    mod = schedule_to_module(sch, args)
+    with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
+        return tvm.transform.Sequential(
+            [
+                tvm.tir.transform.StorageFlatten(64),
+                tvm.tir.transform.Simplify(),
+                tvm.tir.transform.VectorizeLoop(),
+                tvm.tir.transform.StorageRewrite(),
+                tvm.tir.transform.MergeSharedMemoryAllocations(),
+            ]
+        )(mod)
+
+
+def verify_single_allocation(stmt, alloc_size=None):
+    num_alloc = [0]
+    alloc_extents = []
+
+    def verify(n):
+        if (
+            isinstance(n, tvm.tir.Allocate)
+            and n.buffer_var.type_annotation.storage_scope == "shared"
+        ):
+            num_alloc[0] += 1
+            alloc_extents.append(n.extents[0])
+
+    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
+    assert num_alloc[0] == 1
+
+    if alloc_size:
+        assert alloc_extents[0] == alloc_size
+
+
[email protected]_gpu
+def test_matmul_shared():
+    n = 1024
+    block = 16
+    A = te.placeholder((n, n), name="A", dtype="float16")
+    B = te.placeholder((n, n), name="B", dtype="float16")
+
+    def syncthread():
+        return tvm.tir.Call(None, "tir.tvm_storage_sync", 
tvm.runtime.convert(["shared"]))
+
+    def test_matmul_ir(A, B, C):
+        ib = tvm.tir.ir_builder.create()
+
+        tx = te.thread_axis("threadIdx.x")
+        ty = te.thread_axis("threadIdx.y")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        ib.scope_attr(tx, "thread_extent", block)
+        ib.scope_attr(ty, "thread_extent", block)
+        ib.scope_attr(bx, "thread_extent", n // block)
+        ib.scope_attr(by, "thread_extent", n // block)
+
+        A_sh = ib.allocate(A.dtype, (block, block), scope="shared", 
name="A_sh")  # fp16
+        B_sh = ib.allocate(B.dtype, (block, block), scope="shared", 
name="B_sh")  # fp16
+        # Create a shared memory for the accumulation.
+        # This is for testing merging shared memory alloctions with different 
data type.
+        # In practice, there is no need to allocate a shared memory for C.
+        C_local = ib.allocate(C.dtype, (1,), scope="local", name="C_local")
+        C_sh = ib.allocate(C.dtype, (block, block), scope="shared", 
name="C_sh")  # fp32
+
+        A_ptr = ib.buffer_ptr(A)
+        B_ptr = ib.buffer_ptr(B)
+        C_ptr = ib.buffer_ptr(C)
+
+        C_local[0] = 0.0
+
+        with ib.for_range(0, n // block, name="i") as i:
+            A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx]
+            B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx]
+            ib.emit(syncthread())
+
+            with ib.for_range(0, block, name="k") as k:
+                C_local[0] += cast(A_sh[ty, k] * B_sh[k, tx], "float32")
+            ib.emit(syncthread())
+
+        C_sh[ty, tx] = C_local[0]
+        C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx]
+
+        return ib.get()
+
+    C = te.extern(
+        A.shape,
+        [A, B],
+        lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]),
+        name="matmul",
+        dtype="float32",
+    )
+    s = te.create_schedule(C.op)
+    mod = run_passes(s, [A, B, C])
+    # C can be allocated at the start of A, so we only need to allocate 2 
block * block memory with dtype = float16
+    expected_alloc_size = block * block * 4
+    verify_single_allocation(mod["main"].body, expected_alloc_size)
+
+    def check_target(target):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        fmatmul = tvm.build(s, [A, B, C], target)
+        dev = tvm.device(target, 0)
+
+        size = (n, n)
+        a_np = np.random.uniform(size=size).astype(A.dtype)
+        b_np = np.random.uniform(size=size).astype(B.dtype)
+        a = tvm.nd.array(a_np, dev)
+        b = tvm.nd.array(b_np, dev)
+        c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev)
+        fmatmul(a, b, c)
+        np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32"))
+        tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4)
+
+    for target in ["cuda"]:
+        check_target(target)
+
+
[email protected]_gpu
+def test_shared_more_dtype():
+    """Test vectorized store into shared memory"""
+    n = 512
+    A = te.placeholder((n,), name="A", dtype="int8")
+    B = te.placeholder((n,), name="B", dtype="int16")
+
+    def test_device_ir(A, B, C):
+        n = A.shape[0]
+        ib = tvm.tir.ir_builder.create()
+
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", n)
+
+        A_sh = ib.allocate(A.dtype, (n,), scope="shared")  # i8
+        B_sh = ib.allocate(B.dtype, (n,), scope="shared")  # i16
+        C_sh = ib.allocate(C.dtype, (n,), scope="shared")  # i32
+
+        Aptr = ib.buffer_ptr(A)
+        Bptr = ib.buffer_ptr(B)
+        Cptr = ib.buffer_ptr(C)
+
+        A_sh[tx] = Aptr[tx]
+        B_sh[tx] = Bptr[tx]
+
+        C_sh[tx] = cast(A_sh[tx], "int32") + cast(B_sh[tx], "int32")
+        Cptr[tx] = C_sh[tx]
+        return ib.get()
+
+    C = te.extern(
+        (n,),
+        [A, B],
+        lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]),
+        name="vadd",
+        dtype="int32",
+    )
+    s = te.create_schedule(C.op)
+
+    mod = run_passes(s, [A, B, C])
+    verify_single_allocation(mod["main"].body, n * 4)
+
+    def check_target(target):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        fadd = tvm.build(s, [A, B, C], target)
+        dev = tvm.device(target, 0)
+
+        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
+        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
+        c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev)
+        fadd(a, b, c)
+        tvm.testing.assert_allclose(c.numpy(), a.numpy().astype("float32") + 
b.numpy(), 1e-4, 1e-4)
+
+    for target in ["cuda"]:
+        check_target(target)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to