This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cb3763b84d [Disco] Improve RCCL Integration Logic (#15796)
cb3763b84d is described below

commit cb3763b84d6b17448e76541f622275e11f51ddf0
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 20 22:23:39 2023 -0700

    [Disco] Improve RCCL Integration Logic (#15796)
    
    This PR makes the following changes:
    - Move `runtime/disco/ccl` to `runtime/disco/nccl` as RCCL itself uses
      NCCL as a prefix to its APIs;
    - Consolidate `utils.h` to `nccl.cc` to avoid potential misuse of the
      private header file;
    - Reduce the number of macros used in `nccl.cc`.
---
 CMakeLists.txt                                 |  18 +--
 src/runtime/disco/ccl/utils.h                  |  94 ------------
 src/runtime/disco/{ccl/ccl.cc => nccl/nccl.cc} | 189 ++++++++++++++++---------
 3 files changed, 135 insertions(+), 166 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index bc09655923..eb622f3452 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -444,14 +444,16 @@ endif(USE_PROFILER)
 if(USE_CUDA AND USE_NCCL)
   message(STATUS "Build with NCCL...")
   find_nccl(${USE_NCCL})
-  tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/ccl/*.cc)
+  tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
+  set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES 
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
   list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
 endif()
 
 if(USE_ROCM AND USE_RCCL)
   message(STATUS "Build with RCCL...")
   find_rccl(${USE_RCCL})
-  tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/ccl/*.cc)
+  tvm_file_glob(GLOB RUNTIME_RCCL_SRC src/runtime/disco/nccl/*.cc)
+  set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES 
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
   list(APPEND RUNTIME_SRCS ${RUNTIME_RCCL_SRC})
 endif()
 
@@ -888,18 +890,16 @@ if(USE_CUDA AND USE_CUTLASS)
   target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
 endif()
 
-if(USE_CUDA AND USE_NCCL)
-  target_link_libraries(tvm_runtime PRIVATE nccl)
-  target_link_libraries(tvm PRIVATE nccl)
-  set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES 
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
-endif()
-
 if(USE_CUDA AND USE_NVTX)
   set_source_files_properties(src/runtime/nvtx.cc PROPERTIES 
COMPILE_DEFINITIONS "TVM_NVTX_ENABLED=1")
 endif()
 
+if(USE_CUDA AND USE_NCCL)
+  target_link_libraries(tvm PRIVATE nccl)
+  target_link_libraries(tvm_runtime PRIVATE nccl)
+endif()
+
 if(USE_ROCM AND USE_RCCL)
   target_link_libraries(tvm PRIVATE rccl)
   target_link_libraries(tvm_runtime PRIVATE rccl)
-  set_source_files_properties(src/runtime/disco/ccl/ccl.cc PROPERTIES 
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=1")
 endif()
diff --git a/src/runtime/disco/ccl/utils.h b/src/runtime/disco/ccl/utils.h
deleted file mode 100644
index c5066796c0..0000000000
--- a/src/runtime/disco/ccl/utils.h
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_RUNTIME_DISCO_CCL_UTILS_H_
-#define TVM_RUNTIME_DISCO_CCL_UTILS_H_
-
-#include <tvm/runtime/data_type.h>
-#include <tvm/runtime/disco/session.h>
-
-#include "../utils.h"
-
-namespace tvm {
-namespace runtime {
-namespace ccl {
-
-#define NCCL_CALL(cmd)                                                      \
-  do {                                                                      \
-    ncclResult_t r = cmd;                                                   \
-    if (r != ncclSuccess) {                                                 \
-      LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
-    }                                                                       \
-  } while (0)
-
-inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
-  if (dtype == DataType::Int(8)) {
-    return ncclInt8;
-  }
-  if (dtype == DataType::UInt(8)) {
-    return ncclUint8;
-  }
-  if (dtype == DataType::Int(32)) {
-    return ncclInt32;
-  }
-  if (dtype == DataType::UInt(32)) {
-    return ncclUint32;
-  }
-  if (dtype == DataType::Int(64)) {
-    return ncclInt64;
-  }
-  if (dtype == DataType::UInt(64)) {
-    return ncclUint64;
-  }
-  if (dtype == DataType::Float(16)) {
-    return ncclFloat16;
-  }
-  if (dtype == DataType::Float(32)) {
-    return ncclFloat32;
-  }
-  if (dtype == DataType::Float(64)) {
-    return ncclFloat64;
-  }
-  if (dtype == DataType::BFloat(16)) {
-    return ncclBfloat16;
-  }
-  LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
-  throw;
-}
-
-inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
-  switch (kind) {
-    case ReduceKind::kSum:
-      return ncclSum;
-    case ReduceKind::kProd:
-      return ncclProd;
-    case ReduceKind::kMin:
-      return ncclMin;
-    case ReduceKind::kMax:
-      return ncclMax;
-    case ReduceKind::kAvg:
-      return ncclAvg;
-  }
-  LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast<int>(kind);
-  throw;
-}
-
-}  // namespace ccl
-}  // namespace runtime
-}  // namespace tvm
-#endif  // TVM_RUNTIME_DISCO_CCL_UTILS_H_
diff --git a/src/runtime/disco/ccl/ccl.cc b/src/runtime/disco/nccl/nccl.cc
similarity index 68%
rename from src/runtime/disco/ccl/ccl.cc
rename to src/runtime/disco/nccl/nccl.cc
index 8e656f9ab2..8ce6c67d19 100644
--- a/src/runtime/disco/ccl/ccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -16,79 +16,140 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstring>
+#include <mutex>
+#include <sstream>
+#include <vector>
+
+#include "../../../support/process_id.h"
+#include "../utils.h"
+
+/* `TVM_NCCL_RCCL_SWITCH` is set to 0 for NCCL, 1 for RCCL */
 #ifndef TVM_NCCL_RCCL_SWITCH
-#define TVM_NCCL_RCCL_SWITCH 0  // 0: NCCL, 1: RCCL
+#define TVM_NCCL_RCCL_SWITCH 0
 #endif
-
 #if TVM_NCCL_RCCL_SWITCH == 0
-#include <cuda_runtime_api.h>
 #include <nccl.h>
 
 #include "../../cuda/cuda_common.h"
+#else
+#include <rccl/rccl.h>
+
+#include "../../rocm/rocm_common.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace nccl {
 
-using runtimeStream_t = cudaStream_t;
+#define NCCL_CALL(cmd)                                                      \
+  do {                                                                      \
+    auto r = (cmd);                                                         \
+    if (r != ncclSuccess) {                                                 \
+      LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
+    }                                                                       \
+  } while (0)
+
+#if TVM_NCCL_RCCL_SWITCH == 0
 
-#define TVM_DISCO_DEVICE_CALL CUDA_CALL
-#define TVM_DISCO_DEVICE_SET_DEVICE cudaSetDevice
-#define TVM_DISCO_DEVICE_STREAM_CREATE cudaStreamCreate
-#define TVM_DISCO_DEVICE_STREAM_SYNC cudaStreamSynchronize
-#define TVM_DISCO_DEVICE_STREAM_DESTROY cudaStreamDestroy
 #define TVM_DISCO_DEVICE_NAME "cuda"
-#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
 #define TVM_DISCO_CCL_NAME "nccl"
-const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
-#else
-#include <hip/hip_runtime_api.h>
-#include <hip/hip_version.h>
-#include <rccl/rccl.h>
 
-#include "../../rocm/rocm_common.h"
+using deviceStream_t = cudaStream_t;
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
+inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) { 
CUDA_CALL(cudaStreamSynchronize(stream)); }
+inline void StreamCreate(deviceStream_t* stream) { 
CUDA_CALL(cudaStreamCreate(stream)); }
+inline void StreamDestroy(deviceStream_t stream) { 
CUDA_CALL(cudaStreamDestroy(stream)); }
+inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
 
-using runtimeStream_t = hipStream_t;
+#else
 
-#define TVM_DISCO_DEVICE_CALL ROCM_CALL
-#define TVM_DISCO_DEVICE_SET_DEVICE hipSetDevice
-#define TVM_DISCO_DEVICE_STREAM_CREATE hipStreamCreate
-#define TVM_DISCO_DEVICE_STREAM_SYNC hipStreamSynchronize
-#define TVM_DISCO_DEVICE_STREAM_DESTROY hipStreamDestroy
 #define TVM_DISCO_DEVICE_NAME "rocm"
-#define TVM_DISCO_CCL_DESTROY ncclCommDestroy
 #define TVM_DISCO_CCL_NAME "rccl"
-const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
-#endif
 
-#include <dlpack/dlpack.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/disco/session.h>
-#include <tvm/runtime/registry.h>
+using deviceStream_t = hipStream_t;
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
+inline void SetDevice(int device_id) { ROCM_CALL(hipSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) { 
ROCM_CALL(hipStreamSynchronize(stream)); }
+inline void StreamCreate(deviceStream_t* stream) { 
ROCM_CALL(hipStreamCreate(stream)); }
+inline void StreamDestroy(deviceStream_t stream) { 
ROCM_CALL(hipStreamDestroy(stream)); }
 
-#include <cstring>
-#include <mutex>
-#include <sstream>
-#include <vector>
+#endif
 
-#include "../../../support/process_id.h"
-#include "./utils.h"
+inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
+  if (dtype == DataType::Int(8)) {
+    return ncclInt8;
+  }
+  if (dtype == DataType::UInt(8)) {
+    return ncclUint8;
+  }
+  if (dtype == DataType::Int(32)) {
+    return ncclInt32;
+  }
+  if (dtype == DataType::UInt(32)) {
+    return ncclUint32;
+  }
+  if (dtype == DataType::Int(64)) {
+    return ncclInt64;
+  }
+  if (dtype == DataType::UInt(64)) {
+    return ncclUint64;
+  }
+  if (dtype == DataType::Float(16)) {
+    return ncclFloat16;
+  }
+  if (dtype == DataType::Float(32)) {
+    return ncclFloat32;
+  }
+  if (dtype == DataType::Float(64)) {
+    return ncclFloat64;
+  }
+  if (dtype == DataType::BFloat(16)) {
+    return ncclBfloat16;
+  }
+  LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
+  throw;
+}
 
-namespace tvm {
-namespace runtime {
-namespace ccl {
+inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
+  switch (kind) {
+    case ReduceKind::kSum:
+      return ncclSum;
+    case ReduceKind::kProd:
+      return ncclProd;
+    case ReduceKind::kMin:
+      return ncclMin;
+    case ReduceKind::kMax:
+      return ncclMax;
+    case ReduceKind::kAvg:
+      return ncclAvg;
+  }
+  LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast<int>(kind);
+  throw;
+}
 
 struct CCLThreadLocalContext {
   DiscoWorker* worker;
   int device_id;
-  runtimeStream_t default_stream;
+  deviceStream_t default_stream;
   ncclComm_t comm;
 
   void Clear() {
-    NCCL_CALL(TVM_DISCO_CCL_DESTROY(comm));
-    TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_DESTROY(default_stream));
+    NCCL_CALL(ncclCommDestroy(comm));
+    StreamDestroy(default_stream);
   }
 
-  runtimeStream_t GetDefaultStream() {
+  deviceStream_t GetDefaultStream() {
     const auto* func = tvm::runtime::Registry::Get("runtime.get_" 
TVM_DISCO_DEVICE_NAME "_stream");
     ICHECK(func != nullptr);
-    runtimeStream_t stream = static_cast<runtimeStream_t>((*func)().operator 
void*());
+    deviceStream_t stream = static_cast<deviceStream_t>((*func)().operator 
void*());
     return stream == nullptr ? default_stream : stream;
   }
 
@@ -118,8 +179,8 @@ void InitCCLPerWorker(ShapeTuple device_ids, std::string 
unique_id_bytes) {
       << unique_id_bytes.size() << ".";
   // Step up local context of NCCL
   int device_id = device_ids[worker->worker_id];
-  TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_SET_DEVICE(device_id));
-  TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_CREATE(&ctx->default_stream));
+  SetDevice(device_id);
+  StreamCreate(&ctx->default_stream);
   Device device{TVM_DISCO_DEVICE_TYPE, device_id};
   worker->default_device = device;
   worker->ccl = TVM_DISCO_CCL_NAME;
@@ -135,7 +196,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, 
NDArray recv) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ShapeTuple shape = send.Shape();
   int64_t numel = shape->Product();
-  runtimeStream_t stream = ctx->GetDefaultStream();
+  deviceStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
                           /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
                           /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
@@ -146,7 +207,7 @@ void BroadcastFromWorker0(NDArray send, NDArray recv) {
   ICHECK(send.Shape()->Product() == recv.Shape()->Product());
   ShapeTuple shape = send.Shape();
   int64_t numel = shape->Product();
-  runtimeStream_t stream = ctx->GetDefaultStream();
+  deviceStream_t stream = ctx->GetDefaultStream();
   NCCL_CALL(ncclBroadcast(send->data, recv->data, numel,
                           /*datatype=*/AsNCCLDataType(DataType(send->dtype)),
                           /*root=*/0, ctx->comm, stream));
@@ -157,20 +218,21 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray 
recv) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   int worker_id = ctx->worker->worker_id;
   int num_workers = ctx->worker->num_workers;
-  runtimeStream_t stream = ctx->GetDefaultStream();
+  deviceStream_t stream = ctx->GetDefaultStream();
   if (worker_id == 0) {
     CHECK(send.defined()) << "ValueError: buffer `send` must be provided when 
worker_id == 0.";
     NDArray buffer = send.value();
     int64_t numel = buffer.Shape()->Product();
-    CHECK_EQ(numel % num_workers, 0)
-        << "ValueError: Scattering evenly requires that the number of elements 
in the buffer to be "
-           "divisible by the number of workers, but got numel = "
-        << numel << " and " << num_workers << " workers.";
+    CHECK_EQ(numel % num_workers, 0) << "ValueError: Scattering evenly 
requires that the number "
+                                        "of elements in the buffer to be "
+                                        "divisible by the number of workers, 
but got numel = "
+                                     << numel << " and " << num_workers << " 
workers.";
     DataType dtype(buffer->dtype);
     int64_t numel_per_shard = numel / num_workers;
     int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
     CHECK_EQ(numel_per_shard, recv.Shape()->Product())
-        << "ValueError: The number of elements in buffer `recv` must be the 
same as each shard of "
+        << "ValueError: The number of elements in buffer `recv` must be the 
same as each shard "
+           "of "
            "buffer `send`. `send.size` is "
         << numel << ", but `recv.size` is " << recv.Shape()->Product() << ".";
     NCCL_CALL(ncclGroupStart());
@@ -198,20 +260,21 @@ void GatherToWorker0(NDArray send, Optional<NDArray> 
recv) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   int worker_id = ctx->worker->worker_id;
   int num_workers = ctx->worker->num_workers;
-  runtimeStream_t stream = ctx->GetDefaultStream();
+  deviceStream_t stream = ctx->GetDefaultStream();
   if (worker_id == 0) {
     CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when 
worker_id == 0.";
     NDArray buffer = recv.value();
     int64_t numel = buffer.Shape()->Product();
-    CHECK_EQ(numel % num_workers, 0)
-        << "ValueError: Gathering evenly requires that the number of elements 
in the buffer to be "
-           "divisible by the number of workers, but got numel = "
-        << numel << " and " << num_workers << " workers.";
+    CHECK_EQ(numel % num_workers, 0) << "ValueError: Gathering evenly requires 
that the number "
+                                        "of elements in the buffer to be "
+                                        "divisible by the number of workers, 
but got numel = "
+                                     << numel << " and " << num_workers << " 
workers.";
     DataType dtype(buffer->dtype);
     int64_t numel_per_shard = numel / num_workers;
     int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
     CHECK_EQ(numel_per_shard, send.Shape()->Product())
-        << "ValueError: The number of elements in buffer `send` must be the 
same as each shard of "
+        << "ValueError: The number of elements in buffer `send` must be the 
same as each shard "
+           "of "
            "buffer `recv`. `recv.size` is "
         << numel << ", but `send.size` is " << send.Shape()->Product() << ".";
     NCCL_CALL(ncclGroupStart());
@@ -236,7 +299,7 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv) {
 
 void RecvFromWorker0(NDArray buffer) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
-  runtimeStream_t stream = ctx->GetDefaultStream();
+  deviceStream_t stream = ctx->GetDefaultStream();
   CHECK_NE(ctx->worker->worker_id, 0)
       << "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
   NCCL_CALL(ncclGroupStart());
@@ -248,8 +311,8 @@ void RecvFromWorker0(NDArray buffer) {
 void SyncWorker() {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ICHECK(ctx->worker != nullptr);
-  runtimeStream_t stream = ctx->GetDefaultStream();
-  TVM_DISCO_DEVICE_CALL(TVM_DISCO_DEVICE_STREAM_SYNC(stream));
+  deviceStream_t stream = ctx->GetDefaultStream();
+  StreamSynchronize(stream);
 }
 
 TVM_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> 
String {
@@ -273,6 +336,6 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".recv_from_worker0")
     .set_body_typed(RecvFromWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME 
".sync_worker").set_body_typed(SyncWorker);
 
-}  // namespace ccl
+}  // namespace nccl
 }  // namespace runtime
 }  // namespace tvm

Reply via email to