This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 91b6cee634 [Unity][BYOC] Cache cuBlasLt handle with thread entry 
(#15030)
91b6cee634 is described below

commit 91b6cee6341eb9afa67d4832f477ea4478000cae
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Jun 8 02:39:04 2023 +0800

    [Unity][BYOC] Cache cuBlasLt handle with thread entry (#15030)
    
    * [BYOC] Cache cuBlasLt handle with thread entry
    
    * fix
    
    * fix
---
 src/runtime/contrib/cublas/cublas_json_runtime.cc |  9 +++------
 src/runtime/contrib/cublas/cublas_utils.cc        | 13 +++++++++++++
 src/runtime/contrib/cublas/cublas_utils.h         |  7 +++++++
 3 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index b3931fb9fe..fc6b2f5e62 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -52,9 +52,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
   const char* type_key() const override { return "cublas_json"; }  // May be 
overridden
 
   void Run() override {
-    // TODO(masahi): Reuse the same handle across different subgraphs
-    cublasLtHandle_t handle;
-    cublasLtCreate(&handle);
+    auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
 
     for (size_t i = 0; i < nodes_.size(); ++i) {
       const auto& node = nodes_[i];
@@ -88,11 +86,10 @@ class CublasJSONRuntime : public JSONRuntimeBase {
 
         auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != 
CUBLASLT_EPILOGUE_DEFAULT);
 
-        tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, 
transa, transb,
-                                   epilogue);
+        tvm::contrib::CallCublasLt(entry_ptr->handle, a_ptr, b_ptr, bias_ptr, 
out_ptr, transa,
+                                   transb, epilogue);
       }
     }
-    cublasLtDestroy(handle);
   }
 
  private:
diff --git a/src/runtime/contrib/cublas/cublas_utils.cc 
b/src/runtime/contrib/cublas/cublas_utils.cc
index 4b4a1b755e..5cd07cf71d 100644
--- a/src/runtime/contrib/cublas/cublas_utils.cc
+++ b/src/runtime/contrib/cublas/cublas_utils.cc
@@ -48,5 +48,18 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
   return retval;
 }
 
+CuBlasLtThreadEntry::CuBlasLtThreadEntry() { 
CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); }
+
+CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
+  if (handle) {
+    cublasLtDestroy(handle);
+    handle = nullptr;
+  }
+}
+
+typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;
+
+CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal() { return 
CuBlasLtThreadStore::Get(); }
+
 }  // namespace contrib
 }  // namespace tvm
diff --git a/src/runtime/contrib/cublas/cublas_utils.h 
b/src/runtime/contrib/cublas/cublas_utils.h
index ac03b12103..18ea8de8ef 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -77,6 +77,13 @@ struct CuBlasThreadEntry {
   static CuBlasThreadEntry* ThreadLocal();
 };  // CuBlasThreadEntry
 
+struct CuBlasLtThreadEntry {
+  CuBlasLtThreadEntry();
+  ~CuBlasLtThreadEntry();
+  cublasLtHandle_t handle{nullptr};
+  static CuBlasLtThreadEntry* ThreadLocal();
+};  // CuBlasLtThreadEntry
+
 inline cudaDataType_t GetCudaDataType(DLDataType type) {
   if (type.code == kDLInt) {
     switch (type.bits) {

Reply via email to