This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f94890112 [NVSHMEM] Enable nvshmem memory allocation (#17415)
4f94890112 is described below

commit 4f948901124761ce27dba4f0e4b752480315893c
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Sep 30 08:47:36 2024 -0700

    [NVSHMEM] Enable nvshmem memory allocation (#17415)
    
    This PR add the support of nvshmem memory allocation, and integrates it 
into disco.
---
 .../contrib/nvshmem/{nvshmem.cc => init.cc}        |   2 +
 src/runtime/contrib/nvshmem/memory_allocator.cc    | 104 +++++++++++++++++++++
 tests/python/disco/test_nvshmem.py                 |  45 +++++++--
 3 files changed, 145 insertions(+), 6 deletions(-)

diff --git a/src/runtime/contrib/nvshmem/nvshmem.cc 
b/src/runtime/contrib/nvshmem/init.cc
similarity index 96%
rename from src/runtime/contrib/nvshmem/nvshmem.cc
rename to src/runtime/contrib/nvshmem/init.cc
index 985ba55107..50fdde4c49 100644
--- a/src/runtime/contrib/nvshmem/nvshmem.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -54,6 +54,8 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
   }
   nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
   nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
+  int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
+  CUDA_CALL(cudaSetDevice(mype_node));
   LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
            << ", npes=" << nvshmem_n_pes();
 }
diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc 
b/src/runtime/contrib/nvshmem/memory_allocator.cc
new file mode 100644
index 0000000000..89d56ed3dc
--- /dev/null
+++ b/src/runtime/contrib/nvshmem/memory_allocator.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <nvshmem.h>
+#include <nvshmemx.h>
+#include <tvm/runtime/memory/memory_manager.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <thread>
+
+#include "../../cuda/cuda_common.h"
+#include "../../memory/pooled_allocator.h"
+
+namespace tvm {
+namespace runtime {
+
+using tvm::runtime::memory::Buffer;
+using tvm::runtime::memory::PooledAllocator;
+
+/*!
+ * \brief The memory allocator of NVSHMEM.
+ * Overriding PooledAllocator for efficient memory management.
+ */
+class NVSHMEMAllocator final : public PooledAllocator {
+ public:
+  explicit NVSHMEMAllocator() : PooledAllocator() {}
+
+  ~NVSHMEMAllocator() { PooledAllocator::ReleaseAll(); }
+
+  void Clear() final { PooledAllocator::ReleaseAll(); }
+
+  bool AllowMemoryScope(const std::string& mem_scope) const final {
+    // The allowed memory scope of NVSHMEM is "nvshmem";
+    return mem_scope == "nvshmem";
+  }
+
+  /*! \brief Return the global NVSHMEM singleton allocator. */
+  static NVSHMEMAllocator* Global() {
+    static NVSHMEMAllocator* allocator = new NVSHMEMAllocator();
+    return allocator;
+  }
+
+  NDArray Empty(ShapeTuple shape, DataType dtype, Device device) {
+    NDArray::Container* container = new NDArray::Container(nullptr, shape, 
dtype, device);
+    container->SetDeleter([](Object* obj) {
+      auto* ptr = static_cast<NDArray::Container*>(obj);
+      ICHECK(ptr->manager_ctx != nullptr);
+      Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
+      NVSHMEMAllocator::Global()->Free(*(buffer));
+      delete buffer;
+      delete ptr;
+    });
+    Buffer* buffer = new Buffer;
+    *buffer = PooledAllocator::Alloc(device, shape, dtype, String("nvshmem"));
+    container->manager_ctx = reinterpret_cast<void*>(buffer);
+    container->dl_tensor.data = buffer->data;
+    return NDArray(GetObjectPtr<Object>(container));
+  }
+
+ private:
+  void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment,
+                             DLDataType type_hint) final {
+    ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA)
+        << "nvshmem can only allocate cuda device memory space.";
+    ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code == 
DLDataTypeCode::kDLUInt ||
+           type_hint.code == DLDataTypeCode::kDLFloat)
+        << "nvshmem can only allocate tensor with int, usingned int or float 
data types.";
+    return nvshmem_align(alignment, size);
+  }
+
+  void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); }
+};
+
+NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) {
+  return NVSHMEMAllocator::Global()->Empty(shape, dtype, device);
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty);
+
+void NVSHMEMFinalize() {
+  NVSHMEMAllocator::Global()->Clear();
+  nvshmem_finalize();
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize);
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/tests/python/disco/test_nvshmem.py 
b/tests/python/disco/test_nvshmem.py
index 0b16fe9361..b304d145aa 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -23,6 +23,9 @@ import pytest
 import subprocess
 import threading
 import sys
+from multiprocessing import Process
+from typing import Any, Callable, List
+
 
 import tvm
 import tvm.testing
@@ -82,8 +85,6 @@ class SocketSessionTester:
         thread.join()
 
     def __del__(self):
-        for node in self.remote_nodes:
-            node.kill()
         if self.sess is not None:
             self.sess.shutdown()
             del self.sess
@@ -98,17 +99,49 @@ def create_socket_session(num_workers):
     return _SOCKET_SESSION_TESTER.sess
 
 
[email protected]("num_workers", [2, 4])
-def test_nvshmem_init(num_workers):
+def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
     if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
         return
-    sess = create_socket_session(num_workers=num_workers)
+
+    sess = session_kind(num_workers=num_workers)
     f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
     uid = f_init_nvshmem_uid()
     init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
     init_dfunc(uid, num_workers)
     sess.sync_worker_0()
+    finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+    finalize_dfunc()
+    sess.sync_worker_0()
+
+
+def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
+    if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
+        return
+
+    device = tvm.cuda()
+    sess = session_kind(num_workers=num_workers)
+    f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+    uid = f_init_nvshmem_uid()
+    init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    init_dfunc(uid, num_workers)
+    sess.sync_worker_0()
+    empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
+    a = empty_dfunc(ShapeTuple((32, 64)), "float32", device)
+    b = empty_dfunc(ShapeTuple((64, 32)), "float32", device)
+    sess.sync_worker_0()
+    finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+    finalize_dfunc()
+    sess.sync_worker_0()
 
 
 if __name__ == "__main__":
-    tvm.testing.main()
+    # After the first call to `nvshmem_init`, a subsequent call to 
`nvshmem_init`
+    # or `nvshmem_init_thread` in the same program results in undefined 
behavior.
+    # So we always create a new process to run the test. Then no repeated 
nvshmem
+    # init happens in the same process, since the worker0 may share the same 
process.
+    for session_kind in [create_socket_session, di.ProcessSession]:
+        for num_workers in [2, 4]:
+            for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
+                p = Process(target=test_func, args=[session_kind, num_workers])
+                p.start()
+                p.join()

Reply via email to