This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 53d163c968 [TIR, CUDA] Add pass to replace global to shared memory 
copy with cp.async (#11658)
53d163c968 is described below

commit 53d163c96850c8476d479803c59344c6977ef9e8
Author: Masahiro Masuda <[email protected]>
AuthorDate: Fri Jun 10 11:05:18 2022 +0900

    [TIR, CUDA] Add pass to replace global to shared memory copy with cp.async 
(#11658)
    
    * [TIR, CUDA] Add pass to replace global to shared memory copy with cp.async
    
    * add missing doc
    
    * black
    
    * missing src
    
    * clang format
    
    * clang format
    
    * check against nested async scope
---
 include/tvm/tir/stmt.h                             |   5 +
 include/tvm/tir/transform.h                        |   6 +
 python/tvm/testing/utils.py                        |   7 +
 python/tvm/tir/transform/transform.py              |  11 ++
 src/driver/driver_api.cc                           |   8 +
 src/target/source/ptx.cc                           |   3 +-
 src/tir/transforms/inject_ptx_async_copy.cc        | 145 ++++++++++++++++
 tests/python/unittest/test_tir_ptx_cp_async.py     |   4 +-
 .../test_tir_schedule_tensorize_ldmatrix_mma.py    |   8 +-
 .../test_tir_transform_inject_ptx_async_copy.py    | 183 +++++++++++++++++++++
 10 files changed, 370 insertions(+), 10 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 48cac6d8d0..288ed9d609 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1441,6 +1441,11 @@ constexpr const char* pipeline_exec_scope = 
"pipeline_exec_scope";
  */
 constexpr const char* device_scope = "device_scope";
 
+/*!
+ * \brief Mark that the attached statement runs asynchronously.
+ */
+constexpr const char* async_scope = "async_scope";
+
 /*!
  * \brief Mark that the shape of TensorCore fragment
  */
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 6393eeb943..39a6459048 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -644,6 +644,12 @@ TVM_DLL Pass AnnotateEntryFunc();
  */
 TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
 
+/*!
+ * \brief Pass to rewrite global to shared memory copy on CUDA with 
asyncronous copy.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPTXAsyncCopy();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index bf3cc94f5d..59ff93cfea 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1599,6 +1599,13 @@ def terminate_self():
     sys.exit(-1)
 
 
+def is_ampere_or_newer():
+    """Check if the target environment has an NVIDIA Ampere GPU or newer."""
+    arch = tvm.contrib.nvcc.get_target_compute_version()
+    major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+    return major >= 8
+
+
 def main():
     test_file = inspect.getsourcefile(sys._getframe(1))
     sys.exit(pytest.main([test_file] + sys.argv[1:]))
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index e0a7501ef9..e1ddfe439a 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -825,3 +825,14 @@ def Filter(fcond: Callable):
         The result pass
     """
     return _ffi_api.Filter(fcond)  # type: ignore
+
+
+def InjectPTXAsyncCopy():
+    """Rewrite global to shared memory copy on CUDA with asyncronous copy.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InjectPTXAsyncCopy()  # type: ignore
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index ace31800de..7f015e7ca2 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -50,6 +50,7 @@ 
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);
 
 using runtime::PackedFunc;
 using runtime::TVMArgs;
@@ -559,6 +560,13 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
   mixed_pass_list.push_back(tir::transform::InferFragment());
   mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
 
+  bool use_ptx_async_copy =
+      pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();
+
+  if (use_ptx_async_copy) {
+    mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
+  }
+
   bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
                           .value_or(relay::Executor::Create("graph", {}))
                           ->GetAttr<Bool>("unpacked-api")
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index 71c68baed6..c5e3bf98ec 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -651,7 +651,7 @@ std::string PrintCpAsyncAssembly(const std::string& 
shared_ptr,
       : "l"((void *)({smem_addr}))
     );
     __asm__ __volatile__(
-      "cp.async.cg.shared.global [%0], [%1], %2;"
+      "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
        :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
     );
   }
@@ -660,6 +660,7 @@ std::string PrintCpAsyncAssembly(const std::string& 
shared_ptr,
   replacer.register_rule("{smem_addr}", shared_ptr + " + " + 
shared_elem_offset);
   replacer.register_rule("{global_ptr}", global_ptr + " + " + 
global_elem_offset);
   replacer.register_rule("{bytes}", bytes);
+  replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
   asm_code = replacer.rewrite(asm_code);
   return asm_code;
 }
diff --git a/src/tir/transforms/inject_ptx_async_copy.cc 
b/src/tir/transforms/inject_ptx_async_copy.cc
new file mode 100644
index 0000000000..c74ce9d3d2
--- /dev/null
+++ b/src/tir/transforms/inject_ptx_async_copy.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Replace copy from global to shared with async copy
+ * \file inject_ptx_async_copy.cc
+ */
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../ir/buffer_common.h"
+#include "storage_access.h"
+#include "tvm/tir/stmt.h"
+
+namespace tvm {
+namespace tir {
+
+class PTXAsyncCopyInjector : public StmtMutator {
+ public:
+  Stmt VisitStmt_(const AttrStmtNode* attr) {
+    if (attr->attr_key == tir::attr::async_scope) {
+      ICHECK(in_async == false) << "Nested async scopes not supported";
+      in_async = true;
+      auto body = this->VisitStmt(attr->body);
+      in_async = false;
+      return body;
+    }
+    return StmtMutator::VisitStmt_(attr);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) {
+    if (in_async && (store->buffer.scope() == "shared" || 
store->buffer.scope() == "shared.dyn")) {
+      if (auto* load = store->value.as<BufferLoadNode>()) {
+        if (load->buffer.scope() == "global") {
+          ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
+          ICHECK(load->indices[0]->dtype.lanes() == 
store->indices[0]->dtype.lanes());
+
+          const int indices_lanes = load->indices[0]->dtype.lanes();
+          const int bytes = indices_lanes * load->buffer->dtype.bytes();
+
+          if (bytes == 4 || bytes == 8 || bytes == 16) {
+            auto dst_elem_type = 
GetPointerType(store->buffer->data->type_annotation);
+            auto src_elem_type = 
GetPointerType(load->buffer->data->type_annotation);
+            ICHECK(dst_elem_type.first && src_elem_type.first)
+                << "Both store and load buffer should have a pointer type 
annotation.";
+
+            int index_factor = 1;
+            if (dst_elem_type != src_elem_type) {
+              // The only case where src and dst have different dtypes is when 
the dst shared memory
+              // is a byte buffer generated by merging dynamic shared memory.
+              ICHECK(store->buffer.scope() == "shared.dyn");
+              ICHECK(dst_elem_type.second == DataType::UInt(8));
+              // BufferStore/Load have the "pointer reinterpret" semantics 
according to their
+              // "value" dtype. Their "indices" are supposed to be applied 
after such pointer cast,
+              // for example: ((*float16)(byte_buffer))[buffer->indices] = 
fp16_value;
+              // To replace BufferStore/Load with cp.async, we need to 
multiply the store index by
+              // the byte size of the "value" dtype, to get the correct offset 
into the byte buffer.
+              index_factor = src_elem_type.second.bytes();
+            }
+
+            if (indices_lanes == 1) {
+              auto src_offset = load->indices[0];
+              auto dst_offset = store->indices[0];
+              return Evaluate(
+                  Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
+                       {store->buffer->data, tir::Mul(dst_offset, 
PrimExpr(index_factor)),
+                        load->buffer->data, src_offset, PrimExpr(bytes)}));
+            }
+
+            // Only some vectorized indexing patterns are supported for now.
+            auto src_offset = [=]() -> PrimExpr {
+              if (load->indices[0]->IsInstance<RampNode>()) {
+                return load->indices[0].as<RampNode>()->base;
+              }
+              return PrimExpr();
+            }();
+
+            auto dst_offset = [=]() -> PrimExpr {
+              if (store->indices[0].as<RampNode>()) {
+                return store->indices[0].as<RampNode>()->base;
+              } else if (store->indices[0].as<AddNode>()) {
+                // The case where the dst buffer is a byte buffer generated by 
merging dynamic
+                // shared memory.
+                // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = 
A_global[ramp(...),1, 8)]
+                auto* add = store->indices[0].as<AddNode>();
+                if (!add->a->IsInstance<RampNode>()) return PrimExpr();
+                if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
+                return tir::Add(add->a.as<RampNode>()->base, 
add->b.as<BroadcastNode>()->value);
+              }
+              return PrimExpr();
+            }();
+
+            if (src_offset.defined() && dst_offset.defined()) {
+              return Evaluate(
+                  Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
+                       {store->buffer->data, tir::Mul(dst_offset, 
PrimExpr(index_factor)),
+                        load->buffer->data, src_offset, PrimExpr(bytes)}));
+            }
+          }
+        }
+      }
+    }
+    return StmtMutator::VisitStmt_(store);
+  }
+
+ private:
+  bool in_async{false};
+};
+
+namespace transform {
+
+Pass InjectPTXAsyncCopy() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = PTXAsyncCopyInjector()(n->body);
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py 
b/tests/python/unittest/test_tir_ptx_cp_async.py
index 17b6088550..5e6535f295 100644
--- a/tests/python/unittest/test_tir_ptx_cp_async.py
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -40,8 +40,8 @@ def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: 
T.Buffer[(32, 128), "floa
             )
 
         # TODO(masahi): Remove dtype requirement from TVMScript parser
-        T.evaluate(T.ptx_commit_group(dtype="float16"))
-        T.evaluate(T.ptx_wait_group(0, dtype="float16"))
+        T.evaluate(T.ptx_commit_group(dtype=""))
+        T.evaluate(T.ptx_wait_group(0, dtype=""))
 
         for i in range(128):
             B[tx, i] = A_shared[tx, i]
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py 
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
index 9feb994e71..32c1625653 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -76,12 +76,6 @@ def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
     return (a, b, c)
 
 
-def is_ampere_or_newer():
-    arch = tvm.contrib.nvcc.get_target_compute_version()
-    major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
-    return major >= 8
-
-
 def run_test(
     k_inner,
     in_dtype,
@@ -117,7 +111,7 @@ def run_test(
         mma_store_intrin,
     )
 
-    if not is_ampere_or_newer():
+    if not tvm.testing.is_ampere_or_newer():
         return None
 
     f = tvm.build(sch.mod["main"], target="cuda", name="dense")
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
new file mode 100644
index 0000000000..d7e13f40aa
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -0,0 +1,183 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+def count_cp_async(stmt):
+    num_alloc = [0]
+
+    def verify(n):
+        if isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async":
+            num_alloc[0] += 1
+
+    tvm.tir.stmt_functor.post_order_visit(stmt, verify)
+    return num_alloc[0]
+
+
+def generate_global_to_shared_vectorized_copy(dtype, vector_size):
+    num_iters = 128 // vector_size
+    vector_size_expr = tvm.runtime.convert(vector_size)
+
+    @T.prim_func
+    def ptx_global_to_shared_copy(
+        A: T.Buffer[(32, 128), dtype], B: T.Buffer[(32, 128), dtype]
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        bx = T.env_thread("blockIdx.x")
+        tx = T.env_thread("threadIdx.x")
+        T.launch_thread(bx, 1)
+        T.launch_thread(tx, 32)
+        with T.block():
+            A_shared = T.alloc_buffer([32, 128], dtype, scope="shared")
+            T.reads(A[0:32, 0:128])
+            T.writes(B[0:32, 0:128])
+
+            T.attr("default", "async_scope", 1)
+            for i in T.serial(num_iters):
+                for j in T.vectorized(vector_size):
+                    A_shared[tx, i * vector_size_expr + j] = A[tx, i * 
vector_size_expr + j]
+
+            T.evaluate(T.ptx_commit_group(dtype=""))
+            T.evaluate(T.ptx_wait_group(0, dtype=""))
+
+            for i in range(128):
+                B[tx, i] = A_shared[tx, i]
+
+    return ptx_global_to_shared_copy
+
+
[email protected]_func
+def ptx_global_to_shared_copy_fp32x1(
+    A: T.Buffer[(32, 128), "float32"], B: T.Buffer[(32, 128), "float32"]
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
+        T.reads(A[0:32, 0:128])
+        T.writes(B[0:32, 0:128])
+
+        T.attr("default", "async_scope", 1)
+        for i in T.serial(128):
+            A_shared[tx, i] = A[tx, i]
+
+        T.evaluate(T.ptx_commit_group(dtype=""))
+        T.evaluate(T.ptx_wait_group(0, dtype=""))
+
+        for i in range(128):
+            B[tx, i] = A_shared[tx, i]
+
+
[email protected]_func
+def ptx_global_to_shared_dyn_copy_fp16x8(
+    A: T.Buffer[(32, 128), "float16"],
+    B: T.Buffer[(32, 128), "float16"],
+    C: T.Buffer[(32, 128), "float16"],
+) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
+        B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
+        T.reads(A[0:32, 0:128], B[0:32, 0:128])
+        T.writes(C[0:32, 0:128])
+
+        T.attr("default", "async_scope", 1)
+        for i in T.serial(16):
+            for j in T.vectorized(8):
+                A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
+                B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]
+
+        T.evaluate(T.ptx_commit_group(dtype=""))
+        T.evaluate(T.ptx_wait_group(0, dtype=""))
+
+        for i in range(128):
+            C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
+
+
[email protected]_cuda
+def test_inject_async_copy():
+    for dtype, vec_size in [("float16", 8), ("float16", 4), ("float32", 4), 
("float32", 1)]:
+        if vec_size == 1:
+            f = ptx_global_to_shared_copy_fp32x1
+        else:
+            f = generate_global_to_shared_vectorized_copy(dtype, vec_size)
+
+        mod = tvm.IRModule.from_expr(f)
+        mod = tvm.tir.transform.FlattenBuffer()(mod)
+        if vec_size > 1:
+            mod = tvm.tir.transform.VectorizeLoop()(mod)
+        mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+
+        assert count_cp_async(mod["main"].body) == 1
+
+        if not tvm.testing.is_ampere_or_newer():
+            continue
+
+        with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+            mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
+
+        A_np = np.random.rand(32, 128).astype(dtype)
+        B_np = np.zeros((32, 128)).astype(dtype)
+        dev = tvm.cuda(0)
+        A_nd = tvm.nd.array(A_np, device=dev)
+        B_nd = tvm.nd.array(B_np, device=dev)
+        mod(A_nd, B_nd)
+        tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
[email protected]_cuda
+def test_inject_async_copy_shared_dyn():
+    f = ptx_global_to_shared_dyn_copy_fp16x8
+
+    mod = tvm.IRModule.from_expr(f)
+    mod = tvm.tir.transform.FlattenBuffer()(mod)
+    mod = tvm.tir.transform.VectorizeLoop()(mod)
+    mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod)
+    mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+
+    assert count_cp_async(mod["main"].body) == 2
+
+    if not tvm.testing.is_ampere_or_newer():
+        return
+
+    with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+        mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
+
+    A_np = np.random.rand(32, 128).astype("float16")
+    B_np = np.random.rand(32, 128).astype("float16")
+    C_np = np.zeros((32, 128)).astype("float16")
+    dev = tvm.cuda(0)
+    A_nd = tvm.nd.array(A_np, device=dev)
+    B_nd = tvm.nd.array(B_np, device=dev)
+    C_nd = tvm.nd.array(C_np, device=dev)
+    mod(A_nd, B_nd, C_nd)
+    tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np)
+
+
+if __name__ == "__main__":
+    test_inject_async_copy()
+    test_inject_async_copy_shared_dyn()

Reply via email to