This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4f94890112 [NVSHMEM] Enable nvshmem memory allocation (#17415)
4f94890112 is described below
commit 4f948901124761ce27dba4f0e4b752480315893c
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Sep 30 08:47:36 2024 -0700
[NVSHMEM] Enable nvshmem memory allocation (#17415)
This PR add the support of nvshmem memory allocation, and integrates it
into disco.
---
.../contrib/nvshmem/{nvshmem.cc => init.cc} | 2 +
src/runtime/contrib/nvshmem/memory_allocator.cc | 104 +++++++++++++++++++++
tests/python/disco/test_nvshmem.py | 45 +++++++--
3 files changed, 145 insertions(+), 6 deletions(-)
diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc
b/src/runtime/contrib/nvshmem/init.cc
similarity index 96%
rename from src/runtime/contrib/nvshmem/nvshmem.cc
rename to src/runtime/contrib/nvshmem/init.cc
index 985ba55107..50fdde4c49 100644
--- a/src/runtime/contrib/nvshmem/nvshmem.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -54,6 +54,8 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
}
nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
+ int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
+ CUDA_CALL(cudaSetDevice(mype_node));
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}
diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc
b/src/runtime/contrib/nvshmem/memory_allocator.cc
new file mode 100644
index 0000000000..89d56ed3dc
--- /dev/null
+++ b/src/runtime/contrib/nvshmem/memory_allocator.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <nvshmem.h>
+#include <nvshmemx.h>
+#include <tvm/runtime/memory/memory_manager.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <thread>
+
+#include "../../cuda/cuda_common.h"
+#include "../../memory/pooled_allocator.h"
+
+namespace tvm {
+namespace runtime {
+
+using tvm::runtime::memory::Buffer;
+using tvm::runtime::memory::PooledAllocator;
+
+/*!
+ * \brief The memory allocator of NVSHMEM.
+ * Overriding PooledAllocator for efficient memory management.
+ */
+class NVSHMEMAllocator final : public PooledAllocator {
+ public:
+ explicit NVSHMEMAllocator() : PooledAllocator() {}
+
+ ~NVSHMEMAllocator() { PooledAllocator::ReleaseAll(); }
+
+ void Clear() final { PooledAllocator::ReleaseAll(); }
+
+ bool AllowMemoryScope(const std::string& mem_scope) const final {
+ // The allowed memory scope of NVSHMEM is "nvshmem";
+ return mem_scope == "nvshmem";
+ }
+
+ /*! \brief Return the global NVSHMEM singleton allocator. */
+ static NVSHMEMAllocator* Global() {
+ static NVSHMEMAllocator* allocator = new NVSHMEMAllocator();
+ return allocator;
+ }
+
+ NDArray Empty(ShapeTuple shape, DataType dtype, Device device) {
+ NDArray::Container* container = new NDArray::Container(nullptr, shape,
dtype, device);
+ container->SetDeleter([](Object* obj) {
+ auto* ptr = static_cast<NDArray::Container*>(obj);
+ ICHECK(ptr->manager_ctx != nullptr);
+ Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
+ NVSHMEMAllocator::Global()->Free(*(buffer));
+ delete buffer;
+ delete ptr;
+ });
+ Buffer* buffer = new Buffer;
+ *buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem"));
+ container->manager_ctx = reinterpret_cast<void*>(buffer);
+ container->dl_tensor.data = buffer->data;
+ return NDArray(GetObjectPtr<Object>(container));
+ }
+
+ private:
+ void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment,
+ DLDataType type_hint) final {
+ ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA)
+ << "nvshmem can only allocate cuda device memory space.";
+ ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code ==
DLDataTypeCode::kDLUInt ||
+ type_hint.code == DLDataTypeCode::kDLFloat)
+ << "nvshmem can only allocate tensor with int, usingned int or float
data types.";
+ return nvshmem_align(alignment, size);
+ }
+
+ void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); }
+};
+
+NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) {
+ return NVSHMEMAllocator::Global()->Empty(shape, dtype, device);
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty);
+
+void NVSHMEMFinalize() {
+ NVSHMEMAllocator::Global()->Clear();
+ nvshmem_finalize();
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize);
+
+} // namespace runtime
+} // namespace tvm
diff --git a/tests/python/disco/test_nvshmem.py
b/tests/python/disco/test_nvshmem.py
index 0b16fe9361..b304d145aa 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -23,6 +23,9 @@ import pytest
import subprocess
import threading
import sys
+from multiprocessing import Process
+from typing import Any, Callable, List
+
import tvm
import tvm.testing
@@ -82,8 +85,6 @@ class SocketSessionTester:
thread.join()
def __del__(self):
- for node in self.remote_nodes:
- node.kill()
if self.sess is not None:
self.sess.shutdown()
del self.sess
@@ -98,17 +99,49 @@ def create_socket_session(num_workers):
return _SOCKET_SESSION_TESTER.sess
[email protected]("num_workers", [2, 4])
-def test_nvshmem_init(num_workers):
+def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
return
- sess = create_socket_session(num_workers=num_workers)
+
+ sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
init_dfunc(uid, num_workers)
sess.sync_worker_0()
+ finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_dfunc()
+ sess.sync_worker_0()
+
+
+def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
+ if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
+ return
+
+ device = tvm.cuda()
+ sess = session_kind(num_workers=num_workers)
+ f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ uid = f_init_nvshmem_uid()
+ init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ init_dfunc(uid, num_workers)
+ sess.sync_worker_0()
+ empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
+ a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
+ b = empty_dfunc(ShapeTuple((64, 32)), "float32", device)
+ sess.sync_worker_0()
+ finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_dfunc()
+ sess.sync_worker_0()
if __name__ == "__main__":
- tvm.testing.main()
+ # After the first call to `nvshmem_init`, a subsequent call to
`nvshmem_init`
+ # or `nvshmem_init_thread` in the same program results in undefined
behavior.
+ # So we always create a new process to run the test. Then no repeated
nvshmem
+ # init happens in the same process, since the worker0 may share the same
process.
+ for session_kind in [create_socket_session, di.ProcessSession]:
+ for num_workers in [2, 4]:
+ for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
+ p = Process(target=test_func, args=[session_kind, num_workers])
+ p.start()
+ p.join()