This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cb3763b84d [Disco] Improve RCCL Integration Logic (#15796)
cb3763b84d is described below
commit cb3763b84d6b17448e76541f622275e11f51ddf0
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 20 22:23:39 2023 -0700
[Disco] Improve RCCL Integration Logic (#15796)
This PR makes the following changes:
- Move `runtime/disco/ccl` to `runtime/disco/nccl` as RCCL itself uses
NCCL as a prefix to its APIs;
- Consolidate `utils.h` to `nccl.cc` to avoid potential misuse of the
private header file;
- Reduce the number of macros used in `nccl.cc`.
---
CMakeLists.txt | 18 +--
src/runtime/disco/ccl/utils.h | 94 ------------
src/runtime/disco/{ccl/ccl.cc => nccl/nccl.cc} | 189 ++++++++++++++++---------
3 files changed, 135 insertions(+), 166 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index bc09655923..eb622f3452 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -444,14 +444,16 @@ endif(USE_PROFILER)
if(USE_CUDA AND USE_NCCL)
message(STATUS "Build with NCCL...")
find_nccl(${USE_NCCL})
- tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/ccl/*.cc)
+ tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
+ set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
endif()
if(USE_ROCM AND USE_RCCL)
message(STATUS "Build with RCCL...")
find_rccl(${USE_RCCL})
- tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/ccl/*.cc)
+ tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/nccl/*.cc)
+ set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
list(APPEND RUNTIME_SRCS ${RUNTIME_RCCL_SRC})
endif()
@@ -888,18 +890,16 @@ if(USE_CUDA AND USE_CUTLASS)
target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
endif()
-if(USE_CUDA AND USE_NCCL)
- target_link_libraries(tvm_runtime PRIVATE nccl)
- target_link_libraries(tvm PRIVATE nccl)
- set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
-endif()
-
if(USE_CUDA AND USE_NVTX)
set_source_files_properties(src/runtime/nvtx.cc PROPERTIES
COMPILE_DEFINITIONS "TVM_NVTX_ENABLED=1")
endif()
+if(USE_CUDA AND USE_NCCL)
+ target_link_libraries(tvm PRIVATE nccl)
+ target_link_libraries(tvm_runtime PRIVATE nccl)
+endif()
+
if(USE_ROCM AND USE_RCCL)
target_link_libraries(tvm PRIVATE rccl)
target_link_libraries(tvm_runtime PRIVATE rccl)
- set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
endif()
diff --git a/src/runtime/disco/ccl/utils.h b/src/runtime/disco/ccl/utils.h
deleted file mode 100644
index c5066796c0..0000000000
--- a/src/runtime/disco/ccl/utils.h
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_RUNTIME_DISCO_CCL_UTILS_H_
-#define TVM_RUNTIME_DISCO_CCL_UTILS_H_
-
-#include <tvm/runtime/data_type.h>
-#include <tvm/runtime/disco/session.h>
-
-#include "../utils.h"
-
-namespace tvm {
-namespace runtime {
-namespace ccl {
-
-#define NCCL_CALL(cmd) \
- do { \
- ncclResult_t r = cmd; \
- if (r != ncclSuccess) { \
- LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
- } \
- } while (0)
-
-inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
- if (dtype == DataType::Int(8)) {
- return ncclInt8;
- }
- if (dtype == DataType::UInt(8)) {
- return ncclUint8;
- }
- if (dtype == DataType::Int(32)) {
- return ncclInt32;
- }
- if (dtype == DataType::UInt(32)) {
- return ncclUint32;
- }
- if (dtype == DataType::Int(64)) {
- return ncclInt64;
- }
- if (dtype == DataType::UInt(64)) {
- return ncclUint64;
- }
- if (dtype == DataType::Float(16)) {
- return ncclFloat16;
- }
- if (dtype == DataType::Float(32)) {
- return ncclFloat32;
- }
- if (dtype == DataType::Float(64)) {
- return ncclFloat64;
- }
- if (dtype == DataType::BFloat(16)) {
- return ncclBfloat16;
- }
- LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
- throw;
-}
-
-inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
- switch (kind) {
- case ReduceKind::kSum:
- return ncclSum;
- case ReduceKind::kProd:
- return ncclProd;
- case ReduceKind::kMin:
- return ncclMin;
- case ReduceKind::kMax:
- return ncclMax;
- case ReduceKind::kAvg:
- return ncclAvg;
- }
- LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast<int>(kind);
- throw;
-}
-
-} // namespace ccl
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_DISCO_CCL_UTILS_H_
diff --git a/src/runtime/disco/ccl/ccl.cc b/src/runtime/disco/nccl/nccl.cc
similarity index 68%
rename from src/runtime/disco/ccl/ccl.cc
rename to src/runtime/disco/nccl/nccl.cc
index 8e656f9ab2..8ce6c67d19 100644
--- a/src/runtime/disco/ccl/ccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -16,79 +16,140 @@
* specific language governing permissions and limitations
* under the License.
*/
+
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstring>
+#include <mutex>
+#include <sstream>
+#include <vector>
+
+#include "../../../support/process_id.h"
+#include "../utils.h"
+
+/* `TVM_NCCL_RCCL_SWITCH` is set to 0 for NCCL, 1 for RCCL */
#ifndef TVM_NCCL_RCCL_SWITCH
-#define TVM_NCCL_RCCL_SWITCH 0 // 0: NCCL, 1: RCCL
+#define TVM_NCCL_RCCL_SWITCH 0
#endif
-
#if TVM_NCCL_RCCL_SWITCH == 0
-#include <cuda_runtime_api.h>
#include <nccl.h>
#include "../../cuda/cuda_common.h"
+#else
+#include <rccl/rccl.h>
+
+#include "../../rocm/rocm_common.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace nccl {
-using runtimeStream_t = cudaStream_t;
+#define NCCL_CALL(cmd) \
+ do { \
+ auto r = (cmd); \
+ if (r != ncclSuccess) { \
+ LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
+ } \
+ } while (0)
+
+#if TVM_NCCL_RCCL_SWITCH == 0
-#define TVM_DISCO_DEVICE_CALL CUDA_CALL
-#define TVM_DISCO_DEVICE_SET_DEVICE cudaSetDevice
-#define TVM_DISCO_DEVICE_STREAM_CREATE cudaStreamCreate
-#define TVM_DISCO_DEVICE_STREAM_SYNC cudaStreamSynchronize
-#define TVM_DISCO_DEVICE_STREAM_DESTROY cudaStreamDestroy
#define TVM_DISCO_DEVICE_NAME "cuda"
-#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
#define TVM_DISCO_CCL_NAME "nccl"
-const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
-#else
-#include <hip/hip_runtime_api.h>
-#include <hip/hip_version.h>
-#include <rccl/rccl.h>
-#include "../../rocm/rocm_common.h"
+using deviceStream_t = cudaStream_t;
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
+inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) {
CUDA_CALL(cudaStreamSynchronize(stream)); }
+inline void StreamCreate(deviceStream_t* stream) {
CUDA_CALL(cudaStreamCreate(stream)); }
+inline void StreamDestroy(deviceStream_t stream) {
CUDA_CALL(cudaStreamDestroy(stream)); }
+inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
-using runtimeStream_t = hipStream_t;
+#else
-#define TVM_DISCO_DEVICE_CALL ROCM_CALL
-#define TVM_DISCO_DEVICE_SET_DEVICE hipSetDevice
-#define TVM_DISCO_DEVICE_STREAM_CREATE hipStreamCreate
-#define TVM_DISCO_DEVICE_STREAM_SYNC hipStreamSynchronize
-#define TVM_DISCO_DEVICE_STREAM_DESTROY hipStreamDestroy
#define TVM_DISCO_DEVICE_NAME "rocm"
-#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
#define TVM_DISCO_CCL_NAME "rccl"
-const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
-#endif
-#include <dlpack/dlpack.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/disco/session.h>
-#include <tvm/runtime/registry.h>
+using deviceStream_t = hipStream_t;
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
+inline void SetDevice(int device_id) { ROCM_CALL(hipSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) {
ROCM_CALL(hipStreamSynchronize(stream)); }
+inline void StreamCreate(deviceStream_t* stream) {
ROCM_CALL(hipStreamCreate(stream)); }
+inline void StreamDestroy(deviceStream_t stream) {
ROCM_CALL(hipStreamDestroy(stream)); }
-#include <cstring>
-#include <mutex>
-#include <sstream>
-#include <vector>
+#endif
-#include "../../../support/process_id.h"
-#include "./utils.h"
+inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
+ if (dtype == DataType::Int(8)) {
+ return ncclInt8;
+ }
+ if (dtype == DataType::UInt(8)) {
+ return ncclUint8;
+ }
+ if (dtype == DataType::Int(32)) {
+ return ncclInt32;
+ }
+ if (dtype == DataType::UInt(32)) {
+ return ncclUint32;
+ }
+ if (dtype == DataType::Int(64)) {
+ return ncclInt64;
+ }
+ if (dtype == DataType::UInt(64)) {
+ return ncclUint64;
+ }
+ if (dtype == DataType::Float(16)) {
+ return ncclFloat16;
+ }
+ if (dtype == DataType::Float(32)) {
+ return ncclFloat32;
+ }
+ if (dtype == DataType::Float(64)) {
+ return ncclFloat64;
+ }
+ if (dtype == DataType::BFloat(16)) {
+ return ncclBfloat16;
+ }
+ LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
+ throw;
+}
-namespace tvm {
-namespace runtime {
-namespace ccl {
+inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
+ switch (kind) {
+ case ReduceKind::kSum:
+ return ncclSum;
+ case ReduceKind::kProd:
+ return ncclProd;
+ case ReduceKind::kMin:
+ return ncclMin;
+ case ReduceKind::kMax:
+ return ncclMax;
+ case ReduceKind::kAvg:
+ return ncclAvg;
+ }
+ LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast<int>(kind);
+ throw;
+}
struct CCLThreadLocalContext {
DiscoWorker* worker;
int device_id;
- runtimeStream_t default_stream;
+ deviceStream_t default_stream;
ncclComm_t comm;
void Clear() {
- NCCL_CALL(TVM_DISCO_CCL_DESTROY(comm));
- TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_DESTROY(default_stream));
+ NCCL_CALL(ncclCommDestroy(comm));
+ StreamDestroy(default_stream);
}
- runtimeStream_t GetDefaultStream() {
+ deviceStream_t GetDefaultStream() {
const auto* func = tvm::runtime::Registry::Get("runtime.get_"
TVM_DISCO_DEVICE_NAME "_stream");
ICHECK(func != nullptr);
- runtimeStream_t stream = static_cast<runtimeStream_t>((*func)().operator
void*());
+ deviceStream_t stream = static_cast<deviceStream_t>((*func)().operator
void*());
return stream == nullptr ? default_stream : stream;
}
@@ -118,8 +179,8 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string
unique_id_bytes) {
<< unique_id_bytes.size() << ".";
// Step up local context of NCCL
int device_id = device_ids[worker->worker_id];
- TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_SET_DEVICE(device_id));
- TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_CREATE(&ctx->default_stream));
+ SetDevice(device_id);
+ StreamCreate(&ctx->default_stream);
Device device{TVM_DISCO_DEVICE_TYPE, device_id};
worker->default_device = device;
worker->ccl = TVM_DISCO_CCL_NAME;
@@ -135,7 +196,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind,
NDArray recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
- runtimeStream_t stream = ctx->GetDefaultStream();
+ deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
@@ -146,7 +207,7 @@ void BroadcastFromWorker0(NDArray send, NDArray recv) {
ICHECK(send.Shape()->Product() == recv.Shape()->Product());
ShapeTuple shape = send.Shape();
int64_t numel = shape->Product();
- runtimeStream_t stream = ctx->GetDefaultStream();
+ deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
/*root=*/0, ctx->comm, stream));
@@ -157,20 +218,21 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray
recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
- runtimeStream_t stream = ctx->GetDefaultStream();
+ deviceStream_t stream = ctx->GetDefaultStream();
if (worker_id == 0) {
CHECK(send.defined()) << "ValueError: buffer `send` must be provided when
worker_id == 0.";
NDArray buffer = send.value();
int64_t numel = buffer.Shape()->Product();
- CHECK_EQ(numel % num_workers, 0)
- << "ValueError: Scattering evenly requires that the number of elements
in the buffer to be "
- "divisible by the number of workers, but got numel = "
- << numel << " and " << num_workers << " workers.";
+ CHECK_EQ(numel % num_workers, 0) << "ValueError: Scattering evenly
requires that the number "
+ "of elements in the buffer to be "
+ "divisible by the number of workers,
but got numel = "
+ << numel << " and " << num_workers << "
workers.";
DataType dtype(buffer->dtype);
int64_t numel_per_shard = numel / num_workers;
int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
CHECK_EQ(numel_per_shard, recv.Shape()->Product())
- << "ValueError: The number of elements in buffer `recv` must be the
same as each shard of "
+ << "ValueError: The number of elements in buffer `recv` must be the
same as each shard "
+ "of "
"buffer `send`. `send.size` is "
<< numel << ", but `recv.size` is " << recv.Shape()->Product() << ".";
NCCL_CALL(ncclGroupStart());
@@ -198,20 +260,21 @@ void GatherToWorker0(NDArray send, Optional<NDArray>
recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
int worker_id = ctx->worker->worker_id;
int num_workers = ctx->worker->num_workers;
- runtimeStream_t stream = ctx->GetDefaultStream();
+ deviceStream_t stream = ctx->GetDefaultStream();
if (worker_id == 0) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when
worker_id == 0.";
NDArray buffer = recv.value();
int64_t numel = buffer.Shape()->Product();
- CHECK_EQ(numel % num_workers, 0)
- << "ValueError: Gathering evenly requires that the number of elements
in the buffer to be "
- "divisible by the number of workers, but got numel = "
- << numel << " and " << num_workers << " workers.";
+ CHECK_EQ(numel % num_workers, 0) << "ValueError: Gathering evenly requires
that the number "
+ "of elements in the buffer to be "
+ "divisible by the number of workers,
but got numel = "
+ << numel << " and " << num_workers << "
workers.";
DataType dtype(buffer->dtype);
int64_t numel_per_shard = numel / num_workers;
int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
CHECK_EQ(numel_per_shard, send.Shape()->Product())
- << "ValueError: The number of elements in buffer `send` must be the
same as each shard of "
+ << "ValueError: The number of elements in buffer `send` must be the
same as each shard "
+ "of "
"buffer `recv`. `recv.size` is "
<< numel << ", but `send.size` is " << send.Shape()->Product() << ".";
NCCL_CALL(ncclGroupStart());
@@ -236,7 +299,7 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
void RecvFromWorker0(NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
- runtimeStream_t stream = ctx->GetDefaultStream();
+ deviceStream_t stream = ctx->GetDefaultStream();
CHECK_NE(ctx->worker->worker_id, 0)
<< "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
NCCL_CALL(ncclGroupStart());
@@ -248,8 +311,8 @@ void RecvFromWorker0(NDArray buffer) {
void SyncWorker() {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(ctx->worker != nullptr);
- runtimeStream_t stream = ctx->GetDefaultStream();
- TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_SYNC(stream));
+ deviceStream_t stream = ctx->GetDefaultStream();
+ StreamSynchronize(stream);
}
TVM_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() ->
String {
@@ -273,6 +336,6 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".recv_from_worker0")
.set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".sync_worker").set_body_typed(SyncWorker);
-} // namespace ccl
+} // namespace nccl
} // namespace runtime
} // namespace tvm