Hzfengsy commented on a change in pull request #8571:
URL: https://github.com/apache/tvm/pull/8571#discussion_r678814614



##########
File path: tests/python/unittest/test_tir_ir_builder.py
##########
@@ -554,6 +554,147 @@ def check_target(target):
         check_target(target)
 
 
[email protected]_gpu

Review comment:
       I don't think that the dyn_shared_mem tests should be in the 
`test_tir_ir_builder.py`. It would be great if we can create a separate file to 
test the pass `tir.transform.MergeDynamicSharedMemoryAllocations` by comparing 
the transformed IR and the expected IR.

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file merge_dynamic_shared_memory_allocations.cc
+ * \brief Each GPU kernel is allowed to have only one dynamic shared memory 
allocation.
+ * This pass merges multiple TIR-level dynamic shared memory allocations into 
one allocation.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+bool IsDynamicSharedMemory(Var buffer_var) {
+  auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  return storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn";
+}
+
+class AllocateCollector : public StmtExprVisitor {
+ public:
+  void VisitStmt_(const AllocateNode* op) final {
+    if (IsDynamicSharedMemory(op->buffer_var)) {
+      dyn_shmem_allocs_.insert(op);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+};
+
+class DynamicSharedMemoryRewriter : public StmtExprMutator {
+ public:
+  explicit DynamicSharedMemoryRewriter(
+      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      : dyn_shmem_allocs_{dyn_shmem_allocs} {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent && !allocated) {
+      // Allocate one dynamic shared memory allocation at the beginning of 
thread scope
+      int align = 1;
+      for (auto& alloc : dyn_shmem_allocs_) {

Review comment:
       ```suggestion
         for (const auto& alloc : dyn_shmem_allocs_) {
   ```

##########
File path: tests/python/unittest/test_tir_ir_builder.py
##########
@@ -554,6 +554,147 @@ def check_target(target):
         check_target(target)
 
 
[email protected]_gpu

Review comment:
       Also, is it possible to write the test using TVMScript? I think it's 
more clear than the `ir_builder`

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file merge_dynamic_shared_memory_allocations.cc
+ * \brief Each GPU kernel is allowed to have only one dynamic shared memory 
allocation.
+ * This pass merges multiple TIR-level dynamic shared memory allocations into 
one allocation.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+bool IsDynamicSharedMemory(Var buffer_var) {
+  auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  return storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn";
+}
+
+class AllocateCollector : public StmtExprVisitor {
+ public:
+  void VisitStmt_(const AllocateNode* op) final {
+    if (IsDynamicSharedMemory(op->buffer_var)) {
+      dyn_shmem_allocs_.insert(op);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+};
+
+class DynamicSharedMemoryRewriter : public StmtExprMutator {
+ public:
+  explicit DynamicSharedMemoryRewriter(
+      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      : dyn_shmem_allocs_{dyn_shmem_allocs} {}
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent && !allocated) {
+      // Allocate one dynamic shared memory allocation at the beginning of 
thread scope
+      int align = 1;
+      for (auto& alloc : dyn_shmem_allocs_) {
+        ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not 
supported.";
+        align = std::max(align, alloc->dtype.bytes());
+      }
+      for (auto& alloc : dyn_shmem_allocs_) {

Review comment:
       ```suggestion
         for (const auto& alloc : dyn_shmem_allocs_) {
   ```

##########
File path: tests/python/unittest/test_tir_ir_builder.py
##########
@@ -554,6 +554,147 @@ def check_target(target):
         check_target(target)
 
 
[email protected]_gpu
+def test_matmul_dyn_shared():
+    n = 1024
+    A = te.placeholder((n, n), name="A", dtype="float16")
+    B = te.placeholder((n, n), name="B", dtype="float16")
+
+    def syncthread():
+        return tvm.tir.Call(None, "tir.tvm_storage_sync", 
tvm.runtime.convert(["shared"]))
+
+    def test_matmul_ir(A, B, C):
+        ib = tvm.tir.ir_builder.create()
+        block = 16
+
+        tx = te.thread_axis("threadIdx.x")
+        ty = te.thread_axis("threadIdx.y")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        ib.scope_attr(tx, "thread_extent", block)
+        ib.scope_attr(ty, "thread_extent", block)
+        ib.scope_attr(bx, "thread_extent", n / block)

Review comment:
       I'm not sure but do we need `n // block`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to