This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 53d163c968 [TIR, CUDA] Add pass to replace global to shared memory
copy with cp.async (#11658)
53d163c968 is described below
commit 53d163c96850c8476d479803c59344c6977ef9e8
Author: Masahiro Masuda <[email protected]>
AuthorDate: Fri Jun 10 11:05:18 2022 +0900
[TIR, CUDA] Add pass to replace global to shared memory copy with cp.async
(#11658)
* [TIR, CUDA] Add pass to replace global to shared memory copy with cp.async
* add missing doc
* black
* missing src
* clang format
* clang format
* check against nested async scope
---
include/tvm/tir/stmt.h | 5 +
include/tvm/tir/transform.h | 6 +
python/tvm/testing/utils.py | 7 +
python/tvm/tir/transform/transform.py | 11 ++
src/driver/driver_api.cc | 8 +
src/target/source/ptx.cc | 3 +-
src/tir/transforms/inject_ptx_async_copy.cc | 145 ++++++++++++++++
tests/python/unittest/test_tir_ptx_cp_async.py | 4 +-
.../test_tir_schedule_tensorize_ldmatrix_mma.py | 8 +-
.../test_tir_transform_inject_ptx_async_copy.py | 183 +++++++++++++++++++++
10 files changed, 370 insertions(+), 10 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 48cac6d8d0..288ed9d609 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1441,6 +1441,11 @@ constexpr const char* pipeline_exec_scope =
"pipeline_exec_scope";
*/
constexpr const char* device_scope = "device_scope";
+/*!
+ * \brief Mark that the attached statement runs asynchronously.
+ */
+constexpr const char* async_scope = "async_scope";
+
/*!
* \brief Mark that the shape of TensorCore fragment
*/
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 6393eeb943..39a6459048 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -644,6 +644,12 @@ TVM_DLL Pass AnnotateEntryFunc();
*/
TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
+/*!
+ * \brief Pass to rewrite global to shared memory copy on CUDA with
asyncronous copy.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPTXAsyncCopy();
+
} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index bf3cc94f5d..59ff93cfea 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1599,6 +1599,13 @@ def terminate_self():
sys.exit(-1)
+def is_ampere_or_newer():
+ """Check if the target environment has an NVIDIA Ampere GPU or newer."""
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+ return major >= 8
+
+
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index e0a7501ef9..e1ddfe439a 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -825,3 +825,14 @@ def Filter(fcond: Callable):
The result pass
"""
return _ffi_api.Filter(fcond) # type: ignore
+
+
+def InjectPTXAsyncCopy():
+ """Rewrite global to shared memory copy on CUDA with asyncronous copy.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectPTXAsyncCopy() # type: ignore
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index ace31800de..7f015e7ca2 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -50,6 +50,7 @@
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);
using runtime::PackedFunc;
using runtime::TVMArgs;
@@ -559,6 +560,13 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ bool use_ptx_async_copy =
+ pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();
+
+ if (use_ptx_async_copy) {
+ mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
+ }
+
bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index 71c68baed6..c5e3bf98ec 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -651,7 +651,7 @@ std::string PrintCpAsyncAssembly(const std::string&
shared_ptr,
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
- "cp.async.cg.shared.global [%0], [%1], %2;"
+ "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
@@ -660,6 +660,7 @@ std::string PrintCpAsyncAssembly(const std::string&
shared_ptr,
replacer.register_rule("{smem_addr}", shared_ptr + " + " +
shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " +
global_elem_offset);
replacer.register_rule("{bytes}", bytes);
+ replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
diff --git a/src/tir/transforms/inject_ptx_async_copy.cc
b/src/tir/transforms/inject_ptx_async_copy.cc
new file mode 100644
index 0000000000..c74ce9d3d2
--- /dev/null
+++ b/src/tir/transforms/inject_ptx_async_copy.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Replace copy from global to shared with async copy
+ * \file inject_ptx_async_copy.cc
+ */
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../ir/buffer_common.h"
+#include "storage_access.h"
+#include "tvm/tir/stmt.h"
+
+namespace tvm {
+namespace tir {
+
+class PTXAsyncCopyInjector : public StmtMutator {
+ public:
+ Stmt VisitStmt_(const AttrStmtNode* attr) {
+ if (attr->attr_key == tir::attr::async_scope) {
+ ICHECK(in_async == false) << "Nested async scopes not supported";
+ in_async = true;
+ auto body = this->VisitStmt(attr->body);
+ in_async = false;
+ return body;
+ }
+ return StmtMutator::VisitStmt_(attr);
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* store) {
+ if (in_async && (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn")) {
+ if (auto* load = store->value.as<BufferLoadNode>()) {
+ if (load->buffer.scope() == "global") {
+ ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
+ ICHECK(load->indices[0]->dtype.lanes() ==
store->indices[0]->dtype.lanes());
+
+ const int indices_lanes = load->indices[0]->dtype.lanes();
+ const int bytes = indices_lanes * load->buffer->dtype.bytes();
+
+ if (bytes == 4 || bytes == 8 || bytes == 16) {
+ auto dst_elem_type =
GetPointerType(store->buffer->data->type_annotation);
+ auto src_elem_type =
GetPointerType(load->buffer->data->type_annotation);
+ ICHECK(dst_elem_type.first && src_elem_type.first)
+ << "Both store and load buffer should have a pointer type
annotation.";
+
+ int index_factor = 1;
+ if (dst_elem_type != src_elem_type) {
+ // The only case where src and dst have different dtypes is when
the dst shared memory
+ // is a byte buffer generated by merging dynamic shared memory.
+ ICHECK(store->buffer.scope() == "shared.dyn");
+ ICHECK(dst_elem_type.second == DataType::UInt(8));
+ // BufferStore/Load have the "pointer reinterpret" semantics
according to their
+ // "value" dtype. Their "indices" are supposed to be applied
after such pointer cast,
+ // for example: ((*float16)(byte_buffer))[buffer->indices] =
fp16_value;
+ // To replace BufferStore/Load with cp.async, we need to
multiply the store index by
+ // the byte size of the "value" dtype, to get the correct offset
into the byte buffer.
+ index_factor = src_elem_type.second.bytes();
+ }
+
+ if (indices_lanes == 1) {
+ auto src_offset = load->indices[0];
+ auto dst_offset = store->indices[0];
+ return Evaluate(
+ Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
+ {store->buffer->data, tir::Mul(dst_offset,
PrimExpr(index_factor)),
+ load->buffer->data, src_offset, PrimExpr(bytes)}));
+ }
+
+ // Only some vectorized indexing patterns are supported for now.
+ auto src_offset = [=]() -> PrimExpr {
+ if (load->indices[0]->IsInstance<RampNode>()) {
+ return load->indices[0].as<RampNode>()->base;
+ }
+ return PrimExpr();
+ }();
+
+ auto dst_offset = [=]() -> PrimExpr {
+ if (store->indices[0].as<RampNode>()) {
+ return store->indices[0].as<RampNode>()->base;
+ } else if (store->indices[0].as<AddNode>()) {
+ // The case where the dst buffer is a byte buffer generated by
merging dynamic
+ // shared memory.
+ // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] =
A_global[ramp(...),1, 8)]
+ auto* add = store->indices[0].as<AddNode>();
+ if (!add->a->IsInstance<RampNode>()) return PrimExpr();
+ if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
+ return tir::Add(add->a.as<RampNode>()->base,
add->b.as<BroadcastNode>()->value);
+ }
+ return PrimExpr();
+ }();
+
+ if (src_offset.defined() && dst_offset.defined()) {
+ return Evaluate(
+ Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
+ {store->buffer->data, tir::Mul(dst_offset,
PrimExpr(index_factor)),
+ load->buffer->data, src_offset, PrimExpr(bytes)}));
+ }
+ }
+ }
+ }
+ }
+ return StmtMutator::VisitStmt_(store);
+ }
+
+ private:
+ bool in_async{false};
+};
+
+namespace transform {
+
+Pass InjectPTXAsyncCopy() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = PTXAsyncCopyInjector()(n->body);
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py
b/tests/python/unittest/test_tir_ptx_cp_async.py
index 17b6088550..5e6535f295 100644
--- a/tests/python/unittest/test_tir_ptx_cp_async.py
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -40,8 +40,8 @@ def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B:
T.Buffer[(32, 128), "floa
)
# TODO(masahi): Remove dtype requirement from TVMScript parser
- T.evaluate(T.ptx_commit_group(dtype="float16"))
- T.evaluate(T.ptx_wait_group(0, dtype="float16"))
+ T.evaluate(T.ptx_commit_group(dtype=""))
+ T.evaluate(T.ptx_wait_group(0, dtype=""))
for i in range(128):
B[tx, i] = A_shared[tx, i]
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
index 9feb994e71..32c1625653 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -76,12 +76,6 @@ def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
return (a, b, c)
-def is_ampere_or_newer():
- arch = tvm.contrib.nvcc.get_target_compute_version()
- major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
- return major >= 8
-
-
def run_test(
k_inner,
in_dtype,
@@ -117,7 +111,7 @@ def run_test(
mma_store_intrin,
)
- if not is_ampere_or_newer():
+ if not tvm.testing.is_ampere_or_newer():
return None
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
new file mode 100644
index 0000000000..d7e13f40aa
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -0,0 +1,183 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+def count_cp_async(stmt):
+ num_alloc = [0]
+
+ def verify(n):
+ if isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async":
+ num_alloc[0] += 1
+
+ tvm.tir.stmt_functor.post_order_visit(stmt, verify)
+ return num_alloc[0]
+
+
+def generate_global_to_shared_vectorized_copy(dtype, vector_size):
+ num_iters = 128 // vector_size
+ vector_size_expr = tvm.runtime.convert(vector_size)
+
+ @T.prim_func
+ def ptx_global_to_shared_copy(
+ A: T.Buffer[(32, 128), dtype], B: T.Buffer[(32, 128), dtype]
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ bx = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(bx, 1)
+ T.launch_thread(tx, 32)
+ with T.block():
+ A_shared = T.alloc_buffer([32, 128], dtype, scope="shared")
+ T.reads(A[0:32, 0:128])
+ T.writes(B[0:32, 0:128])
+
+ T.attr("default", "async_scope", 1)
+ for i in T.serial(num_iters):
+ for j in T.vectorized(vector_size):
+ A_shared[tx, i * vector_size_expr + j] = A[tx, i *
vector_size_expr + j]
+
+ T.evaluate(T.ptx_commit_group(dtype=""))
+ T.evaluate(T.ptx_wait_group(0, dtype=""))
+
+ for i in range(128):
+ B[tx, i] = A_shared[tx, i]
+
+ return ptx_global_to_shared_copy
+
+
[email protected]_func
+def ptx_global_to_shared_copy_fp32x1(
+ A: T.Buffer[(32, 128), "float32"], B: T.Buffer[(32, 128), "float32"]
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ bx = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(bx, 1)
+ T.launch_thread(tx, 32)
+ with T.block():
+ A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
+ T.reads(A[0:32, 0:128])
+ T.writes(B[0:32, 0:128])
+
+ T.attr("default", "async_scope", 1)
+ for i in T.serial(128):
+ A_shared[tx, i] = A[tx, i]
+
+ T.evaluate(T.ptx_commit_group(dtype=""))
+ T.evaluate(T.ptx_wait_group(0, dtype=""))
+
+ for i in range(128):
+ B[tx, i] = A_shared[tx, i]
+
+
[email protected]_func
+def ptx_global_to_shared_dyn_copy_fp16x8(
+ A: T.Buffer[(32, 128), "float16"],
+ B: T.Buffer[(32, 128), "float16"],
+ C: T.Buffer[(32, 128), "float16"],
+) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ bx = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(bx, 1)
+ T.launch_thread(tx, 32)
+ with T.block():
+ A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
+ B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
+ T.reads(A[0:32, 0:128], B[0:32, 0:128])
+ T.writes(C[0:32, 0:128])
+
+ T.attr("default", "async_scope", 1)
+ for i in T.serial(16):
+ for j in T.vectorized(8):
+ A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
+ B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]
+
+ T.evaluate(T.ptx_commit_group(dtype=""))
+ T.evaluate(T.ptx_wait_group(0, dtype=""))
+
+ for i in range(128):
+ C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
+
+
[email protected]_cuda
+def test_inject_async_copy():
+ for dtype, vec_size in [("float16", 8), ("float16", 4), ("float32", 4),
("float32", 1)]:
+ if vec_size == 1:
+ f = ptx_global_to_shared_copy_fp32x1
+ else:
+ f = generate_global_to_shared_vectorized_copy(dtype, vec_size)
+
+ mod = tvm.IRModule.from_expr(f)
+ mod = tvm.tir.transform.FlattenBuffer()(mod)
+ if vec_size > 1:
+ mod = tvm.tir.transform.VectorizeLoop()(mod)
+ mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+
+ assert count_cp_async(mod["main"].body) == 1
+
+ if not tvm.testing.is_ampere_or_newer():
+ continue
+
+ with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+ mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
+
+ A_np = np.random.rand(32, 128).astype(dtype)
+ B_np = np.zeros((32, 128)).astype(dtype)
+ dev = tvm.cuda(0)
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ mod(A_nd, B_nd)
+ tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
[email protected]_cuda
+def test_inject_async_copy_shared_dyn():
+ f = ptx_global_to_shared_dyn_copy_fp16x8
+
+ mod = tvm.IRModule.from_expr(f)
+ mod = tvm.tir.transform.FlattenBuffer()(mod)
+ mod = tvm.tir.transform.VectorizeLoop()(mod)
+ mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod)
+ mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+
+ assert count_cp_async(mod["main"].body) == 2
+
+ if not tvm.testing.is_ampere_or_newer():
+ return
+
+ with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}):
+ mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda")
+
+ A_np = np.random.rand(32, 128).astype("float16")
+ B_np = np.random.rand(32, 128).astype("float16")
+ C_np = np.zeros((32, 128)).astype("float16")
+ dev = tvm.cuda(0)
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ C_nd = tvm.nd.array(C_np, device=dev)
+ mod(A_nd, B_nd, C_nd)
+ tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np)
+
+
+if __name__ == "__main__":
+ test_inject_async_copy()
+ test_inject_async_copy_shared_dyn()