This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 04c6863a25 [Unity][BYOC] Fix cuBLAS BYOC compatibilty with Disco + 
`ThreadedSession` (#15977)
04c6863a25 is described below

commit 04c6863a25747d0fa745b01fcf47696ba85e1388
Author: masahi <[email protected]>
AuthorDate: Thu Oct 26 07:42:10 2023 +0900

    [Unity][BYOC] Fix cuBLAS BYOC compatibilty with Disco + `ThreadedSession` 
(#15977)
    
    Fix cuBLAS BYOC compatibilty with Disco with ThreadedSession
---
 src/runtime/contrib/cublas/cublas_json_runtime.cc | 68 +++++++++++++++++------
 1 file changed, 51 insertions(+), 17 deletions(-)

diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 9617559d7e..c6916d4f86 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -49,21 +49,69 @@ class CublasJSONRuntime : public JSONRuntimeBase {
 
   void Init(const Array<NDArray>& consts) override {}
 
+  PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& 
sptr_to_self) override {
+    // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since 
CublasJSONRuntime
+    // can be used by multiple GPUs running on different threads, we avoid 
using that function
+    // and directly call cuBLAS on the inputs from TVMArgs.
+    if (this->symbol_name_ == name) {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        ICHECK(this->initialized_) << "The module has not been initialized";
+        this->Run(args);
+      });
+    } else {
+      return JSONRuntimeBase::GetFunction(name, sptr_to_self);
+    }
+  }
+
   const char* type_key() const override { return "cublas_json"; }  // May be 
overridden
 
-  void Run() override {
+  void Run(TVMArgs args) {
     auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
 
     auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
     ICHECK(func != nullptr);
     cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
 
+    std::vector<const DLTensor*> dl_tensors(NumEntries());
+
+    for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
+      auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
+                                           : EntryID(outputs_[i - 
input_var_eid_.size()]);
+      ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() 
== kTVMDLTensorHandle)
+          << "Expect NDArray or DLTensor as inputs";
+
+      const DLTensor* arg;
+      if (args[i].IsObjectRef<NDArray>()) {
+        NDArray arr = args[i];
+        arg = arr.operator->();
+      } else {
+        arg = args[i].operator DLTensor*();
+      }
+
+      dl_tensors[eid] = arg;
+    }
+
+    auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
+      ICHECK_LT(idx, node.GetInputs().size());
+      auto eid = EntryID(node.GetInputs()[idx]);
+      ICHECK(eid < dl_tensors.size());
+      return dl_tensors[eid];
+    };
+
+    auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
+      const DLTensor* bias = nullptr;
+      if (has_bias) {
+        bias = get_input(node, 2);
+      }
+      return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
+    };
+
     for (size_t i = 0; i < nodes_.size(); ++i) {
       const auto& node = nodes_[i];
       if (node.GetOpType() == "kernel") {
         auto op_name = node.GetOpName();
         uint32_t output_eid = EntryID(outputs_[0]);
-        auto out_ptr = data_entry_[output_eid];
+        auto out_ptr = dl_tensors[output_eid];
         bool transa = false;
         bool transb = false;
         cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
@@ -80,14 +128,6 @@ class CublasJSONRuntime : public JSONRuntimeBase {
           epilogue = CUBLASLT_EPILOGUE_BIAS;
         }
 
-        auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
-          const DLTensor* bias = nullptr;
-          if (has_bias) {
-            bias = GetInput(node, 2);
-          }
-          return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
-        };
-
         auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != 
CUBLASLT_EPILOGUE_DEFAULT);
 
         tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr, 
bias_ptr, out_ptr,
@@ -96,13 +136,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
     }
   }
 
- private:
-  const DLTensor* GetInput(const JSONGraphNode& node, const int idx) {
-    ICHECK_LT(idx, node.GetInputs().size());
-    auto eid = EntryID(node.GetInputs()[idx]);
-    ICHECK(eid < data_entry_.size());
-    return data_entry_[eid];
-  }
+  void Run() override { LOG(FATAL) << "Unreachable"; }
 };
 
 runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json,

Reply via email to