This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e257fb8a41 [Runtime] CUDA IPC Memory support and custom allreduce
kernels (#16750)
e257fb8a41 is described below
commit e257fb8a41159a2558dc1fccb5e3dd3c45001820
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 20 19:29:27 2024 -0400
[Runtime] CUDA IPC Memory support and custom allreduce kernels (#16750)
This PR introduces the CUDA IPC memory support in TVM runtime.
IPC memory allows multiple distribtued workers accessing the GPU
memory of each other directly. This functionality is helpful for
implementing customzied communication primitives across distributed
workers.
In this PR, we bring the customized all-reduce implementation
from TensorRT-LLM into 3rdparty. This all-reduce implementation
makes use of the CUDA IPC memory. We expose the all-reduce function
in global function under namespace `tvm::runtime::disco::cuda_ipc`.
One unit test for the customized all-reduce kernel over two workers
is added.
---
Co-authored-by: Hongyi Jin <[email protected]>
---
3rdparty/tensorrt_llm/custom_allreduce_kernels.cu | 400 ++++++++++++++++++++++
3rdparty/tensorrt_llm/custom_allreduce_kernels.h | 48 +++
CMakeLists.txt | 2 +-
LICENSE | 1 +
include/tvm/runtime/disco/cuda_ipc_memory.h | 102 ++++++
include/tvm/runtime/memory/memory_manager.h | 13 +-
LICENSE => licenses/LICENSE.tensorrt_llm.txt | 54 +--
python/tvm/runtime/disco/session.py | 13 +-
src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc | 227 ++++++++++++
src/runtime/disco/cuda_ipc/custom_allreduce.cc | 112 ++++++
src/runtime/disco/nccl/nccl.cc | 117 +------
src/runtime/disco/nccl/nccl_context.h | 147 ++++++++
src/runtime/memory/memory_manager.cc | 9 +-
src/runtime/memory/naive_allocator.h | 2 +-
src/runtime/memory/pooled_allocator.h | 25 +-
src/runtime/relax_vm/builtin.cc | 1 +
src/runtime/vm/vm.cc | 2 +
tests/python/disco/test_custom_allreduce.py | 78 +++++
18 files changed, 1168 insertions(+), 185 deletions(-)
diff --git a/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
b/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
new file mode 100644
index 0000000000..6dec368b43
--- /dev/null
+++ b/3rdparty/tensorrt_llm/custom_allreduce_kernels.cu
@@ -0,0 +1,400 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <dmlc/logging.h>
+#include <stdint.h>
+
+#include "custom_allreduce_kernels.h"
+
+namespace tensorrt_llm {
+
+static inline __device__ void st_flag_release(uint32_t& flag, uint32_t*
flag_addr) {
+#if __CUDA_ARCH__ >= 700
+ asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
+#else
+ __threadfence_system();
+ asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t*
flag_addr) {
+#if __CUDA_ARCH__ >= 700
+ asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) :
"l"(flag_addr));
+#else
+ asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) :
"l"(flag_addr));
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Type Converter that packs data format to 128 bits data type
+//
+using PackedFloat = union {
+ int4 packed;
+ float unpacked[4];
+};
+
+using PackedHalf = union {
+ int4 packed;
+ half2 unpacked[4];
+};
+
+template <typename T>
+struct PackedOn16Bytes {};
+
+template <>
+struct PackedOn16Bytes<float> {
+ using Type = PackedFloat;
+};
+
+template <>
+struct PackedOn16Bytes<half> {
+ using Type = PackedHalf;
+};
+
+#ifdef ENABLE_BF16
+using PackedBFloat16 = union {
+ int4 packed;
+ __nv_bfloat162 unpacked[4];
+};
+
+template <>
+struct PackedOn16Bytes<__nv_bfloat16> {
+ using Type = PackedBFloat16;
+};
+#endif
+
+// add two 128b data
+template <typename T>
+inline __device__ int4 add128b(T& a, T& b) {
+ T c;
+ c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
+ c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
+ c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
+ c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
+ return c.packed;
+}
+
+__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, const
uint32_t flag,
+ const size_t rank, const size_t
world_size,
+ int const tidx, int const bidx) {
+ // At the end of the function, we now that has least block 0 from all others
GPUs have reached
+ // that point.
+ uint32_t volatile* my_signals = signals[rank];
+ if (tidx < world_size) {
+ // The 1st block notifies the other ranks.
+ if (bidx == 0) {
+ signals[tidx][rank] = flag;
+ }
+
+ // Busy-wait until all ranks are ready.
+ while (my_signals[tidx] != flag) {
+ }
+ }
+
+ // Make sure we can move on...
+ __syncthreads();
+}
+
+__global__ void multiGpuBarrierKernel(AllReduceParams params) {
+ multi_gpu_barrier(params.peer_barrier_ptrs_out, params.barrier_flag,
params.local_rank,
+ params.ranks_per_node, threadIdx.x, blockIdx.x);
+}
+
+template <typename T, int RANKS_PER_NODE>
+static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
+ int const bidx = blockIdx.x;
+ int const tidx = threadIdx.x;
+
+ // The number of elements packed into one for comms
+ static constexpr int NUM_ELTS = 16 / sizeof(T);
+
+ // Packed data type for comms
+ using PackedStruct = typename PackedOn16Bytes<T>::Type;
+
+ multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag,
params.local_rank,
+ RANKS_PER_NODE, tidx, bidx);
+
+ // The source pointers. Distributed round-robin for the different warps.
+ T const* src_d[RANKS_PER_NODE];
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ int rank = (params.local_rank + ii) % RANKS_PER_NODE;
+ src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
+ }
+
+ // The location in the destination array (load 8 fp16 or load 4 fp32 using
LDG.128).
+ size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
+ // The end of the segment computed by that block.
+ size_t max_offset = min((bidx + 1) * params.elts_per_block,
params.elts_per_rank);
+
+ // Each block accumulates the values from the different GPUs on the same
node.
+ for (size_t iter_offset = offset; iter_offset < max_offset;
+ iter_offset += blockDim.x * NUM_ELTS) {
+ // Iterate over the different ranks/devices on the node to load the values.
+ PackedStruct vals[RANKS_PER_NODE];
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ vals[ii].packed = *reinterpret_cast<int4
const*>(&src_d[ii][iter_offset]);
+ }
+
+ // Sum the values from the different ranks.
+ PackedStruct sums;
+ sums.packed = {0, 0, 0, 0};
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ sums.packed = add128b(sums, vals[ii]);
+ }
+
+ // Store to the destination buffer.
+
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset])
=
+ sums.packed;
+ }
+}
+
+template <typename T, int RANKS_PER_NODE>
+static __global__ void twoShotAllReduceKernel(AllReduceParams params) {
+ // The block index.
+ int const bidx = blockIdx.x;
+ // The thread index with the block.
+ int const tidx = threadIdx.x;
+
+ // The number of elements packed into one for comms
+ static constexpr int NUM_ELTS = 16 / sizeof(T);
+
+ // Packed data type for comms
+ using PackedType = typename PackedOn16Bytes<T>::Type;
+
+ // The location in the destination array (load 8 fp16 or load 4 fp32 using
LDG.128).
+ const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
+ const size_t block_start = params.rank_offset + block_offset;
+ // The end of the segment computed by that block.
+ size_t max_offset =
+ min(block_start + params.elts_per_block, params.rank_offset +
params.elts_per_rank);
+
+ multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag,
params.local_rank,
+ RANKS_PER_NODE, tidx, bidx);
+
+ // The source pointers. Distributed round-robin for the different warps.
+ T* src_d[RANKS_PER_NODE];
+ // The destination ranks for round-robin gathering
+ size_t dst_rank[RANKS_PER_NODE];
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ int rank = (params.local_rank + ii) % RANKS_PER_NODE;
+ src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
+ dst_rank[ii] = rank;
+ }
+
+ // Each block accumulates the values from the different GPUs on the same
node.
+ for (size_t local_offset = block_start; local_offset < max_offset;
+ local_offset += blockDim.x * NUM_ELTS) {
+ // Iterate over the different ranks/devices on the node to load the values.
+ PackedType vals[RANKS_PER_NODE];
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ vals[ii].packed = *reinterpret_cast<int4
const*>(&src_d[ii][local_offset]);
+ }
+
+ // Sum the values from the different ranks.
+ PackedType sums;
+ sums.packed = {0, 0, 0, 0};
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ sums.packed = add128b(sums, vals[ii]);
+ }
+
+ // Store to the local buffer.
+ *reinterpret_cast<int4*>(&src_d[0][local_offset]) = sums.packed;
+ }
+
+ // sync threads to make sure all block threads have the sums
+ __syncthreads();
+
+ // barriers among the blocks with the same idx (release-acquire semantics)
+ if (tidx < RANKS_PER_NODE) {
+ // The all blocks notifies the other ranks.
+ uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
+ st_flag_release(params.barrier_flag,
+ params.peer_barrier_ptrs_in[tidx] + flag_block_offset +
params.local_rank);
+
+ // Busy-wait until all ranks are ready.
+ uint32_t rank_barrier = 0;
+ uint32_t* peer_barrier_d =
+ params.peer_barrier_ptrs_in[params.local_rank] + flag_block_offset +
tidx;
+ do {
+ ld_flag_acquire(rank_barrier, peer_barrier_d);
+ } while (rank_barrier != params.barrier_flag);
+ }
+
+ // sync threads to make sure all other ranks has the final partial results
+ __syncthreads();
+
+ size_t max_block_offset = min(block_offset + params.elts_per_block,
params.elts_per_rank);
+ // Gather all needed elts from other intra-node ranks
+ for (size_t local_offset = block_offset; local_offset < max_block_offset;
+ local_offset += blockDim.x * NUM_ELTS) {
+#pragma unroll
+ for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
+ // use round-robin gathering from other ranks
+ size_t offset_rank = dst_rank[ii] * params.elts_per_rank + local_offset;
+ if (offset_rank >= params.elts_total) {
+ continue;
+ }
+
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])
=
+ *reinterpret_cast<int4*>(&src_d[ii][offset_rank]);
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline int divUp(int a, int b) { return (a + b - 1) / b; }
+
+std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo,
AllReduceParams& param,
+ size_t elts_per_thread) {
+ ICHECK(param.elts_total % elts_per_thread == 0);
+
+ int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
+
+ const size_t total_threads = param.elts_total / elts_per_thread;
+ switch (algo) {
+ case AllReduceStrategyType::ONESHOT: { // one stage all reduce algo
+ if (total_threads <= DEFAULT_BLOCK_SIZE) { // local reduce
+ threads_per_block = WARP_SIZE * divUp(total_threads, WARP_SIZE);
+ blocks_per_grid = 1;
+ } else { // local reduce
+ threads_per_block = DEFAULT_BLOCK_SIZE;
+ blocks_per_grid = divUp(total_threads, DEFAULT_BLOCK_SIZE);
+ blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS),
blocks_per_grid);
+ }
+ param.elts_per_rank = param.elts_total;
+ param.elts_per_block =
+ elts_per_thread * divUp(param.elts_per_rank, elts_per_thread *
blocks_per_grid);
+ break;
+ }
+ case AllReduceStrategyType::TWOSHOT: { // two stage all reduce algo
+ const size_t elts_per_rank = param.elts_total / param.ranks_per_node;
+ ICHECK(elts_per_rank % elts_per_thread == 0);
+
+ size_t total_threads = elts_per_rank / elts_per_thread;
+ total_threads = WARP_SIZE * ((total_threads + WARP_SIZE - 1) /
WARP_SIZE);
+ ICHECK(total_threads % WARP_SIZE == 0);
+
+ while (total_threads % blocks_per_grid != 0 ||
+ total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
+ blocks_per_grid += 1;
+ }
+
+ threads_per_block = total_threads / blocks_per_grid;
+
+ // NOTE: need to adjust here
+ if (static_cast<size_t>(blocks_per_grid) > MAX_ALL_REDUCE_BLOCKS) {
+ size_t iter_factor = 1;
+ while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS ||
+ blocks_per_grid % iter_factor) {
+ iter_factor += 1;
+ }
+ blocks_per_grid /= iter_factor;
+ }
+ param.elts_per_rank = param.elts_total / param.ranks_per_node;
+ param.elts_per_block = param.elts_per_rank / blocks_per_grid;
+ param.elts_per_block = elts_per_thread * divUp(param.elts_per_block,
elts_per_thread);
+ param.rank_offset = param.rank * param.elts_per_rank;
+ break;
+ }
+ default:
+ LOG(FATAL) << ("Algorithm not supported here.");
+ }
+
+ return std::make_tuple(blocks_per_grid, threads_per_block);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename T, int RANKS_PER_NODE>
+void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int
blocks_per_grid,
+ int threads_per_block, cudaStream_t stream) {
+ if (algo == AllReduceStrategyType::ONESHOT) {
+ oneShotAllReduceKernel<T, RANKS_PER_NODE>
+ <<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
+ } else {
+ twoShotAllReduceKernel<T, RANKS_PER_NODE>
+ <<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
+ }
+}
+
+template <typename T>
+void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param,
AllReduceStrategyType strat,
+ cudaStream_t stream) {
+ ICHECK(strat == AllReduceStrategyType::ONESHOT || strat ==
AllReduceStrategyType::TWOSHOT);
+ auto last_error = cudaGetLastError();
+ if (last_error != cudaSuccess) {
+ LOG(INFO) << "cuda error:" << cudaGetErrorString(last_error);
+ }
+
+ size_t elts_per_thread = 16 / sizeof(T);
+ auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param,
elts_per_thread);
+ switch (param.ranks_per_node) {
+ case 2:
+ dispatchARKernels<T, 2>(strat, param, blocks_per_grid,
threads_per_block, stream);
+ break;
+ case 4:
+ dispatchARKernels<T, 4>(strat, param, blocks_per_grid,
threads_per_block, stream);
+ break;
+ case 6:
+ dispatchARKernels<T, 6>(strat, param, blocks_per_grid,
threads_per_block, stream);
+ break;
+ case 8:
+ dispatchARKernels<T, 8>(strat, param, blocks_per_grid,
threads_per_block, stream);
+ break;
+ default:
+ break;
+ }
+ last_error = cudaGetLastError();
+ if (last_error != cudaSuccess) {
+ LOG(INFO) << "cuda error:" << cudaGetErrorString(last_error);
+ }
+}
+
+void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream) {
+ multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
+}
+
+void customAllReduce(AllReduceParams& params, void* data, size_t elts,
DLDataType dataType,
+ AllReduceStrategyType strat, cudaStream_t stream) {
+ params.local_output_buffer_ptr = data;
+ params.elts_total = elts;
+
+ if (dataType.code == kDLFloat && dataType.bits == 32) {
+ invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
+ } else if (dataType.code == kDLFloat && dataType.bits == 16) {
+ invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
+ }
+#ifdef ENABLE_BF16
+ else if (dataType.code == kDLBfloat && dataType.bits == 16) {
+ invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
+ }
+#endif
+ else {
+ LOG(FATAL) << ("Unsupported dataType for customAllReduce");
+ }
+}
+
+} // namespace tensorrt_llm
diff --git a/3rdparty/tensorrt_llm/custom_allreduce_kernels.h
b/3rdparty/tensorrt_llm/custom_allreduce_kernels.h
new file mode 100644
index 0000000000..7fd66e5d10
--- /dev/null
+++ b/3rdparty/tensorrt_llm/custom_allreduce_kernels.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <stdint.h>
+
+namespace tensorrt_llm {
+
+constexpr size_t WARP_SIZE = 32;
+constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
+constexpr size_t MAX_RANKS_PER_NODE = 8;
+constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
+
+enum class AllReduceStrategyType : int8_t {
+ ONESHOT = 1,
+ TWOSHOT = 2,
+};
+
+struct AllReduceParams {
+ size_t elts_total;
+ size_t elts_per_rank;
+ size_t elts_per_block;
+ size_t rank_offset;
+ size_t ranks_per_node, rank, local_rank;
+ uint32_t barrier_flag;
+ uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
+ uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
+ void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
+ void* local_output_buffer_ptr;
+};
+
+void customAllReduce(AllReduceParams& params, void* data, size_t elts,
DLDataType dataType,
+ AllReduceStrategyType strat, cudaStream_t stream);
+
+} // namespace tensorrt_llm
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 906509004a..a7db4b7b6e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -454,7 +454,7 @@ endif(USE_PROFILER)
if(USE_CUDA AND USE_NCCL)
message(STATUS "Build with NCCL...")
find_nccl(${USE_NCCL})
- tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
+ tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc
src/runtime/disco/cuda_ipc/*.cc 3rdparty/tensorrt_llm/*.cu)
set_source_files_properties(src/runtime/disco/nccl/nccl.cc PROPERTIES
COMPILE_DEFINITIONS "TVM_NCCL_RCCL_SWITCH=0")
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
endif()
diff --git a/LICENSE b/LICENSE
index 1d26fab957..82c7871cc6 100644
--- a/LICENSE
+++ b/LICENSE
@@ -215,6 +215,7 @@ Apache Software Foundation License 2.0
3rdparty/mlperftiny
3rdparty/nvbench (with LLVM exception)
3rdparty/cutlass_fpA_intB_gemm
+3rdparty/tensorrt_llm
BSD 2-clause License
--------------------
diff --git a/include/tvm/runtime/disco/cuda_ipc_memory.h
b/include/tvm/runtime/disco/cuda_ipc_memory.h
new file mode 100644
index 0000000000..120e6a5431
--- /dev/null
+++ b/include/tvm/runtime/disco/cuda_ipc_memory.h
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_
+#define TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/memory/memory_manager.h>
+#include <tvm/runtime/object.h>
+
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace cuda_ipc {
+
+/*!
+ * \brief The CUDA IPC (interprocess communication) memory object,
+ * which internally contains data pointers to CUDA IPC memory.
+ * It is be useful for efficient all-reduce implementation.
+ * \note Right now the class members are closely tied with customized
+ * all-reduce kernel. They may also be extended for other uses in
+ * the future.
+ */
+class CUDAIPCMemoryObj : public Object {
+ public:
+ /*! \brief The number of GPU workers. */
+ int num_workers;
+ /*! \brief The worker id corresponding to this IPC memory object. */
+ int worker_id;
+ /*!
+ * \brief The data pointers of all all-reduce inputs.
+ * It has "num_workers" pointers. The i-th pointer is the data pointer on
worker i.
+ * If "i != worker_id", the pointer is an IPC data pointer.
+ * Otherwise, the pointer is a local CUDA data pointer.
+ */
+ std::vector<void*> remote_data;
+
+ // We introduce the barrier helper data below per CUDAIPCMemory object
+ // so that they can be used by custom collective operations and allow
+ // fine-grained synchronization on each buffer. These barriers have
+ // low overhead, and can potentially enable concurrent execution of
+ // kernels in future.
+ /*!
+ * \brief The pointers to input barrier signals of all workers for
all-reduce.
+ * It has "num_workers" pointers, and the pointer arrangement is the same as
"remote_data".
+ */
+ std::vector<void*> barrier_in;
+ /*!
+ * \brief The pointers to output barrier signals of all workers for
all-reduce.
+ * It has "num_workers" pointers, and the pointer arrangement is the same as
"remote_data".
+ */
+ std::vector<void*> barrier_out;
+ /*! \brief The integer buffer flag for all-reduce. */
+ int barrier_flag;
+
+ static constexpr const char* _type_key = "tvm.runtime.disco.cuda_ipc_memory";
+ static constexpr const bool _type_has_method_sequal_reduce = false;
+ static constexpr const bool _type_has_method_shash_reduce = false;
+ TVM_DECLARE_BASE_OBJECT_INFO(CUDAIPCMemoryObj, Object);
+};
+
+/*!
+ * \brief Managed reference to CUDAIPCMemoryObj.
+ * \sa CUDAIPCMemory
+ */
+class CUDAIPCMemory : public ObjectRef {
+ public:
+ /*! \brief Get the global singleton CUDAIPCMemory allocator. */
+ TVM_DLL static memory::Allocator* GlobalAllocator();
+ /*!
+ * \brief Given a local CUDA data pointer, return the CUDAIPCMemory object
of the pointer.
+ * \note The pointer's CUDAIPCMemory is expected to have been allocated
+ * through global function "cuda_ipc.alloc_storage". Or otherwise this
+ * function will raise exception.
+ */
+ TVM_DLL static CUDAIPCMemory GetIPCMemoryFromDevicePtr(void* ptr);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAIPCMemory, ObjectRef,
CUDAIPCMemoryObj);
+};
+
+} // namespace cuda_ipc
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_
diff --git a/include/tvm/runtime/memory/memory_manager.h
b/include/tvm/runtime/memory/memory_manager.h
index 6b8aa9e666..7ae7058896 100644
--- a/include/tvm/runtime/memory/memory_manager.h
+++ b/include/tvm/runtime/memory/memory_manager.h
@@ -99,6 +99,10 @@ class Allocator {
*/
TVM_DLL virtual size_t UsedMemory() const = 0;
+ protected:
+ /*! \brief Check if the given memory scope is allowed to allocate by the
allocator. */
+ TVM_DLL virtual bool AllowMemoryScope(const std::string& mem_scope) const;
+
private:
AllocatorType type_;
};
@@ -137,6 +141,8 @@ class StorageObj : public Object {
public:
/*! \brief The index into the VM function table. */
Buffer buffer;
+ /*! \brief The allocator where the storage buffer is allocated from. */
+ Allocator* allocator;
/*! \brief Allocate an NDArray from a given piece of storage. */
TVM_DLL NDArray AllocNDArray(int64_t offset, ShapeTuple shape, DLDataType
dtype);
@@ -144,10 +150,7 @@ class StorageObj : public Object {
/*! \brief The deleter for an NDArray when allocated from underlying
storage. */
static void Deleter(Object* ptr);
- ~StorageObj() {
- auto alloc = MemoryManager::Global()->GetAllocator(buffer.device,
buffer.alloc_type);
- alloc->Free(buffer);
- }
+ ~StorageObj() { allocator->Free(buffer); }
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "vm.Storage";
@@ -157,7 +160,7 @@ class StorageObj : public Object {
/*! \brief reference to storage. */
class Storage : public ObjectRef {
public:
- TVM_DLL explicit Storage(Buffer buffer);
+ TVM_DLL explicit Storage(Buffer buffer, Allocator* allocator);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj);
};
diff --git a/LICENSE b/licenses/LICENSE.tensorrt_llm.txt
similarity index 91%
copy from LICENSE
copy to licenses/LICENSE.tensorrt_llm.txt
index 1d26fab957..d645695673 100644
--- a/LICENSE
+++ b/licenses/LICENSE.tensorrt_llm.txt
@@ -1,3 +1,4 @@
+
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
@@ -178,7 +179,7 @@
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "{}"
+ boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
@@ -186,7 +187,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright {yyyy} {name of copyright owner}
+ Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -199,52 +200,3 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-
-------------------------------------------------------------------------------------
-This product bundles various third-party components under other open source
licenses.
-This section summarizes those components and their licenses. See licenses/
-for text of these licenses.
-
-
-Apache Software Foundation License 2.0
---------------------------------------
-
-3rdparty/dlpack
-3rdparty/dmlc-core
-3rdparty/OpenCL-Headers
-3rdparty/mlperftiny
-3rdparty/nvbench (with LLVM exception)
-3rdparty/cutlass_fpA_intB_gemm
-
-BSD 2-clause License
---------------------
-
-3rdparty/picojson
-3rdparty/dmlc-core/include/dmlc/concurrentqueue.h
-
-
-BSD 2-clause License + zlib License
------------------------------------
-
-3rdparty/dmlc-core/include/dmlc/blockingconcurrentqueue.h
-
-
-MIT License
------------
-
-3rdparty/libcrc
-3rdparty/cma
-3rdparty/compiler-rt/builtin_fp16.h
-3rdparty/cnpy
-
-The Unlicense
--------------
-
-3rdparty/rang
-
-BSD 3-Clause "New" or "Revised" License
----------------------------------------
-
-3rdparty/cutlass
-3rdparty/libbacktrace
-3rdparty/libflash_attn
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index 53b362f579..344212a2f6 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -20,13 +20,11 @@ with the distributed runtime.
import os
import pickle
-
-
from typing import Any, Callable, Optional, Sequence, Union
import numpy as np
-from ..._ffi import register_object, register_func
+from ..._ffi import get_global_func, register_func, register_object
from ..._ffi.runtime_ctypes import Device
from ..container import ShapeTuple
from ..ndarray import NDArray
@@ -283,7 +281,8 @@ class Session(Object):
The device IDs to be used by the underlying communication library.
"""
assert ccl in ("nccl", "rccl"), f"Unsupported CCL backend: {ccl}"
- return _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) #
type: ignore # pylint: disable=no-member
+ _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type:
ignore # pylint: disable=no-member
+ self._clear_ipc_memory_pool()
def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
"""Broadcast an array from worker-0 to all other workers.
@@ -365,6 +364,12 @@ class Session(Object):
func = self._get_cached_method("runtime.disco.allgather")
func(src, dst)
+ def _clear_ipc_memory_pool(self):
+ # Clear the IPC memory allocator when the allocator exists.
+ name = "runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear"
+ if get_global_func(name, allow_missing=True) is not None:
+ self.call_packed(self.get_global_func(name))
+
@register_object("runtime.disco.ThreadedSession")
class ThreadedSession(Session):
diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
new file mode 100644
index 0000000000..451c3df0cb
--- /dev/null
+++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_runtime.h>
+#include <tvm/runtime/disco/cuda_ipc_memory.h>
+#include <tvm/runtime/memory/memory_manager.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h"
+#include "../../cuda/cuda_common.h"
+#include "../../memory/pooled_allocator.h"
+#include "../nccl/nccl_context.h"
+
+namespace tvm {
+namespace runtime {
+namespace cuda_ipc {
+
+using tensorrt_llm::MAX_ALL_REDUCE_BLOCKS;
+using tensorrt_llm::MAX_RANKS_PER_NODE;
+using tvm::runtime::memory::Buffer;
+
+/*!
+ * \brief All-gather the IPC memory handles across all distributed workers.
+ * On each worker, we copy the IPC handle to GPU memory. And nccl AllGather
+ * is reused to all-gather the handles. Finally the all-gathered handles
+ * on each worker are copied from GPU to CPU.
+ */
+std::vector<cudaIpcMemHandle_t>
AllGatherIPCHandles(nccl::CCLThreadLocalContext* ctx,
+ cudaIpcMemHandle_t
local_handle) {
+ void *d_src, *d_dst;
+ CUDA_CALL(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE));
+ CUDA_CALL(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE *
ctx->worker->num_workers));
+ CUDA_CALL(cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE,
cudaMemcpyHostToDevice));
+ NCCL_CALL(
+ ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, ctx->comm,
/*stream=*/nullptr));
+ std::vector<char> serial_handles(CUDA_IPC_HANDLE_SIZE *
ctx->worker->num_workers, 0);
+ CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst,
+ CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers,
cudaMemcpyDefault));
+ std::vector<cudaIpcMemHandle_t> handles(ctx->worker->num_workers);
+ for (int i = 0; i < ctx->worker->num_workers; ++i) {
+ memcpy(handles[i].reserved, &serial_handles[i * CUDA_IPC_HANDLE_SIZE],
CUDA_IPC_HANDLE_SIZE);
+ }
+ CUDA_CALL(cudaFree(d_src));
+ CUDA_CALL(cudaFree(d_dst));
+ return handles;
+}
+
+/*!
+ * \brief The memory allocator of CUDAIPCMemory.
+ * Overriding PooledAllocator for efficient memory management.
+ */
+class CUDAIPCMemoryAllocator final : public memory::PooledAllocator {
+ public:
+ explicit CUDAIPCMemoryAllocator() : PooledAllocator() {}
+
+ bool AllowMemoryScope(const std::string& mem_scope) const final {
+ // The allowed memory scope of CUDAIPCMemory is "ipc_memory";
+ return mem_scope == "ipc_memory";
+ }
+
+ CUDAIPCMemory GetIPCMemoryFromDevicePtr(void* ptr) const {
+ auto it = ipc_memory_map_.find(ptr);
+ CHECK(it != ipc_memory_map_.end())
+ << "The given pointer's CUDAIPCMemory object does not exist. Please
use global function "
+ "\"cuda_ipc.alloc_storage\" to allocate the CUDAIPCMemory object
first.";
+ return it->second;
+ }
+
+ /*! \brief Return the global CUDAIPCMemory singleton allocator. */
+ static CUDAIPCMemoryAllocator* Global() {
+ static CUDAIPCMemoryAllocator* allocator = new CUDAIPCMemoryAllocator();
+ return allocator;
+ }
+
+ private:
+ void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment,
+ DLDataType type_hint) final {
+ auto [data_ptr, data_comm_ptrs] = AllocIPCMemory(dev, size, alignment,
type_hint);
+ int barrier_ptr_size = sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) *
MAX_RANKS_PER_NODE;
+ auto [barrier_in_ptr, barrier_in_comm_ptrs] =
+ AllocIPCMemory(dev, barrier_ptr_size, alignment, DataType::UInt(32));
+ auto [barrier_out_ptr, barrier_out_comm_ptrs] =
+ AllocIPCMemory(dev, barrier_ptr_size, alignment, DataType::UInt(32));
+ // Initialize the barrier values to 0 to avoid synchronization issue.
+ CUDA_CALL(cudaMemset(barrier_in_ptr, 0, barrier_ptr_size));
+ CUDA_CALL(cudaMemset(barrier_out_ptr, 0, barrier_ptr_size));
+
+ // Create the CUDAIPCMemory object.
+ ObjectPtr<CUDAIPCMemoryObj> ipc_memory = make_object<CUDAIPCMemoryObj>();
+ nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get();
+ ipc_memory->remote_data = data_comm_ptrs;
+ ipc_memory->barrier_in = barrier_in_comm_ptrs;
+ ipc_memory->barrier_out = barrier_out_comm_ptrs;
+ ipc_memory->barrier_flag = 1;
+ ipc_memory->num_workers = nccl_ctx->worker->num_workers;
+ ipc_memory->worker_id = nccl_ctx->worker->worker_id;
+ ipc_memory_map_[data_ptr] = CUDAIPCMemory(std::move(ipc_memory));
+ return data_ptr;
+ }
+
+ void DeviceFreeDataSpace(Device dev, void* ptr) final {
+ ICHECK(dev.device_type == kDLCUDA);
+ CUDA_CALL(cudaSetDevice(dev.device_id));
+ nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
+ auto it = ipc_memory_map_.find(ptr);
+ ICHECK(it != ipc_memory_map_.end());
+ FreeIPCMemory(it->second->remote_data, ctx->worker->worker_id);
+ FreeIPCMemory(it->second->barrier_in, ctx->worker->worker_id);
+ FreeIPCMemory(it->second->barrier_out, ctx->worker->worker_id);
+ ipc_memory_map_.erase(it);
+ }
+
+ /*!
+ * \brief Allocate CUDA memory with the required size, alignment and dtype,
+ * and return the IPC memory data pointers.
+ * \returns The local data pointer of the allocated CUDA memory,
+ * and a list of pointers that contains the CUDA IPC memory pointer
+ * of the allocated memory on each worker.
+ * For the i-th pointer, if i is the worker id of the given device,
+ * then the returned i-th pointer points to the local CUDA memory,
+ * or otherwise it is an IPC memory pointer.
+ * \details This function first allocates local memory on every worker,
+ * and creates an IPC memory pointer for the local memory.
+ * Then it uses nccl all-gather to synchronize the IPC memory pointers
+ * across all workers, so that every worker know each other's IPC memory
+ * pointer.
+ */
+ std::pair<void*, std::vector<void*>> AllocIPCMemory(Device dev, size_t size,
size_t alignment,
+ DLDataType type_hint) {
+ // Alloc local buffer
+ ICHECK(dev.device_type == kDLCUDA);
+ void* ptr;
+ CUDA_CALL(cudaSetDevice(dev.device_id));
+ CUDA_CALL(cudaMalloc(&ptr, size));
+ // Create ipc handle
+ cudaIpcMemHandle_t local_handle;
+ CUDA_CALL(cudaIpcGetMemHandle(&local_handle, ptr));
+ // All-gather IPC handles.
+ nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
+ std::vector<cudaIpcMemHandle_t> handles = AllGatherIPCHandles(ctx,
local_handle);
+ // Collect the all-gather results.
+ std::vector<void*> comm_ptrs(ctx->worker->num_workers);
+ for (size_t node_id = 0; node_id < handles.size(); ++node_id) {
+ if (static_cast<int>(node_id) == ctx->worker->worker_id) {
+ comm_ptrs[node_id] = ptr;
+ } else {
+ uint8_t* foreign_buffer;
+
CUDA_CALL(cudaIpcOpenMemHandle(reinterpret_cast<void**>(&foreign_buffer),
handles[node_id],
+ cudaIpcMemLazyEnablePeerAccess));
+ comm_ptrs[node_id] = foreign_buffer;
+ }
+ }
+ return std::make_pair(ptr, comm_ptrs);
+ }
+
+ /*! \brief Free the IPC memory pointers. */
+ void FreeIPCMemory(std::vector<void*> comm_ptrs, int worker_id) {
+ for (int i = 0; i < static_cast<int>(comm_ptrs.size()); ++i) {
+ if (i != worker_id) {
+ // Free ipc handle.
+ CUDA_CALL(cudaIpcCloseMemHandle(comm_ptrs[i]));
+ } else {
+ // Free local buffer.
+ CUDA_CALL(cudaFree(comm_ptrs[i]));
+ }
+ }
+ }
+
+ /*! \brief The mapping from local CUDA memory pointer to its allocated
CUDAIPCMemory object. */
+ std::unordered_map<void*, CUDAIPCMemory> ipc_memory_map_;
+};
+
+/*!
+ * \brief Allocate a storage object with CUDA IPC memory.
+ * \param buffer_shape The shape of the storage to allocate.
+ * \param dtype_hint The dtype of the storage to allocate.
+ * \return The allocated storage object with internal CUDA IPC memory buffer.
+ */
+memory::Storage IPCAllocStorage(ShapeTuple buffer_shape, DLDataType
dtype_hint) {
+ auto storage_obj =
runtime::SimpleObjAllocator().make_object<memory::StorageObj>();
+ nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get();
+ Device device{DLDeviceType::kDLCUDA, nccl_ctx->device_id};
+ CUDAIPCMemoryAllocator* allocator = CUDAIPCMemoryAllocator::Global();
+ storage_obj->buffer = CUDAIPCMemoryAllocator::Global()->Alloc(
+ device, std::move(buffer_shape), dtype_hint, /*mem_scope=*/"ipc_memory");
+ storage_obj->allocator = allocator;
+ memory::Storage storage(storage_obj);
+ return storage;
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.alloc_storage").set_body_typed(IPCAllocStorage);
+
+TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear").set_body_typed([]()
{
+ CUDAIPCMemoryAllocator::Global()->Clear();
+});
+
+/******************** CUDAIPCMemoryObj ********************/
+
+TVM_REGISTER_OBJECT_TYPE(CUDAIPCMemoryObj);
+
+// Direct to CUDAIPCMemoryAllocator::Global.
+memory::Allocator* CUDAIPCMemory::GlobalAllocator() { return
CUDAIPCMemoryAllocator::Global(); }
+
+// Direct to CUDAIPCMemoryAllocator::GlobalGetIPCMemoryFromDevicePtr.
+CUDAIPCMemory CUDAIPCMemory::GetIPCMemoryFromDevicePtr(void* ptr) {
+ return CUDAIPCMemoryAllocator::Global()->GetIPCMemoryFromDevicePtr(ptr);
+}
+
+} // namespace cuda_ipc
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc
b/src/runtime/disco/cuda_ipc/custom_allreduce.cc
new file mode 100644
index 0000000000..e9be5973e1
--- /dev/null
+++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda_runtime.h>
+#include <tvm/runtime/disco/cuda_ipc_memory.h>
+#include <tvm/runtime/memory/memory_manager.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h"
+#include "../nccl/nccl_context.h"
+
+namespace tvm {
+namespace runtime {
+namespace nccl {
+namespace cuda_ipc {
+
+using tvm::runtime::cuda_ipc::CUDAIPCMemory;
+
+/*! \brief Compute the size (i.e., number of elements) of the input tensor. */
+inline int64_t TensorSize(const DLTensor* tensor) {
+ int64_t size = 1;
+ for (int i = tensor->ndim - 1; i >= 0; --i) {
+ if (tensor->strides) {
+ ICHECK_EQ(tensor->strides[i], size);
+ }
+ size *= tensor->shape[i];
+ }
+ return size;
+}
+
+/*! \brief Check if customized all-reduce kernels can be applied. */
+inline bool CanApplyCustomAllReduce(int64_t num_elements, DLDataType dtype) {
+ // The customized all-reduce kernel has the following requirement(s).
+ return num_elements % (16 / ((dtype.bits * dtype.lanes + 7) / 8)) == 0;
+}
+
+/*! \brief Check if the two-shot customized all-reduce kernel can be applied.
*/
+inline bool CanApplyTwoShotAllReduce(int64_t num_elements, DLDataType dtype,
int num_workers) {
+ // The two-shot customized all-reduce kernel has the following
requirement(s).
+ return (num_elements / num_workers) % (16 / ((dtype.bits * dtype.lanes + 7)
/ 8)) == 0;
+}
+
+/*!
+ * \brief Customized all-reduce kernel backed by CUDA IPC memory.
+ * \param send The input tensor of all-reduce.
+ * \param strategy The all-reduce strategy. See AllReduceStrategyType for
detail.
+ * \param recv The output tensor of all-reduce.
+ */
+void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) {
+ int64_t num_elements = TensorSize(send);
+ nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
+
+ if (!CanApplyCustomAllReduce(num_elements, send->dtype)) {
+ // Dispatch to nccl AllReduce if the customized all-reduce cannot apply.
+ deviceStream_t stream = ctx->GetDefaultStream();
+ NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements,
+
/*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)),
+ /*op=*/ncclSum, ctx->comm, stream));
+ return;
+ }
+
+ // Initialize the all-reduce kernel arguments.
+ tensorrt_llm::AllReduceParams params;
+ params.ranks_per_node = ctx->worker->num_workers;
+ params.rank = ctx->worker->worker_id;
+ params.local_rank = ctx->worker->worker_id;
+ CUDAIPCMemory ipc_memory =
CUDAIPCMemory::GetIPCMemoryFromDevicePtr(send->data);
+ params.barrier_flag = ipc_memory->barrier_flag++;
+ for (int i = 0; i < ctx->worker->num_workers; ++i) {
+ params.peer_comm_buffer_ptrs[i] = ipc_memory->remote_data[i];
+ }
+ for (int i = 0; i < ctx->worker->num_workers; ++i) {
+ params.peer_barrier_ptrs_in[i] =
reinterpret_cast<uint32_t*>(ipc_memory->barrier_in[i]);
+ }
+ for (int i = 0; i < ctx->worker->num_workers; ++i) {
+ params.peer_barrier_ptrs_out[i] =
reinterpret_cast<uint32_t*>(ipc_memory->barrier_out[i]);
+ }
+
+ tensorrt_llm::AllReduceStrategyType strategy_ =
+ static_cast<tensorrt_llm::AllReduceStrategyType>(strategy);
+ if (!CanApplyTwoShotAllReduce(num_elements, send->dtype,
ctx->worker->num_workers)) {
+ // Two-shot all-reduce does not support this case.
+ // So we fallback to the one-shot strategy.
+ strategy_ = tensorrt_llm::AllReduceStrategyType::ONESHOT;
+ }
+
+ tensorrt_llm::customAllReduce(params, recv->data, num_elements, send->dtype,
strategy_,
+ ctx->GetDefaultStream());
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.custom_allreduce").set_body_typed(CustomAllReduce);
+
+} // namespace cuda_ipc
+} // namespace nccl
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 61c307c673..b5fc1053b2 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -17,12 +17,6 @@
* under the License.
*/
-#include <dlpack/dlpack.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/disco/builtin.h>
-#include <tvm/runtime/disco/session.h>
-#include <tvm/runtime/registry.h>
-
#include <cstring>
#include <mutex>
#include <sstream>
@@ -30,92 +24,15 @@
#include "../../../support/process_id.h"
#include "../utils.h"
-
-/* `TVM_NCCL_RCCL_SWITCH` is set to 0 for NCCL, 1 for RCCL */
-#ifndef TVM_NCCL_RCCL_SWITCH
-#define TVM_NCCL_RCCL_SWITCH 0
-#endif
-#if TVM_NCCL_RCCL_SWITCH == 0
-#include <nccl.h>
-
-#include "../../cuda/cuda_common.h"
-#else
-#include <rccl/rccl.h>
-
-#include "../../rocm/rocm_common.h"
-#endif
+#include "nccl_context.h"
namespace tvm {
namespace runtime {
namespace nccl {
-#define NCCL_CALL(cmd) \
- do { \
- auto r = (cmd); \
- if (r != ncclSuccess) { \
- LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
- } \
- } while (0)
-
-#if TVM_NCCL_RCCL_SWITCH == 0
-
-#define TVM_DISCO_DEVICE_NAME "cuda"
-#define TVM_DISCO_CCL_NAME "nccl"
-
-using deviceStream_t = cudaStream_t;
-const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
-inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
-inline void StreamSynchronize(deviceStream_t stream) {
CUDA_CALL(cudaStreamSynchronize(stream)); }
-inline void StreamCreate(deviceStream_t* stream) {
CUDA_CALL(cudaStreamCreate(stream)); }
-inline void StreamDestroy(deviceStream_t stream) {
CUDA_CALL(cudaStreamDestroy(stream)); }
-
-#else
-
-#define TVM_DISCO_DEVICE_NAME "rocm"
-#define TVM_DISCO_CCL_NAME "rccl"
-
-using deviceStream_t = hipStream_t;
-const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
-inline void SetDevice(int device_id) { ROCM_CALL(hipSetDevice(device_id)); }
-inline void StreamSynchronize(deviceStream_t stream) {
ROCM_CALL(hipStreamSynchronize(stream)); }
-inline void StreamCreate(deviceStream_t* stream) {
ROCM_CALL(hipStreamCreate(stream)); }
-inline void StreamDestroy(deviceStream_t stream) {
ROCM_CALL(hipStreamDestroy(stream)); }
-
-#endif
-
-inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
- if (dtype == DataType::Int(8)) {
- return ncclInt8;
- }
- if (dtype == DataType::UInt(8)) {
- return ncclUint8;
- }
- if (dtype == DataType::Int(32)) {
- return ncclInt32;
- }
- if (dtype == DataType::UInt(32)) {
- return ncclUint32;
- }
- if (dtype == DataType::Int(64)) {
- return ncclInt64;
- }
- if (dtype == DataType::UInt(64)) {
- return ncclUint64;
- }
- if (dtype == DataType::Float(16)) {
- return ncclFloat16;
- }
- if (dtype == DataType::Float(32)) {
- return ncclFloat32;
- }
- if (dtype == DataType::Float(64)) {
- return ncclFloat64;
- }
- if (dtype == DataType::BFloat(16)) {
- return ncclBfloat16;
- }
- LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
- throw;
+CCLThreadLocalContext* CCLThreadLocalContext::Get() {
+ thread_local static CCLThreadLocalContext ctx;
+ return &ctx;
}
inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
@@ -135,32 +52,6 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
throw;
}
-struct CCLThreadLocalContext {
- DiscoWorker* worker;
- int device_id;
- deviceStream_t default_stream = nullptr;
- ncclComm_t comm;
-
- void Clear() {
- NCCL_CALL(ncclCommDestroy(comm));
- if (default_stream != nullptr) {
- StreamDestroy(default_stream);
- }
- }
-
- deviceStream_t GetDefaultStream() {
- const auto* func = tvm::runtime::Registry::Get("runtime.get_"
TVM_DISCO_DEVICE_NAME "_stream");
- ICHECK(func != nullptr);
- deviceStream_t stream = static_cast<deviceStream_t>((*func)().operator
void*());
- return stream == nullptr ? default_stream : stream;
- }
-
- static CCLThreadLocalContext* Get() {
- thread_local static CCLThreadLocalContext ctx;
- return &ctx;
- }
-};
-
void InitCCL(Session sess, IntTuple device_ids) {
DRef func = sess->GetGlobalFunc("runtime.disco." TVM_DISCO_CCL_NAME
".init_ccl_per_worker");
DLOG(INFO) << "Initializing " TVM_DISCO_CCL_NAME " with devices: " <<
device_ids;
diff --git a/src/runtime/disco/nccl/nccl_context.h
b/src/runtime/disco/nccl/nccl_context.h
new file mode 100644
index 0000000000..9d1b8b933a
--- /dev/null
+++ b/src/runtime/disco/nccl/nccl_context.h
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_RUNTIME_DISCO_NCCL_NCCL_CONTEXT_H_
+#define TVM_RUNTIME_DISCO_NCCL_NCCL_CONTEXT_H_
+
+#include <dlpack/dlpack.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/builtin.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../../support/process_id.h"
+#include "../utils.h"
+
+/* `TVM_NCCL_RCCL_SWITCH` is set to 0 for NCCL, 1 for RCCL */
+#ifndef TVM_NCCL_RCCL_SWITCH
+#define TVM_NCCL_RCCL_SWITCH 0
+#endif
+#if TVM_NCCL_RCCL_SWITCH == 0
+#include <nccl.h>
+
+#include "../../cuda/cuda_common.h"
+#else
+#include <rccl/rccl.h>
+
+#include "../../rocm/rocm_common.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace nccl {
+
+#define NCCL_CALL(cmd) \
+ do { \
+ auto r = (cmd); \
+ if (r != ncclSuccess) { \
+ LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \
+ } \
+ } while (0)
+
+#if TVM_NCCL_RCCL_SWITCH == 0
+
+#define TVM_DISCO_DEVICE_NAME "cuda"
+#define TVM_DISCO_CCL_NAME "nccl"
+
+using deviceStream_t = cudaStream_t;
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
+inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) {
CUDA_CALL(cudaStreamSynchronize(stream)); }
+inline void StreamCreate(deviceStream_t* stream) {
CUDA_CALL(cudaStreamCreate(stream)); }
+inline void StreamDestroy(deviceStream_t stream) {
CUDA_CALL(cudaStreamDestroy(stream)); }
+
+#else
+
+#define TVM_DISCO_DEVICE_NAME "rocm"
+#define TVM_DISCO_CCL_NAME "rccl"
+
+using deviceStream_t = hipStream_t;
+const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLROCM;
+inline void SetDevice(int device_id) { ROCM_CALL(hipSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) {
ROCM_CALL(hipStreamSynchronize(stream)); }
+inline void StreamCreate(deviceStream_t* stream) {
ROCM_CALL(hipStreamCreate(stream)); }
+inline void StreamDestroy(deviceStream_t stream) {
ROCM_CALL(hipStreamDestroy(stream)); }
+
+#endif
+
+/*! \brief Convert DataType to ncclDataType. */
+inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
+ if (dtype == DataType::Int(8)) {
+ return ncclInt8;
+ }
+ if (dtype == DataType::UInt(8)) {
+ return ncclUint8;
+ }
+ if (dtype == DataType::Int(32)) {
+ return ncclInt32;
+ }
+ if (dtype == DataType::UInt(32)) {
+ return ncclUint32;
+ }
+ if (dtype == DataType::Int(64)) {
+ return ncclInt64;
+ }
+ if (dtype == DataType::UInt(64)) {
+ return ncclUint64;
+ }
+ if (dtype == DataType::Float(16)) {
+ return ncclFloat16;
+ }
+ if (dtype == DataType::Float(32)) {
+ return ncclFloat32;
+ }
+ if (dtype == DataType::Float(64)) {
+ return ncclFloat64;
+ }
+ if (dtype == DataType::BFloat(16)) {
+ return ncclBfloat16;
+ }
+ LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
+ throw;
+}
+
+struct CCLThreadLocalContext {
+ DiscoWorker* worker;
+ int device_id;
+ deviceStream_t default_stream = nullptr;
+ ncclComm_t comm;
+
+ void Clear() {
+ NCCL_CALL(ncclCommDestroy(comm));
+ if (default_stream != nullptr) {
+ StreamDestroy(default_stream);
+ }
+ }
+
+ deviceStream_t GetDefaultStream() {
+ const auto* func = tvm::runtime::Registry::Get("runtime.get_"
TVM_DISCO_DEVICE_NAME "_stream");
+ ICHECK(func != nullptr);
+ deviceStream_t stream = static_cast<deviceStream_t>((*func)().operator
void*());
+ return stream == nullptr ? default_stream : stream;
+ }
+
+ static CCLThreadLocalContext* Get();
+};
+
+} // namespace nccl
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_DISCO_NCCL_NCCL_CONTEXT_H_
diff --git a/src/runtime/memory/memory_manager.cc
b/src/runtime/memory/memory_manager.cc
index 5c50fe08ae..0607697e6b 100644
--- a/src/runtime/memory/memory_manager.cc
+++ b/src/runtime/memory/memory_manager.cc
@@ -43,9 +43,10 @@ static void BufferDeleter(Object* obj) {
delete ptr;
}
-Storage::Storage(Buffer buffer) {
+Storage::Storage(Buffer buffer, Allocator* allocator) {
auto n = make_object<StorageObj>();
n->buffer = std::move(buffer);
+ n->allocator = allocator;
data_ = std::move(n);
}
@@ -203,9 +204,13 @@ NDArray Allocator::Empty(ShapeTuple shape, DLDataType
dtype, DLDevice dev,
return NDArray(GetObjectPtr<Object>(container));
}
+bool Allocator::AllowMemoryScope(const std::string& mem_scope) const {
+ return mem_scope.empty() || mem_scope == "global";
+}
+
Buffer Allocator::Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
const std::string& mem_scope) {
- if (mem_scope.empty() || mem_scope == "global") {
+ if (AllowMemoryScope(mem_scope)) {
// by default, we can always redirect to the flat memory allocations
NDArray::Container container(nullptr, shape, type_hint, dev);
size_t size = DeviceAPI::Get(dev)->GetDataSize(container.dl_tensor);
diff --git a/src/runtime/memory/naive_allocator.h
b/src/runtime/memory/naive_allocator.h
index 8d8d2e9d88..6d8e90fed9 100644
--- a/src/runtime/memory/naive_allocator.h
+++ b/src/runtime/memory/naive_allocator.h
@@ -57,7 +57,7 @@ class NaiveAllocator final : public Allocator {
}
nbytes *= (type_hint.bits * type_hint.lanes + 7) / 8;
buf.device = dev;
- if (mem_scope.empty() || mem_scope == "global") {
+ if (AllowMemoryScope(mem_scope)) {
auto tmp_buf = Allocator::Alloc(dev, shape, type_hint, mem_scope);
buf.size = tmp_buf.size;
buf.data = tmp_buf.data;
diff --git a/src/runtime/memory/pooled_allocator.h
b/src/runtime/memory/pooled_allocator.h
index 9ebe1939be..c96c87a73a 100644
--- a/src/runtime/memory/pooled_allocator.h
+++ b/src/runtime/memory/pooled_allocator.h
@@ -36,7 +36,7 @@ namespace tvm {
namespace runtime {
namespace memory {
-class PooledAllocator final : public Allocator {
+class PooledAllocator : public Allocator {
public:
static constexpr size_t kDefaultPageSize = 4096;
@@ -60,12 +60,12 @@ class PooledAllocator final : public Allocator {
buf.size = size;
buf.alloc_type = kPooled;
try {
- buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, size, alignment,
type_hint);
+ buf.data = DeviceAllocDataSpace(dev, size, alignment, type_hint);
} catch (InternalError& err) {
LOG(WARNING) << "PooledAllocator got InternalError during allocation: "
<< err.message();
LOG(WARNING) << "Trying to release all unused memory and reallocate...";
ReleaseAll();
- buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, size, alignment,
type_hint);
+ buf.data = DeviceAllocDataSpace(dev, size, alignment, type_hint);
}
used_memory_.fetch_add(size, std::memory_order_relaxed);
@@ -75,7 +75,7 @@ class PooledAllocator final : public Allocator {
Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
const std::string& mem_scope) override {
- if (mem_scope.empty() || mem_scope == "global") {
+ if (AllowMemoryScope(mem_scope)) {
return Allocator::Alloc(dev, shape, type_hint, mem_scope);
}
LOG(FATAL) << "This alloc should be implemented";
@@ -95,13 +95,22 @@ class PooledAllocator final : public Allocator {
size_t UsedMemory() const override { return
used_memory_.load(std::memory_order_relaxed); }
- private:
- void ReleaseAll() {
+ protected:
+ virtual void* DeviceAllocDataSpace(Device dev, size_t nbytes, size_t
alignment,
+ DLDataType type_hint) {
+ return DeviceAPI::Get(dev)->AllocDataSpace(dev, nbytes, alignment,
type_hint);
+ }
+
+ virtual void DeviceFreeDataSpace(Device dev, void* ptr) {
+ DeviceAPI::Get(dev)->FreeDataSpace(dev, ptr);
+ }
+
+ virtual void ReleaseAll() {
std::lock_guard<std::recursive_mutex> lock(mu_);
for (auto const& it : memory_pool_) {
auto const& pool = it.second;
for (auto const& buf : pool) {
- DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data);
+ DeviceFreeDataSpace(buf.device, buf.data);
}
}
memory_pool_.clear();
@@ -109,7 +118,7 @@ class PooledAllocator final : public Allocator {
VLOG(1) << "release all buffers";
}
- private:
+ protected:
size_t page_size_;
std::atomic<size_t> used_memory_;
std::unordered_map<size_t, std::vector<Buffer>> memory_pool_;
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 15e3edf1cb..17061c3297 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -349,6 +349,7 @@ Storage VMAllocStorage(void* ctx_ptr, ShapeTuple
buffer_shape, Index device_inde
storage_obj->buffer =
alloc->Alloc(vm->devices[device_index], buffer_shape, dtype_hint,
mem_scope);
+ storage_obj->allocator = alloc;
Storage storage(storage_obj);
return storage;
}
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 75e1ec5636..dfde076bfc 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -847,6 +847,7 @@ void VirtualMachine::RunLoop(const std::vector<Index>&
output_tensor_reg_indices
instr.alloc_storage.shape + instr.alloc_storage.ndim);
storage_obj->buffer = allocator->Alloc(device, ShapeTuple(shape_),
instr.alloc_storage.dtype_hint, mem_scope);
+ storage_obj->allocator = allocator;
} else {
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
auto alignment = instr.alloc_storage.alignment;
@@ -855,6 +856,7 @@ void VirtualMachine::RunLoop(const std::vector<Index>&
output_tensor_reg_indices
<< ", device_index=" << instr.alloc_storage.device_index;
storage_obj->buffer =
allocator->Alloc(device, size, alignment,
instr.alloc_storage.dtype_hint);
+ storage_obj->allocator = allocator;
}
Storage storage(storage_obj);
WriteRegister(instr.dst, storage);
diff --git a/tests/python/disco/test_custom_allreduce.py
b/tests/python/disco/test_custom_allreduce.py
new file mode 100644
index 0000000000..47b5f9590a
--- /dev/null
+++ b/tests/python/disco/test_custom_allreduce.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import enum
+from functools import reduce
+from itertools import product
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.runtime import DataType, ShapeTuple, disco
+from tvm.runtime.disco import Session
+
+
+class AllReduceStrategyType(enum.IntEnum):
+ ONESHOT = 1
+ TWOSHOT = 2
+
+
+_shapes = [(2, 3), (3, 4), (128, 128)]
+
+_strategies = [
+ AllReduceStrategyType.ONESHOT,
+ AllReduceStrategyType.TWOSHOT,
+]
+
+_ccl = [ccl for ccl in tvm.get_global_func("runtime.disco.compiled_ccl")() if
ccl == "nccl"]
+
+
[email protected]("shape", _shapes)
[email protected]("ccl", _ccl)
[email protected]("strategy", _strategies)
+def test_allreduce(shape, ccl, strategy):
+ devices = [0, 1]
+ sess: Session = disco.ProcessSession(num_workers=len(devices))
+ sess.init_ccl(ccl, *devices)
+
+ num_elements = reduce(lambda x, y: x * y, shape)
+ dtype = "float32"
+ falloc_ipc_storage =
sess.get_global_func("runtime.disco.cuda_ipc.alloc_storage")
+ falloc_tensor = sess.get_global_func("vm.builtin.alloc_tensor")
+ fallreduce =
sess.get_global_func("runtime.disco.cuda_ipc.custom_allreduce")
+ d_storage = sess.call_packed(falloc_ipc_storage, ShapeTuple(shape),
DataType(dtype))
+ d_input = sess.call_packed(falloc_tensor, d_storage, 0, ShapeTuple(shape),
DataType(dtype))
+
+ array_1 = np.arange(num_elements, dtype="float32").reshape(*shape)
+ array_2 = np.arange(start=1, stop=-(num_elements - 1), step=-1,
dtype="float32").reshape(*shape)
+ d_input.debug_copy_from(0, array_1)
+ d_input.debug_copy_from(1, array_2)
+ d_output = sess.empty(shape, "float32")
+
+ sess.call_packed(fallreduce, d_input, strategy, d_output)
+ result_1 = d_output.debug_get_from_remote(0).numpy()
+ result_2 = d_output.debug_get_from_remote(1).numpy()
+ expected = np.add(array_1, array_2)
+ np.testing.assert_equal(result_1, expected)
+ np.testing.assert_equal(result_2, expected)
+
+
+if __name__ == "__main__":
+ for shape, strategy in product(_shapes, _strategies):
+ test_allreduce(shape, "nccl", strategy)