This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 04c6863a25 [Unity][BYOC] Fix cuBLAS BYOC compatibilty with Disco +
`ThreadedSession` (#15977)
04c6863a25 is described below
commit 04c6863a25747d0fa745b01fcf47696ba85e1388
Author: masahi <[email protected]>
AuthorDate: Thu Oct 26 07:42:10 2023 +0900
[Unity][BYOC] Fix cuBLAS BYOC compatibilty with Disco + `ThreadedSession`
(#15977)
Fix cuBLAS BYOC compatibilty with Disco with ThreadedSession
---
src/runtime/contrib/cublas/cublas_json_runtime.cc | 68 +++++++++++++++++------
1 file changed, 51 insertions(+), 17 deletions(-)
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 9617559d7e..c6916d4f86 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -49,21 +49,69 @@ class CublasJSONRuntime : public JSONRuntimeBase {
void Init(const Array<NDArray>& consts) override {}
+ PackedFunc GetFunction(const String& name, const ObjectPtr<Object>&
sptr_to_self) override {
+ // JSONRuntimeBase::SetInputOutputBuffers(...) is not thread safe. Since
CublasJSONRuntime
+ // can be used by multiple GPUs running on different threads, we avoid
using that function
+ // and directly call cuBLAS on the inputs from TVMArgs.
+ if (this->symbol_name_ == name) {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ ICHECK(this->initialized_) << "The module has not been initialized";
+ this->Run(args);
+ });
+ } else {
+ return JSONRuntimeBase::GetFunction(name, sptr_to_self);
+ }
+ }
+
const char* type_key() const override { return "cublas_json"; } // May be
overridden
- void Run() override {
+ void Run(TVMArgs args) {
auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator
void*());
+ std::vector<const DLTensor*> dl_tensors(NumEntries());
+
+ for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
+ auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
+ : EntryID(outputs_[i -
input_var_eid_.size()]);
+ ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code()
== kTVMDLTensorHandle)
+ << "Expect NDArray or DLTensor as inputs";
+
+ const DLTensor* arg;
+ if (args[i].IsObjectRef<NDArray>()) {
+ NDArray arr = args[i];
+ arg = arr.operator->();
+ } else {
+ arg = args[i].operator DLTensor*();
+ }
+
+ dl_tensors[eid] = arg;
+ }
+
+ auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
+ ICHECK_LT(idx, node.GetInputs().size());
+ auto eid = EntryID(node.GetInputs()[idx]);
+ ICHECK(eid < dl_tensors.size());
+ return dl_tensors[eid];
+ };
+
+ auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
+ const DLTensor* bias = nullptr;
+ if (has_bias) {
+ bias = get_input(node, 2);
+ }
+ return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
+ };
+
for (size_t i = 0; i < nodes_.size(); ++i) {
const auto& node = nodes_[i];
if (node.GetOpType() == "kernel") {
auto op_name = node.GetOpName();
uint32_t output_eid = EntryID(outputs_[0]);
- auto out_ptr = data_entry_[output_eid];
+ auto out_ptr = dl_tensors[output_eid];
bool transa = false;
bool transb = false;
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
@@ -80,14 +128,6 @@ class CublasJSONRuntime : public JSONRuntimeBase {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
- auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
- const DLTensor* bias = nullptr;
- if (has_bias) {
- bias = GetInput(node, 2);
- }
- return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
- };
-
auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue !=
CUBLASLT_EPILOGUE_DEFAULT);
tvm::contrib::CallCublasLt(entry_ptr->handle, stream, a_ptr, b_ptr,
bias_ptr, out_ptr,
@@ -96,13 +136,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
}
}
- private:
- const DLTensor* GetInput(const JSONGraphNode& node, const int idx) {
- ICHECK_LT(idx, node.GetInputs().size());
- auto eid = EntryID(node.GetInputs()[idx]);
- ICHECK(eid < data_entry_.size());
- return data_entry_[eid];
- }
+ void Run() override { LOG(FATAL) << "Unreachable"; }
};
runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json,