This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 91b6cee634 [Unity][BYOC] Cache cuBlasLt handle with thread entry
(#15030)
91b6cee634 is described below
commit 91b6cee6341eb9afa67d4832f477ea4478000cae
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jun 8 02:39:04 2023 +0800
[Unity][BYOC] Cache cuBlasLt handle with thread entry (#15030)
* [BYOC] Cache cuBlasLt handle with thread entry
* fix
* fix
---
src/runtime/contrib/cublas/cublas_json_runtime.cc | 9 +++------
src/runtime/contrib/cublas/cublas_utils.cc | 13 +++++++++++++
src/runtime/contrib/cublas/cublas_utils.h | 7 +++++++
3 files changed, 23 insertions(+), 6 deletions(-)
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index b3931fb9fe..fc6b2f5e62 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -52,9 +52,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
const char* type_key() const override { return "cublas_json"; } // May be
overridden
void Run() override {
- // TODO(masahi): Reuse the same handle across different subgraphs
- cublasLtHandle_t handle;
- cublasLtCreate(&handle);
+ auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
for (size_t i = 0; i < nodes_.size(); ++i) {
const auto& node = nodes_[i];
@@ -88,11 +86,10 @@ class CublasJSONRuntime : public JSONRuntimeBase {
auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue !=
CUBLASLT_EPILOGUE_DEFAULT);
- tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr,
transa, transb,
- epilogue);
+ tvm::contrib::CallCublasLt(entry_ptr->handle, a_ptr, b_ptr, bias_ptr,
out_ptr, transa,
+ transb, epilogue);
}
}
- cublasLtDestroy(handle);
}
private:
diff --git a/src/runtime/contrib/cublas/cublas_utils.cc
b/src/runtime/contrib/cublas/cublas_utils.cc
index 4b4a1b755e..5cd07cf71d 100644
--- a/src/runtime/contrib/cublas/cublas_utils.cc
+++ b/src/runtime/contrib/cublas/cublas_utils.cc
@@ -48,5 +48,18 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
return retval;
}
+CuBlasLtThreadEntry::CuBlasLtThreadEntry() {
CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); }
+
+CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
+ if (handle) {
+ cublasLtDestroy(handle);
+ handle = nullptr;
+ }
+}
+
+typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;
+
+CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal() { return
CuBlasLtThreadStore::Get(); }
+
} // namespace contrib
} // namespace tvm
diff --git a/src/runtime/contrib/cublas/cublas_utils.h
b/src/runtime/contrib/cublas/cublas_utils.h
index ac03b12103..18ea8de8ef 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -77,6 +77,13 @@ struct CuBlasThreadEntry {
static CuBlasThreadEntry* ThreadLocal();
}; // CuBlasThreadEntry
+struct CuBlasLtThreadEntry {
+ CuBlasLtThreadEntry();
+ ~CuBlasLtThreadEntry();
+ cublasLtHandle_t handle{nullptr};
+ static CuBlasLtThreadEntry* ThreadLocal();
+}; // CuBlasLtThreadEntry
+
inline cudaDataType_t GetCudaDataType(DLDataType type) {
if (type.code == kDLInt) {
switch (type.bits) {