This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cdc2303292 [TIR] Require exactly same-dtype matching for Vulkan smem 
reuse (#16515)
cdc2303292 is described below

commit cdc2303292565e2b330ca3eb3ad0622691e708ce
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Feb 3 13:23:14 2024 -0500

    [TIR] Require exactly same-dtype matching for Vulkan smem reuse (#16515)
    
    This PR fixes the StorageRewrite pass which failed to avoid
    shared memory reuse of different dtypes for Vulkan.
    
    Since the Vulkan target information is required at the time
    of lowering, the pass `BindTarget` needs to apply before
    lowering, so that the functions have correct target information.
    Note that previously the pass checks `Target::Current`, while
    `tvm.build` does not set the current target.
    
    One regression test is added.
---
 .../transforms/merge_shared_memory_allocations.cc  |  5 --
 src/tir/transforms/storage_rewrite.cc              | 46 ++++++++----
 .../test_tir_transform_storage_rewrite.py          | 84 ++++++++++++++++++++++
 3 files changed, 115 insertions(+), 20 deletions(-)

diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc 
b/src/tir/transforms/merge_shared_memory_allocations.cc
index 1598d409c5..c79b9c1f93 100644
--- a/src/tir/transforms/merge_shared_memory_allocations.cc
+++ b/src/tir/transforms/merge_shared_memory_allocations.cc
@@ -662,11 +662,6 @@ namespace transform {
 Pass MergeSharedMemoryAllocations() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", 
Bool(false)).value();
-    // disable this pass for Vulkan
-    auto target = Target::Current(true);
-    if (target.defined() && target->kind->name == "vulkan") {
-      return f;
-    }
     auto* n = f.CopyOnWrite();
     n->body = MergeSharedMemoryAllocations(std::move(n->body), 
merge_static_smem);
     return f;
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 6875523a95..991c48219b 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -380,13 +380,15 @@ class StoragePlanRewriter : public StmtExprMutator {
   using StmtEntry = LinearAccessPatternFinder::StmtEntry;
   using AllocEntry = LinearAccessPatternFinder::AllocEntry;
 
-  Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) {
+  Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
+               bool reuse_require_exact_matched_dtype) {
     detect_inplace_ = detect_inplace;
     // plan the rewrite
     LinearAccessPatternFinder finder;
     finder(stmt);
     this->LivenessAnalysis(finder.linear_seq_);
-    this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse);
+    this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse,
+                     reuse_require_exact_matched_dtype);
     all_buffers_accessed_ = finder.all_buffers_accessed_;
     this->PrepareNewAlloc();
     // start rewrite
@@ -817,7 +819,7 @@ class StoragePlanRewriter : public StmtExprMutator {
   // Memory plan algorithm
   void PlanMemory(const std::vector<StmtEntry>& seq,
                   const std::unordered_map<const VarNode*, AllocEntry>& 
alloc_info,
-                  bool enable_reuse = true) {
+                  bool enable_reuse, bool reuse_require_exact_matched_dtype) {
     std::unordered_set<const VarNode*> inplace_flag;
 
     for (size_t i = 0; i < seq.size(); ++i) {
@@ -864,8 +866,9 @@ class StoragePlanRewriter : public StmtExprMutator {
             }
           }
           if (dst_entry == nullptr) {
-            dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
-                                  entry.num_physical_dimensions, enable_reuse);
+            dst_entry =
+                FindAlloc(alloc, thread_scope_, storage_scope, 
entry.num_physical_dimensions,
+                          enable_reuse, reuse_require_exact_matched_dtype);
           }
           dst_entry->allocs.emplace_back(alloc);
           alloc_map_[var] = dst_entry;
@@ -919,7 +922,7 @@ class StoragePlanRewriter : public StmtExprMutator {
 
   StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
                           const StorageScope& scope, size_t 
num_physical_dimensions,
-                          bool enable_reuse = true) {
+                          bool enable_reuse, bool 
reuse_require_exact_matched_dtype) {
     ICHECK(op != nullptr);
     // skip plan for local variable,
     // compiler can do a better job with register allocation.
@@ -958,6 +961,9 @@ class StoragePlanRewriter : public StmtExprMutator {
         if (e->scope != scope) continue;
         // when not divided, no reuse, eg, float4 vs float3
         if (e->bits_offset % op_elem_bits != 0) continue;
+        if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
+          continue;
+        }
         e->const_nbits = std::max(const_nbits, e->const_nbits);
         const_free_map_.erase(it);
         return e;
@@ -969,6 +975,9 @@ class StoragePlanRewriter : public StmtExprMutator {
         if (e->attach_scope_ != attach_scope) continue;
         if (e->scope != scope) continue;
         if (e->elem_type != op->dtype.element_of()) continue;
+        if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
+          continue;
+        }
         e->const_nbits = std::max(const_nbits, e->const_nbits);
         const_free_map_.erase(it);
         return e;
@@ -1704,17 +1713,24 @@ namespace transform {
 
 Pass StorageRewrite() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    bool enable_reuse = true;
+    bool reuse_require_exact_matched_dtype = false;
     bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", 
Bool(false)).value();
-    // disable merge_static_smem for Vulkan
-    auto target = Target::Current(true);
-    if (target.defined() && target->kind->name == "vulkan") {
-      merge_static_smem = false;
-    }
-    // Only enable reuse when we are not merging static shared memory.
-    // Otherwise we will do it in a separate stage
-    bool enable_reuse = merge_static_smem ? false : true;
+    if (merge_static_smem) {
+      // When `merge_static_smem` is true, we will reuse and merge shared
+      // memory in a dedicated pass `MergeSharedMemoryAllocations`.
+      // And so we don't enable reuse in this pass.
+      enable_reuse = false;
+    }
+
+    Optional<Target> target = f->GetAttr<Target>("target");
+    if (target.defined() && target.value()->kind->name == "vulkan") {
+      // Require exactly same-dtype matching in smem reuse for Vulkan
+      reuse_require_exact_matched_dtype = true;
+    }
     auto* n = f.CopyOnWrite();
-    n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, 
enable_reuse);
+    n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, 
enable_reuse,
+                                            reuse_require_exact_matched_dtype);
     // Parameters may not be rewritten, but internal allocations may.
     // Vectorization of AllocateConst is currently disabled, as it has
     // indexing issues for types that include padding (e.g. int8x3
diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py 
b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
index 197e81818e..4b71eb8254 100644
--- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
+++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 import sys
+
 import pytest
+
 import tvm
 import tvm.testing
 from tvm import te
@@ -928,5 +930,87 @@ class TestNoOrphanedDeclBuffer(BaseCompare):
             D[i] = C[i]
 
 
+def test_vulkan_smem_reuse():
+    target = tvm.target.Target(
+        {
+            "keys": ["vulkan", "gpu"],
+            "kind": "vulkan",
+            "max_num_threads": 256,
+            "max_threads_per_block": 256,
+            "supports_float32": T.bool(True),
+            "supports_int32": T.bool(True),
+            "tag": "",
+            "thread_warp_size": 1,
+        }
+    )
+
+    @T.prim_func(private=True)
+    def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        A_shared = T.allocate([4], "float32", "shared")
+        A_local = T.allocate([4], "float32", "local")
+        B_shared = T.allocate([4], "float16", "shared")
+        A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_1 = T.Buffer((4,), data=A.data)
+            A_shared_1[threadIdx_x] = A_1[threadIdx_x]
+        A_local_1 = T.Buffer((4,), data=A_local, scope="local")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
+        B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
+        threadIdx_x = T.launch_thread("threadIdx.x", 4)
+        B_1 = T.Buffer((4,), "float16", data=B.data)
+        B_1[threadIdx_x] = B_shared_1[threadIdx_x]
+
+    @T.prim_func(private=True)
+    def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), 
"float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        A_shared = T.allocate([4], "float32", "shared")
+        A_local = T.allocate([4], "float32", "local")
+        A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_1 = T.Buffer((4,), data=A.data)
+            A_shared_1[threadIdx_x] = A_1[threadIdx_x]
+        A_local_1 = T.Buffer((4,), data=A_local, scope="local")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
+        A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
+        threadIdx_x = T.launch_thread("threadIdx.x", 4)
+        B_1 = T.Buffer((4,), "float16", data=B.data)
+        B_1[threadIdx_x] = A_shared_2[threadIdx_x]
+
+    @T.prim_func(private=True)
+    def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), 
"float16")):
+        T.func_attr({"target": target, "tir.noalias": T.bool(True)})
+        A_shared_1 = T.allocate([4], "float32", "shared")
+        A_local_1 = T.allocate([4], "float32", "local")
+        B_shared_1 = T.allocate([4], "float16", "shared")
+        A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_1 = T.Buffer((4,), data=A.data)
+            A_shared_1_1[threadIdx_x] = A_1[threadIdx_x]
+        A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x]
+        B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, 
scope="shared")
+        with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
+            B_shared_1_1[threadIdx_x] = T.Cast("float16", 
A_local_1_1[threadIdx_x])
+        threadIdx_x = T.launch_thread("threadIdx.x", 4)
+        B_1 = T.Buffer((4,), "float16", data=B.data)
+        B_1[threadIdx_x] = B_shared_1_1[threadIdx_x]
+
+    # Reuse shared memory when lowering without target.
+    mod = tvm.IRModule({"main": func})
+    tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering)
+
+    # No shared memory reuse when lowering with target Vulkan.
+    mod = tvm.tir.transform.BindTarget(target)(mod)
+    tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to