This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 40b9a926b9 [Disco] Pipe-based Multi-processing Session (#15727)
40b9a926b9 is described below
commit 40b9a926b9df375a6ca5ce51ea93e4d849b85517
Author: Junru Shao <[email protected]>
AuthorDate: Thu Sep 14 04:43:44 2023 -0700
[Disco] Pipe-based Multi-processing Session (#15727)
This PR introduces `ProcessSession`, a new session implementation based
on multi-processing.
`ProcessSession` shares exactly the same communication protocol with
`ThreadedSession`, but all workers except for worker 0 are launched in a
separate process than thread. Workers communicate with the controller
via pipe provided by the OS, rather than SPSC message queue between
threads.
In our implementation, Python's `subproces.popen` is used to create
subprocesses, and the Python executable, or more specifically,
`sys.executable` calls into `tvm.exec.disco_worker` as the entrypoint.
Besides the launching logic that is only executed once in the very
beginning, the rest of the implementation resides in a C++-only
environment, including reads/writes to pipe file descriptors,
serialization and deserialization of messages, worker interpretation of
each message, etc.
Detailed engineering elements included in this PR:
- Refactors the MinRPC-based communication protocol out to be shared by
`ProcessSession` and `ThreadedSession` as `protocol.h`;
- Refactors a controller-side worker thread into `DiscoWorkerThread`,
which is shared by both session implementation to launch worker-0;
- Added two instructions `kDebugGetFromRemote` and `kDebugSetRegister`,
which are used to communicate with workers other than worker-0 in
debug mode;
- Introduces multi-processing infra including: `tvm.exec.disco_worker`
serving as the entrypoint that launches workers, and
`tvm/runtime/disco/process_pool.py` that exposes APIs to launch worker
processes. `tvm.exec.disco_worker` calls into a global function
`runtime.disco.WorkerProcess` that executes the worker main loop in
pure C++;
- Introduces `src/support/process_id.h` that provides cross-platform pid
and tid printing utilities;
- Refactors Disco's NCCL integration that get rids of initialized-once
global NCCL context, and switches to broadcasting `ncclUniqueId` from
controller to all workers, and then create NCCL communicators in each
worker thread/process accordingly. This is a thread/process-agnostic
way of using NCCL.
---
include/tvm/runtime/disco/session.h | 50 +++++-
python/tvm/exec/disco_worker.py | 51 +++++++
python/tvm/runtime/disco/__init__.py | 9 +-
python/tvm/runtime/disco/process_pool.py | 180 ++++++++++++++++++++++
python/tvm/runtime/disco/session.py | 27 +++-
python/tvm/testing/__init__.py | 3 +-
python/tvm/testing/disco.py | 53 +++++++
src/runtime/disco/bcast_session.cc | 7 +
src/runtime/disco/bcast_session.h | 2 +
src/runtime/disco/builtin.cc | 6 +-
src/runtime/disco/nccl/nccl.cc | 91 ++++-------
src/runtime/disco/nccl/utils.h | 2 +
src/runtime/disco/process_session.cc | 213 ++++++++++++++++++++++++++
src/runtime/disco/protocol.h | 254 +++++++++++++++++++++++++++++++
src/runtime/disco/session.cc | 11 +-
src/runtime/disco/threaded_session.cc | 128 ++++------------
src/runtime/disco/worker.cc | 55 ++++++-
src/runtime/disco/worker.h | 43 ++++++
src/support/process_id.h | 67 ++++++++
tests/python/disco/test_nccl.py | 92 ++++++-----
tests/python/disco/test_session.py | 95 +++++-------
21 files changed, 1149 insertions(+), 290 deletions(-)
diff --git a/include/tvm/runtime/disco/session.h
b/include/tvm/runtime/disco/session.h
index e28fb7144c..984ea026d8 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -72,6 +72,7 @@
#ifndef TVM_RUNTIME_DISCO_SESSION_H_
#define TVM_RUNTIME_DISCO_SESSION_H_
+#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
@@ -92,6 +93,8 @@ enum class DiscoAction : int32_t {
kSyncWorker = 4,
kCopyFromWorker0 = 5,
kCopyToWorker0 = 6,
+ kDebugGetFromRemote = 7,
+ kDebugSetRegister = 8,
};
/*! \brief Converts the enum class `DiscoAction` to string */
@@ -111,6 +114,10 @@ inline std::string DiscoAction2String(DiscoAction action) {
return "kCopyFromWorker0";
case DiscoAction::kCopyToWorker0:
return "kCopyToWorker0";
+ case DiscoAction::kDebugGetFromRemote:
+ return "kDebugGetFromRemote";
+ case DiscoAction::kDebugSetRegister:
+ return "kDebugSetRegister";
}
LOG(FATAL) << "ValueError: Unknown DiscoAction: " <<
static_cast<int>(action);
}
@@ -136,7 +143,7 @@ class DRefObj : public Object {
* \param worker_id The id of the worker to be copied to.
* \param source The NDArray to be copied.
*/
- void DebugCopyFrom(int worker_id, NDArray source);
+ inline void DebugCopyFrom(int worker_id, TVMArgValue source);
static constexpr const char* _type_key = "runtime.disco.DRef";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
@@ -213,6 +220,12 @@ class SessionObj : public Object {
virtual void SyncWorker(int worker_id) = 0;
/*! \brief Signal all the workers to shutdown */
virtual void Shutdown() = 0;
+ /*!
+ * \brief Initialize the data plane between workers.
+ * \param ccl The name of the communication backend, e.g., nccl, rccl, mpi.
+ * \param device_ids The device ids of the workers.
+ */
+ virtual void InitCCL(String ccl, ShapeTuple device_ids) = 0;
/*!
* \brief Get the value of a register from a remote worker.
* \param reg_id The id of the register to be fetched.
@@ -220,13 +233,19 @@ class SessionObj : public Object {
* \return The value of the register.
*/
virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
-
- static constexpr const char* _type_key = "runtime.disco.Session";
- TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);
+ /*!
+ * \brief Set the value of a register on a remote worker.
+ * \param reg_id The id of the register to be set.
+ * \param value The value to be set.
+ * \param worker_id The id of the worker to be set.
+ */
+ virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int
worker_id) = 0;
struct FFI;
friend struct SessionObj::FFI;
friend class DRefObj;
+ static constexpr const char* _type_key = "runtime.disco.Session";
+ TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);
protected:
/*! \brief Deallocate a register id, kill it on all workers, and append it
to `free_regs_`. */
@@ -239,8 +258,22 @@ class SessionObj : public Object {
*/
class Session : public ObjectRef {
public:
- /*! \brief Create a session backed by a thread pool of workers */
- static Session ThreadedSession(int num_workers);
+ /*!
+ * \brief Create a session backed by a thread pool of workers
+ * \param num_workers The number of workers.
+ */
+ TVM_DLL static Session ThreadedSession(int num_workers);
+ /*!
+ * \brief Create a session backed by pipe-based multiprocessing
+ * \param num_workers The number of workers.
+ * \param process_pool_creator The name of a global function that takes
`num_workers` as an input,
+ * and returns a PackedFunc, which takes an integer `worker_id` as the input
and returns None.
+ * When `worker-id` is 0, it shuts down the process pool; Otherwise, it
retursn a tuple
+ * (read_fd, writefd) used to communicate with the corresponding worker.
+ * \note Worker-0 is always co-located with the controler as a separate
thread, and therefore
+ * worker-0 does not exist in the process pool.
+ */
+ TVM_DLL static Session ProcessSession(int num_workers, String
process_pool_creator);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef,
SessionObj);
};
@@ -250,6 +283,7 @@ class Session : public ObjectRef {
*/
class DiscoChannel {
public:
+ virtual ~DiscoChannel() = default;
/*! \brief Send a packed sequence to the receiver */
virtual void Send(const TVMArgs& args) = 0;
/*! \brief Receive a packed sequence from worker */
@@ -272,6 +306,10 @@ TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) {
return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id,
worker_id);
}
+void DRefObj::DebugCopyFrom(int worker_id, TVMArgValue value) {
+ return Downcast<Session>(this->session)->DebugSetRegister(this->reg_id,
value, worker_id);
+}
+
template <typename... Args>
DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
constexpr int offset = 3;
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
new file mode 100644
index 0000000000..9faa5742ae
--- /dev/null
+++ b/python/tvm/exec/disco_worker.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Internal DiscoWorker for Disco ProcessSession."""
+import os
+import sys
+
+from tvm import runtime as _ # pylint: disable=unused-import
+from tvm._ffi import get_global_func
+from tvm.testing import disco as _ # pylint: disable=unused-import
+
+
+def main():
+ """Main worker function"""
+ if len(sys.argv) != 5:
+ print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
+ return
+ worker_id = int(sys.argv[1])
+ num_workers = int(sys.argv[2])
+ if sys.platform == "win32":
+ import msvcrt # pylint: disable=import-outside-toplevel,import-error
+
+ reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
+ writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
+ else:
+ reader = int(sys.argv[3])
+ writer = int(sys.argv[4])
+
+ worker_func = get_global_func("runtime.disco.WorkerProcess")
+ worker_func(worker_id, num_workers, reader, writer)
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except (KeyboardInterrupt, IOError):
+ pass
diff --git a/python/tvm/runtime/disco/__init__.py
b/python/tvm/runtime/disco/__init__.py
index 57c0548e2e..856e69bc35 100644
--- a/python/tvm/runtime/disco/__init__.py
+++ b/python/tvm/runtime/disco/__init__.py
@@ -15,4 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""TVM distributed runtime API."""
-from .session import DModule, DPackedFunc, DRef, Session, ThreadedSession
+from .session import (
+ DModule,
+ DPackedFunc,
+ DRef,
+ ProcessSession,
+ Session,
+ ThreadedSession,
+)
diff --git a/python/tvm/runtime/disco/process_pool.py
b/python/tvm/runtime/disco/process_pool.py
new file mode 100644
index 0000000000..44348577f7
--- /dev/null
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Pipe worker for multi-processing."""
+import os
+import subprocess
+import sys
+
+import psutil
+
+from tvm._ffi import register_func
+from tvm.runtime import ShapeTuple
+
+
+class DiscoPopenWorker:
+ """A subprocess worker via Popen.
+
+ PopenWorker provides a low-level
+ API to interact with a separate process via Popen.
+
+ Parameters
+ ----------
+ worker_id : int
+ The worker id of the current worker.
+
+ num_workers : int
+ The total number of workers.
+
+ stdout: Union[None, int, IO[Any]]
+ The standard output streams handler specified for the popen process.
+
+ stderr: Union[None, int, IO[Any]]
+ The standard error streams handler specified for the popen process.
+ """
+
+ def __init__(self, worker_id: int, num_workers: int, stdout=None,
stderr=None):
+ self.worker_id = worker_id
+ self.num_workers = num_workers
+ self._proc = None
+ self._stdout = stdout
+ self._stderr = stderr
+
+ def __del__(self):
+ try:
+ self.kill()
+ except ImportError:
+ pass
+
+ def kill(self):
+ """Kill the current running process and cleanup.
+
+ Note
+ ----
+ The worker can start a new process when send is called again.
+ """
+ if self._proc is not None:
+ # kill all child processes recursively
+ try:
+ _kill_child_processes(self._proc.pid)
+ except TypeError:
+ pass
+ try:
+ self._proc.kill()
+ except OSError:
+ pass
+
+ # Join the child process to avoid zombie processes
+ self.join(timeout=1.0)
+ self._proc = None
+
+ def join(self, timeout=None):
+ """Join the current process worker before it terminates.
+
+ Parameters
+ ----------
+ timeout: Optional[number]
+ Timeout value, block at most timeout seconds if it
+ is a positive number.
+ """
+ if self._proc:
+ try:
+ self._proc.wait(timeout)
+ except subprocess.TimeoutExpired:
+ pass
+
+ def start(self):
+ """Start a new subprocess if nothing is available"""
+ if self._proc is not None:
+ return None, None
+
+ # connect subprocess with a pair of pipes
+ main_read, worker_write = os.pipe()
+ worker_read, main_write = os.pipe()
+
+ cmd = [
+ sys.executable,
+ "-m",
+ "tvm.exec.disco_worker",
+ str(self.worker_id),
+ str(self.num_workers),
+ ]
+ if sys.platform == "win32":
+ import msvcrt # pylint:
disable=import-error,import-outside-toplevel
+
+ worker_read_handle = msvcrt.get_osfhandle(worker_read)
+ worker_write_handle = msvcrt.get_osfhandle(worker_write)
+ os.set_handle_inheritable(worker_read_handle, True)
+ os.set_handle_inheritable(worker_write_handle, True)
+ cmd += [str(worker_read_handle), str(worker_write_handle)]
+ self._proc = subprocess.Popen(
+ cmd,
+ close_fds=False,
+ stdout=self._stdout,
+ stderr=self._stderr,
+ )
+ else:
+ cmd += [str(worker_read), str(worker_write)]
+ self._proc = subprocess.Popen( # pylint:
disable=consider-using-with
+ cmd,
+ pass_fds=(worker_read, worker_write),
+ stdout=self._stdout,
+ stderr=self._stderr,
+ )
+
+ # close worker side of the pipe
+ os.close(worker_read)
+ os.close(worker_write)
+ return main_read, main_write
+
+
+def _kill_child_processes(pid):
+ """Kill all child processes recursively for a given pid.
+
+ Parameters
+ ----------
+ pid : int
+ The given parameter id.
+ """
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ except psutil.NoSuchProcess:
+ return
+
+ for process in children:
+ try:
+ process.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+
+@register_func("runtime.disco.create_process_pool")
+def _create_process_pool(num_workers: int):
+ """Create a process pool where the workers' are are [1, num_workers)."""
+ pool = [DiscoPopenWorker(i, num_workers) for i in range(1, num_workers)]
+
+ def result_func(worker_id: int):
+ nonlocal pool
+ if worker_id != 0:
+ read_fd, write_fd = pool[worker_id - 1].start()
+ return ShapeTuple([read_fd, write_fd])
+ print("Shutting down the process pool")
+ del pool
+ return None
+
+ return result_func
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index eab5a5268d..d05561c2d1 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -27,7 +27,7 @@ from ..container import ShapeTuple
from ..ndarray import NDArray
from ..ndarray import array as _as_NDArray
from ..object import Object
-from . import _ffi_api
+from . import _ffi_api, process_pool # pylint: disable=unused-import
@register_object("runtime.disco.DRef")
@@ -250,22 +250,21 @@ class Session(Object):
func = self._get_cached_method("runtime.disco.load_vm_module")
return DModule(func(path, device))
- def init_ccl(self, api: str, *args):
+ def init_ccl(self, ccl: str, *device_ids):
"""Initialize the underlying communication collective library.
Parameters
----------
- api : str
+ ccl : str
The name of the communication collective library. Currently
supported libraries are:
- nccl
- rccl
- mpi
- *args : various types
- The arguments to be passed to the initialization function of the
communication
+ *device_ids : int
+ The device IDs to be used by the underlying communication library.
"""
- assert api in ("nccl", "rccl"), f"Unsupported CCL backend: {api}"
- func = self.get_global_func(f"runtime.disco.{api}.init_ccl")
- func(*args)
+ assert ccl in ("nccl", "rccl"), f"Unsupported CCL backend: {ccl}"
+ return _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) #
type: ignore # pylint: disable=no-member
def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
"""Broadcast an array from worker-0 to all other workers.
@@ -343,6 +342,18 @@ class ThreadedSession(Session):
)
+@register_object("runtime.disco.ProcessSession")
+class ProcessSession(Session):
+ """A Disco session backed by pipe-based multi-processing."""
+
+ def __init__(self, num_workers: int) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.SessionProcess, # type: ignore # pylint:
disable=no-member
+ num_workers,
+ "runtime.disco.create_process_pool",
+ )
+
+
REDUCE_OPS = {
"sum": 0,
"prod": 1,
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index 3e5f838a27..9aa1a31933 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -17,8 +17,7 @@
# pylint: disable=redefined-builtin, wildcard-import
"""Utility Python functions for TVM testing"""
-
-from . import auto_scheduler, autotvm
+from . import auto_scheduler, autotvm, disco
from ._ffi_api import (
ErrorTest,
FrontendTestModule,
diff --git a/python/tvm/testing/disco.py b/python/tvm/testing/disco.py
new file mode 100644
index 0000000000..c13e83b7c4
--- /dev/null
+++ b/python/tvm/testing/disco.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, missing-function-docstring,
missing-class-docstring
+"""Common utilities for testing disco"""
+from tvm._ffi import register_func
+from tvm.runtime import NDArray, ShapeTuple, String
+from tvm.runtime.ndarray import array
+
+
+@register_func("tests.disco.add_one")
+def add_one(x: int) -> int: # pylint: disable=invalid-name
+ return x + 1
+
+
+@register_func("tests.disco.add_one_float", override=True)
+def add_one_float(x: float): # pylint: disable=invalid-name
+ return x + 0.5
+
+
+@register_func("tests.disco.add_one_ndarray", override=True)
+def add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name
+ return array(x.numpy() + 1)
+
+
+@register_func("tests.disco.str", override=True)
+def str_func(x: str): # pylint: disable=invalid-name
+ return x + "_suffix"
+
+
+@register_func("tests.disco.str_obj", override=True)
+def str_obj_func(x: String): # pylint: disable=invalid-name
+ assert isinstance(x, String)
+ return String(x + "_suffix")
+
+
+@register_func("tests.disco.shape_tuple", override=True)
+def shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name
+ assert isinstance(x, ShapeTuple)
+ return ShapeTuple(list(x) + [4, 5])
diff --git a/src/runtime/disco/bcast_session.cc
b/src/runtime/disco/bcast_session.cc
index 0625c1157e..9b553c319c 100644
--- a/src/runtime/disco/bcast_session.cc
+++ b/src/runtime/disco/bcast_session.cc
@@ -68,6 +68,13 @@ void BcastSessionObj::Shutdown() {
BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown,
0);
}
+void BcastSessionObj::InitCCL(String ccl, ShapeTuple device_ids) {
+ const auto* pf = runtime::Registry::Get("runtime.disco." + ccl +
".init_ccl");
+ CHECK(pf) << "ValueError: Cannot initialize CCL `" << ccl
+ << "`, because cannot find function: runtime.disco." << ccl <<
".init_ccl";
+ (*pf)(GetRef<Session>(this), device_ids);
+}
+
void BcastSessionObj::SyncWorker(int worker_id) {
BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kSyncWorker,
worker_id);
TVMArgs args = this->RecvReplyPacked(worker_id);
diff --git a/src/runtime/disco/bcast_session.h
b/src/runtime/disco/bcast_session.h
index 0221207b96..d064b30f5a 100644
--- a/src/runtime/disco/bcast_session.h
+++ b/src/runtime/disco/bcast_session.h
@@ -42,7 +42,9 @@ class BcastSessionObj : public SessionObj {
void CopyToWorker0(const NDArray& host_array, const DRef& remote_array)
override;
void SyncWorker(int worker_id) override;
void Shutdown() override;
+ void InitCCL(String ccl, ShapeTuple device_ids) override;
TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0;
+ void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id)
override = 0;
protected:
/*! \brief Deallocate a register id, kill it on all workers, and append it
to `free_regs_`. */
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 64e3fd4b28..06408c723a 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -100,7 +100,11 @@ void RecvFromWorker0(NDArray buffer) {
GetCCLFunc("recv_from_worker0")(buffer);
int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }
-void SyncWorker() { GetCCLFunc("sync_worker")(); }
+void SyncWorker() {
+ if (DiscoWorker::ThreadLocal()->ccl != "") {
+ GetCCLFunc("sync_worker")();
+ }
+}
TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray);
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 0212923cef..e404e3c2bb 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -19,13 +19,16 @@
#include <cuda_runtime_api.h>
#include <dlpack/dlpack.h>
#include <nccl.h>
+#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/disco/session.h>
#include <tvm/runtime/registry.h>
+#include <cstring>
#include <mutex>
#include <sstream>
#include <vector>
+#include "../../../support/process_id.h"
#include "../../cuda/cuda_common.h"
#include "./utils.h"
@@ -33,48 +36,6 @@ namespace tvm {
namespace runtime {
namespace nccl {
-struct NCCLGlobalContext {
- std::vector<ncclComm_t> communicators;
-
- static NCCLGlobalContext* Get() {
- static NCCLGlobalContext ctx;
- return &ctx;
- }
-
- void Initialize(const std::vector<int>& device_ids) {
- {
- std::ostringstream os;
- bool is_first = true;
- for (int device_id : device_ids) {
- if (!is_first) {
- os << ",";
- } else {
- is_first = false;
- }
- os << device_id;
- }
- LOG(INFO) << "Initializing NCCL with devices: " << os.str() << ".";
- }
- // TODO(@junrushao): support more flexible communicator pattern for
generic SPMD usecases
- DiscoWorker* worker = DiscoWorker::ThreadLocal();
- int num_workers = worker->num_workers;
- CHECK_EQ(device_ids.size(), num_workers)
- << "ValueError: There are " << num_workers << " worker(s), but " <<
device_ids.size()
- << " device id(s) are provided.";
- ncclUniqueId id;
- NCCL_CALL(ncclGetUniqueId(&id));
- NCCL_CALL(ncclGroupStart());
- for (int worker_id = 0; worker_id < num_workers; ++worker_id) {
- int device_id = device_ids[worker_id];
- ncclComm_t comm;
- CUDA_CALL(cudaSetDevice(device_id));
- NCCL_CALL(ncclCommInitRank(&comm, num_workers, id, worker_id));
- this->communicators.push_back(comm);
- }
- NCCL_CALL(ncclGroupEnd());
- }
-};
-
struct NCCLThreadLocalContext {
DiscoWorker* worker;
int device_id;
@@ -92,23 +53,38 @@ struct NCCLThreadLocalContext {
}
};
-void InitCCL(const std::vector<int>& device_ids) {
- // Set up global context only once
- static std::once_flag flag;
- std::call_once(flag, [&]() {
NCCLGlobalContext::Get()->Initialize(device_ids); });
- // Set up thread-local context for each thread
- DiscoWorker* worker = DiscoWorker::ThreadLocal();
+void InitCCL(Session sess, ShapeTuple device_ids) {
+ DRef func = sess->GetGlobalFunc("runtime.disco.nccl.init_ccl_per_worker");
+ LOG(INFO) << "Initializing NCCL with devices: " << device_ids;
+ ncclUniqueId id;
+ TVMByteArray array;
+ NCCL_CALL(ncclGetUniqueId(&id));
+ array.data = id.internal;
+ array.size = NCCL_UNIQUE_ID_BYTES;
+ sess->CallPacked(func, device_ids, array);
+}
+
+void InitCCLPerWorker(ShapeTuple device_ids, std::string unique_id_bytes) {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ ICHECK(worker != nullptr);
+ CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES)
+ << "ValueError: The length of unique_id must be " <<
NCCL_UNIQUE_ID_BYTES << ", but got "
+ << unique_id_bytes.size() << ".";
+ // Step up local context of NCCL
int device_id = device_ids[worker->worker_id];
CUDA_CALL(cudaSetDevice(device_id));
+ CUDA_CALL(cudaStreamCreate(&ctx->stream));
Device device{DLDeviceType::kDLCUDA, device_id};
+ DeviceAPI::Get(device)->SetStream(device, ctx->stream);
worker->default_device = device;
worker->ccl = "nccl";
ctx->worker = worker;
ctx->device_id = device_id;
- ctx->comm = NCCLGlobalContext::Get()->communicators[worker->worker_id];
- CUDA_CALL(cudaStreamCreate(&ctx->stream));
- DeviceAPI::Get(device)->SetStream(device, ctx->stream);
+ // Initialize the communicator
+ ncclUniqueId id;
+ std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES);
+ NCCL_CALL(ncclCommInitRank(&ctx->comm, worker->num_workers, id,
worker->worker_id));
}
void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
@@ -158,7 +134,7 @@ void ScatterFromWorker0(Optional<NDArray> send, NDArray
recv) {
}
} else {
if (send.defined()) {
- LOG(WARNING) << "ValueError: buffer `send` must be None when worker_id
!= 0. However, got "
+ LOG(WARNING) << "Buffer `send` must be None when worker_id != 0, but got
"
"send = "
<< send.get() << ". This will be ignored.";
}
@@ -222,17 +198,12 @@ void RecvFromWorker0(NDArray buffer) {
void SyncWorker() {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
+ ICHECK(ctx->worker != nullptr);
CUDA_CALL(cudaStreamSynchronize(ctx->stream));
}
-TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl")
- .set_body([](TVMArgs args, TVMRetValue* rv) -> void {
- std::vector<int> device_ids;
- for (int i = 0; i < args.num_args; ++i) {
- device_ids.push_back(args[i].operator int());
- }
- InitCCL(device_ids);
- });
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl").set_body_typed(InitCCL);
+TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl_per_worker").set_body_typed(InitCCLPerWorker);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.allreduce")
.set_body_typed([](NDArray send, int kind, NDArray recv) {
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " <<
kind;
diff --git a/src/runtime/disco/nccl/utils.h b/src/runtime/disco/nccl/utils.h
index 4e5fb8cd74..7f40365136 100644
--- a/src/runtime/disco/nccl/utils.h
+++ b/src/runtime/disco/nccl/utils.h
@@ -69,6 +69,7 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype)
{
return ncclBfloat16;
}
LOG(FATAL) << "ValueError: Unsupported data type " << dtype;
+ throw;
}
inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
@@ -85,6 +86,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) {
return ncclAvg;
}
LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast<int>(kind);
+ throw;
}
} // namespace nccl
diff --git a/src/runtime/disco/process_session.cc
b/src/runtime/disco/process_session.cc
new file mode 100644
index 0000000000..8ddfdce812
--- /dev/null
+++ b/src/runtime/disco/process_session.cc
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <memory>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+#include "../../support/pipe.h"
+#include "../minrpc/rpc_reference.h"
+#include "./bcast_session.h"
+#include "./protocol.h"
+#include "./worker.h"
+#include "tvm/runtime/c_runtime_api.h"
+
+namespace tvm {
+namespace runtime {
+
+class DiscoPipeMessageQueue : private ::tvm::support::Pipe,
+ private DiscoProtocol<DiscoPipeMessageQueue> {
+ public:
+ explicit DiscoPipeMessageQueue(int64_t handle) :
::tvm::support::Pipe(handle) {}
+
+ ~DiscoPipeMessageQueue() = default;
+
+ void Send(const TVMArgs& args) {
+ RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args,
this);
+ }
+
+ TVMArgs Recv() {
+ {
+ this->RecycleAll();
+ uint64_t packet_nbytes = 0;
+ RPCCode code = RPCCode::kReturn;
+ this->Read(&packet_nbytes);
+ this->Read(&code);
+ }
+ TVMValue* values = nullptr;
+ int* type_codes = nullptr;
+ int num_args = 0;
+ RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
+ return TVMArgs(values, type_codes, num_args);
+ }
+
+ using dmlc::Stream::Read;
+ using dmlc::Stream::ReadArray;
+ using dmlc::Stream::Write;
+ using dmlc::Stream::WriteArray;
+ friend struct RPCReference;
+ friend struct DiscoProtocol<DiscoPipeMessageQueue>;
+};
+
+class DiscoProcessChannel final : public DiscoChannel {
+ public:
+ DiscoProcessChannel(int64_t controler_to_worker_fd, int64_t
worker_to_controler_fd)
+ : controler_to_worker_(controler_to_worker_fd),
+ worker_to_controler_(worker_to_controler_fd) {}
+
+ DiscoProcessChannel(DiscoProcessChannel&& other) = delete;
+ DiscoProcessChannel(const DiscoProcessChannel& other) = delete;
+
+ void Send(const TVMArgs& args) { controler_to_worker_.Send(args); }
+ TVMArgs Recv() { return controler_to_worker_.Recv(); }
+ void Reply(const TVMArgs& args) { worker_to_controler_.Send(args); }
+ TVMArgs RecvReply() { return worker_to_controler_.Recv(); }
+
+ DiscoPipeMessageQueue controler_to_worker_;
+ DiscoPipeMessageQueue worker_to_controler_;
+};
+
+class ProcessSessionObj final : public BcastSessionObj {
+ public:
+ explicit ProcessSessionObj(int num_workers, PackedFunc process_pool)
+ : process_pool_(process_pool),
+ worker_0_(std::make_unique<DiscoWorkerThread>(0, num_workers,
&worker_zero_data_)) {
+ std::vector<int64_t> read_fds;
+ std::vector<int64_t> write_fds;
+ read_fds.reserve(num_workers - 1);
+ write_fds.reserve(num_workers - 1);
+ for (int i = 1; i < num_workers; ++i) {
+ ShapeTuple fds = process_pool(i);
+ CHECK_EQ(fds.size(), 2) << "ValueError: process_pool(" << i << ") should
return a tuple of "
+ << "size 2, but got a tuple of size " <<
fds.size() << ".";
+ read_fds.push_back(fds[0]);
+ write_fds.push_back(fds[1]);
+ }
+ for (int i = 0; i < num_workers - 1; ++i) {
+
workers_.emplace_back(std::make_unique<DiscoProcessChannel>(write_fds[i],
read_fds[i]));
+ }
+ }
+
+ void Kill() {
+ if (this->worker_0_ != nullptr) {
+ this->Shutdown();
+ this->worker_0_.reset();
+ this->workers_.clear();
+ this->process_pool_(0);
+ }
+ }
+
+ ~ProcessSessionObj() { Kill(); }
+
+ TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
+ if (worker_id == 0) {
+ this->SyncWorker(worker_id);
+ return worker_0_->worker->register_file.at(reg_id);
+ }
+ {
+ TVMValue values[3];
+ int type_codes[3];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoAction::kDebugGetFromRemote), reg_id,
+ worker_id);
+ workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 3));
+ }
+ TVMArgs args = this->RecvReplyPacked(worker_id);
+ ICHECK_EQ(args.size(), 2);
+ ICHECK(static_cast<DiscoAction>(args[0].operator int()) ==
DiscoAction::kDebugGetFromRemote);
+ TVMRetValue result;
+ result = args[1];
+ return result;
+ }
+
+ void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) {
+ if (worker_id == 0) {
+ this->SyncWorker(worker_id);
+ worker_0_->worker->SetRegister(reg_id, value);
+ return;
+ }
+ ObjectRef wrapped{nullptr};
+ if (value.type_code() == kTVMNDArrayHandle || value.type_code() ==
kTVMObjectHandle) {
+ wrapped = DiscoDebugObject::Wrap(value);
+ TVMValue tvm_value;
+ int type_code = kTVMObjectHandle;
+ tvm_value.v_handle = const_cast<Object*>(wrapped.get());
+ value = TVMArgValue(tvm_value, type_code);
+ }
+ {
+ TVMValue values[4];
+ int type_codes[4];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoAction::kDebugSetRegister), reg_id,
+ worker_id, value);
+ workers_[worker_id - 1]->Send(TVMArgs(values, type_codes, 4));
+ }
+ TVMRetValue result;
+ TVMArgs args = this->RecvReplyPacked(worker_id);
+ ICHECK_EQ(args.size(), 1);
+ ICHECK(static_cast<DiscoAction>(args[0].operator int()) ==
DiscoAction::kDebugSetRegister);
+ }
+
+ void BroadcastPacked(const TVMArgs& args) final {
+ worker_0_->channel->Send(args);
+ for (std::unique_ptr<DiscoProcessChannel>& channel : workers_) {
+ channel->Send(args);
+ }
+ }
+
+ TVMArgs RecvReplyPacked(int worker_id) final {
+ if (worker_id == 0) {
+ return worker_0_->channel->RecvReply();
+ }
+ return this->workers_.at(worker_id - 1)->RecvReply();
+ }
+
+ PackedFunc process_pool_;
+ std::unique_ptr<DiscoWorkerThread> worker_0_;
+ std::vector<std::unique_ptr<DiscoProcessChannel>> workers_;
+
+ static constexpr const char* _type_key = "runtime.disco.ProcessSession";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ProcessSessionObj, SessionObj);
+};
+
+TVM_REGISTER_OBJECT_TYPE(DiscoDebugObject);
+TVM_REGISTER_OBJECT_TYPE(ProcessSessionObj);
+
+Session Session::ProcessSession(int num_workers, String process_pool_creator) {
+ const PackedFunc* pf = Registry::Get(process_pool_creator);
+ CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator
+ << " in the registry. Please check if it is registered.";
+ PackedFunc process_pool = (*pf)(num_workers);
+ auto n = make_object<ProcessSessionObj>(num_workers, process_pool);
+ return Session(n);
+}
+
+void WorkerProcess(int worker_id, int num_workers, int64_t read_fd, int64_t
write_fd) {
+ DiscoProcessChannel channel(read_fd, write_fd);
+ DiscoWorker worker(worker_id, num_workers, nullptr, &channel);
+ worker.MainLoop();
+}
+
+TVM_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession);
+TVM_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess);
+
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
new file mode 100644
index 0000000000..50a6b091af
--- /dev/null
+++ b/src/runtime/disco/protocol.h
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RUNTIME_DISCO_PROTOCOL_H_
+#define TVM_RUNTIME_DISCO_PROTOCOL_H_
+
+#include <dmlc/io.h>
+#include <dmlc/memory_io.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "../../support/arena.h"
+#include "../../support/base64.h"
+#include "../minrpc/rpc_reference.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief The communication protocol used by Disco message channel.
+ * \tparam SubClassType The subclass type that inherits this protocol.
+ */
+template <class SubClassType>
+struct DiscoProtocol {
+ protected:
+ /*! \brief Virtual destructor */
+ virtual ~DiscoProtocol() = default;
+
+ /*! \brief Recycle all the memory used in the arena */
+ inline void RecycleAll() {
+ this->object_arena_.clear();
+ this->arena_.RecycleAll();
+ }
+
+ /*! \brief Get the length of the object being serialized. Used by
RPCReference. */
+ inline uint64_t GetObjectBytes(Object* obj);
+
+ /*! \brief Write the object to stream. Used by RPCReference. */
+ inline void WriteObject(Object* obj);
+
+ /*! \brief Read the object from stream. Used by RPCReference. */
+ inline void ReadObject(int* tcode, TVMValue* value);
+
+ /*! \brief Callback method used when starting a new message. Used by
RPCReference. */
+ void MessageStart(uint64_t packet_nbytes) {}
+
+ /*! \brief Callback method used when a new message is complete. Used by
RPCReference. */
+ void MessageDone() {}
+
+ /*! \brief Callback method when an error occurs in (de)-serialization. Used
by RPCReference. */
+ void ThrowError(RPCServerStatus status) {
+ LOG(FATAL) << "InternalError: Unexpected error in RPC: " <<
RPCServerStatusToString(status);
+ }
+
+ /*!\ brief Arena used by RPCReference to allocate POD memory */
+ template <typename T>
+ T* ArenaAlloc(int count) {
+ static_assert(std::is_pod<T>::value, "need to be trival");
+ return arena_.template allocate_<T>(count);
+ }
+
+ support::Arena arena_;
+ std::vector<ObjectRef> object_arena_;
+ friend struct RPCReference;
+};
+
+/*!
+ * \brief The debug extension of the communication protocol that allows
serialization and
+ * deserialization of NDArrays and reflection-capable TVM objects.
+ */
+struct DiscoDebugObject : public Object {
+ public:
+ /*! \brief The data to be serialized */
+ TVMRetValue data;
+
+ /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug
extension. */
+ static ObjectRef Wrap(const TVMRetValue& data) {
+ ObjectPtr<DiscoDebugObject> n = make_object<DiscoDebugObject>();
+ n->data = data;
+ return ObjectRef(n);
+ }
+
+ /*! \brief Wrap an NDArray or reflection-capable TVM object into the debug
extension. */
+ static ObjectRef Wrap(const TVMArgValue& data) {
+ TVMRetValue rv;
+ rv = data;
+ return Wrap(std::move(rv));
+ }
+
+ /*! \brief Serialize the debug object to string */
+ inline std::string SaveToStr() const;
+ /*! \brief Deserialize the debug object from string */
+ static inline ObjectPtr<DiscoDebugObject> LoadFromStr(std::string json_str);
+ /*! \brief Get the size of the debug object in bytes */
+ inline uint64_t GetObjectBytes() const { return sizeof(uint64_t) +
this->SaveToStr().size(); }
+
+ static constexpr const char* _type_key = "runtime.disco.DiscoDebugObject";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DiscoDebugObject, SessionObj);
+};
+
+template <class SubClassType>
+inline uint64_t DiscoProtocol<SubClassType>::GetObjectBytes(Object* obj) {
+ if (obj->IsInstance<DRefObj>()) {
+ return sizeof(uint32_t) + sizeof(int64_t);
+ } else if (obj->IsInstance<StringObj>()) {
+ uint64_t size = static_cast<StringObj*>(obj)->size;
+ return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
+ } else if (obj->IsInstance<ShapeTupleObj>()) {
+ uint64_t ndim = static_cast<ShapeTupleObj*>(obj)->size;
+ return sizeof(uint32_t) + sizeof(uint64_t) + ndim *
sizeof(ShapeTupleObj::index_type);
+ } else if (obj->IsInstance<DiscoDebugObject>()) {
+ return sizeof(uint32_t) +
static_cast<DiscoDebugObject*>(obj)->GetObjectBytes();
+ } else {
+ LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
+ << obj->GetTypeKey() << " (type_index = " << obj->type_index()
<< ")";
+ }
+}
+template <class SubClassType>
+inline void DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
+ SubClassType* self = static_cast<SubClassType*>(this);
+ if (obj->IsInstance<DRefObj>()) {
+ int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
+ self->template Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
+ self->template Write<int64_t>(reg_id);
+ } else if (obj->IsInstance<StringObj>()) {
+ StringObj* str = static_cast<StringObj*>(obj);
+ self->template Write<uint32_t>(TypeIndex::kRuntimeString);
+ self->template Write<uint64_t>(str->size);
+ self->template WriteArray<char>(str->data, str->size);
+ } else if (obj->IsInstance<ShapeTupleObj>()) {
+ ShapeTupleObj* shape = static_cast<ShapeTupleObj*>(obj);
+ self->template Write<uint32_t>(TypeIndex::kRuntimeShapeTuple);
+ self->template Write<uint64_t>(shape->size);
+ self->template WriteArray<ShapeTupleObj::index_type>(shape->data,
shape->size);
+ } else if (obj->IsInstance<DiscoDebugObject>()) {
+ self->template Write<uint32_t>(TypeIndex::kRoot);
+ std::string str = static_cast<DiscoDebugObject*>(obj)->SaveToStr();
+ self->template Write<uint64_t>(str.size());
+ self->template WriteArray<char>(str.data(), str.size());
+ } else {
+ LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
+ << obj->GetTypeKey() << " (type_index = " << obj->type_index()
<< ")";
+ }
+}
+
+template <class SubClassType>
+inline void DiscoProtocol<SubClassType>::ReadObject(int* tcode, TVMValue*
value) {
+ SubClassType* self = static_cast<SubClassType*>(this);
+ ObjectRef result{nullptr};
+ uint32_t type_index;
+ self->template Read<uint32_t>(&type_index);
+ if (type_index == TypeIndex::kRuntimeDiscoDRef) {
+ ObjectPtr<DRefObj> dref = make_object<DRefObj>();
+ self->template Read<int64_t>(&dref->reg_id);
+ dref->session = Session{nullptr};
+ result = ObjectRef(std::move(dref));
+ } else if (type_index == TypeIndex::kRuntimeString) {
+ uint64_t size = 0;
+ self->template Read<uint64_t>(&size);
+ std::string data(size, '\0');
+ self->template ReadArray<char>(data.data(), size);
+ result = String(std::move(data));
+ } else if (type_index == TypeIndex::kRuntimeShapeTuple) {
+ uint64_t ndim = 0;
+ self->template Read<uint64_t>(&ndim);
+ std::vector<ShapeTupleObj::index_type> data(ndim);
+ self->template ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
+ result = ShapeTuple(std::move(data));
+ } else if (type_index == TypeIndex::kRoot) {
+ uint64_t size = 0;
+ self->template Read<uint64_t>(&size);
+ std::string data(size, '\0');
+ self->template ReadArray<char>(data.data(), size);
+ result = DiscoDebugObject::LoadFromStr(std::move(data))->data;
+ } else {
+ LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
+ << Object::TypeIndex2Key(type_index) << " (type_index = " <<
type_index << ")";
+ }
+ TVMArgsSetter(value, tcode)(0, result);
+ object_arena_.push_back(result);
+}
+
+inline std::string DiscoDebugObject::SaveToStr() const {
+ if (this->data.type_code() == kTVMObjectHandle) {
+ ObjectRef obj = this->data;
+ const PackedFunc* f = runtime::Registry::Get("node.SaveJSON");
+ CHECK(f) << "ValueError: Cannot serialize object in non-debugging mode: "
<< obj->GetTypeKey();
+ std::string result = (*f)(obj);
+ result.push_back('0');
+ return result;
+ } else if (this->data.type_code() == kTVMNDArrayHandle) {
+ NDArray array = this->data;
+ std::string result;
+ {
+ dmlc::MemoryStringStream mstrm(&result);
+ support::Base64OutStream b64strm(&mstrm);
+ runtime::SaveDLTensor(&b64strm, array.operator->());
+ b64strm.Finish();
+ }
+ result.push_back('1');
+ return result;
+ }
+ LOG(FATAL) << "ValueError: Cannot serialize the following type code in
non-debugging mode: "
+ << this->data.type_code() << "(" <<
ArgTypeCode2Str(this->data.type_code());
+}
+
+inline ObjectPtr<DiscoDebugObject> DiscoDebugObject::LoadFromStr(std::string
json_str) {
+ ICHECK(!json_str.empty());
+ char control_bit = json_str.back();
+ json_str.pop_back();
+ ObjectPtr<DiscoDebugObject> result = make_object<DiscoDebugObject>();
+ if (control_bit == '0') {
+ const PackedFunc* f = runtime::Registry::Get("node.LoadJSON");
+ CHECK(f) << "ValueError: Cannot deserialize object in non-debugging mode";
+ result->data = (*f)(json_str);
+ } else if (control_bit == '1') {
+ dmlc::MemoryStringStream mstrm(&json_str);
+ support::Base64InStream b64strm(&mstrm);
+ b64strm.InitPosition();
+ runtime::NDArray array;
+ ICHECK(array.Load(&b64strm));
+ result->data = std::move(array);
+ } else {
+ LOG(FATAL) << "ValueError: Unsupported control bit: " << control_bit
+ << ". Full string: " << json_str;
+ }
+ return result;
+}
+
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_RUNTIME_DISCO_PROTOCOL_H_
diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc
index e22b6c6d26..2cc027151a 100644
--- a/src/runtime/disco/session.cc
+++ b/src/runtime/disco/session.cc
@@ -31,15 +31,6 @@ struct SessionObj::FFI {
}
};
-void DRefObj::DebugCopyFrom(int worker_id, NDArray source) {
- TVMRetValue target_array = this->DebugGetFromRemote(worker_id);
- CHECK(target_array.type_code() == kTVMNDArrayHandle)
- << "ValueError: The DRef on the remote is not an NDArray, instead, its
type code is: "
- << ArgTypeCode2Str(target_array.type_code());
- NDArray target = target_array.operator NDArray();
- target.CopyFrom(source);
-}
-
TVM_REGISTER_OBJECT_TYPE(DRefObj);
TVM_REGISTER_OBJECT_TYPE(SessionObj);
TVM_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession);
@@ -58,6 +49,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0")
.set_body_method<Session>(&SessionObj::CopyToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker")
.set_body_method<Session>(&SessionObj::SyncWorker);
+TVM_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") //
+ .set_body_method<Session>(&SessionObj::InitCCL);
TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked").set_body([](TVMArgs
args, TVMRetValue* rv) {
Session self = args[0];
*rv = SessionObj::FFI::CallWithPacked(
diff --git a/src/runtime/disco/threaded_session.cc
b/src/runtime/disco/threaded_session.cc
index cb84918d2d..349601fd03 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -27,16 +27,17 @@
#include <thread>
#include <utility>
-#include "../../support/arena.h"
#include "../../support/ring_buffer.h"
#include "../minrpc/rpc_reference.h"
#include "./bcast_session.h"
+#include "./protocol.h"
#include "./worker.h"
namespace tvm {
namespace runtime {
-class DiscoThreadedMessageQueue : public dmlc::Stream {
+class DiscoThreadedMessageQueue : private dmlc::Stream,
+ private
DiscoProtocol<DiscoThreadedMessageQueue> {
public:
void Send(const TVMArgs& args) {
RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args,
this);
@@ -67,10 +68,7 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
condition_.wait(lock, [this] { return msg_cnt_.load() > 0; });
--msg_cnt_;
}
- {
- this->arena_.RecycleAll();
- this->object_arena_.clear();
- }
+ this->RecycleAll();
uint64_t packet_nbytes = 0;
RPCCode code = RPCCode::kReturn;
this->Read(&packet_nbytes);
@@ -84,18 +82,6 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
this->ring_buffer_.Reserve(n);
}
- void MessageDone() {}
-
- void ThrowError(RPCServerStatus status) {
- LOG(FATAL) << "InternalError: Unexpected error in RPC: " <<
RPCServerStatusToString(status);
- }
-
- template <typename T>
- T* ArenaAlloc(int count) {
- static_assert(std::is_pod<T>::value, "need to be trival");
- return arena_.template allocate_<T>(count);
- }
-
size_t Read(void* data, size_t size) final {
std::lock_guard<std::mutex> lock(mutex_);
ring_buffer_.Read(data, size);
@@ -107,85 +93,17 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
ring_buffer_.Write(data, size);
}
- uint64_t GetObjectBytes(Object* obj) {
- if (obj->IsInstance<DRefObj>()) {
- return sizeof(uint32_t) + sizeof(int64_t);
- } else if (obj->IsInstance<StringObj>()) {
- uint64_t size = static_cast<StringObj*>(obj)->size;
- return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
- } else if (obj->IsInstance<ShapeTupleObj>()) {
- uint64_t ndim = static_cast<ShapeTupleObj*>(obj)->size;
- return sizeof(uint32_t) + sizeof(uint64_t) + ndim *
sizeof(ShapeTupleObj::index_type);
- } else {
- LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
- << obj->GetTypeKey() << " (type_index = " <<
obj->type_index() << ")";
- }
- }
-
- void WriteObject(Object* obj) {
- if (obj->IsInstance<DRefObj>()) {
- int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
- this->Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
- this->Write<int64_t>(reg_id);
- } else if (obj->IsInstance<StringObj>()) {
- StringObj* str = static_cast<StringObj*>(obj);
- this->Write<uint32_t>(TypeIndex::kRuntimeString);
- this->Write<uint64_t>(str->size);
- this->WriteArray<char>(str->data, str->size);
- } else if (obj->IsInstance<ShapeTupleObj>()) {
- ShapeTupleObj* shape = static_cast<ShapeTupleObj*>(obj);
- this->Write<uint32_t>(TypeIndex::kRuntimeShapeTuple);
- this->Write<uint64_t>(shape->size);
- this->WriteArray<ShapeTupleObj::index_type>(shape->data, shape->size);
- } else {
- LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
- << obj->GetTypeKey() << " (type_index = " <<
obj->type_index() << ")";
- }
- }
-
- void ReadObject(int* tcode, TVMValue* value) {
- ObjectRef result{nullptr};
- uint32_t type_index;
- this->Read<uint32_t>(&type_index);
- if (type_index == TypeIndex::kRuntimeDiscoDRef) {
- ObjectPtr<DRefObj> dref = make_object<DRefObj>();
- this->Read<int64_t>(&dref->reg_id);
- dref->session = Session{nullptr};
- result = ObjectRef(std::move(dref));
- } else if (type_index == TypeIndex::kRuntimeString) {
- uint64_t size = 0;
- this->Read<uint64_t>(&size);
- std::string data(size, '\0');
- this->ReadArray<char>(data.data(), size);
- result = String(std::move(data));
- } else if (type_index == TypeIndex::kRuntimeShapeTuple) {
- uint64_t ndim = 0;
- this->Read<uint64_t>(&ndim);
- std::vector<ShapeTupleObj::index_type> data(ndim);
- this->ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
- result = ShapeTuple(std::move(data));
- } else {
- LOG(FATAL) << "ValueError: Object type is not supported in Disco calling
convention: "
- << Object::TypeIndex2Key(type_index) << " (type_index = " <<
type_index << ")";
- }
- *tcode = kTVMObjectHandle;
- value->v_handle = const_cast<Object*>(result.get());
- object_arena_.push_back(result);
- }
-
using dmlc::Stream::Read;
using dmlc::Stream::ReadArray;
using dmlc::Stream::Write;
using dmlc::Stream::WriteArray;
friend struct RPCReference;
+ friend struct DiscoProtocol<DiscoThreadedMessageQueue>;
std::mutex mutex_;
std::atomic<int> msg_cnt_{0};
std::condition_variable condition_;
-
support::RingBuffer ring_buffer_;
- support::Arena arena_;
- std::vector<ObjectRef> object_arena_;
};
class DiscoThreadChannel final : public DiscoChannel {
@@ -199,44 +117,52 @@ class DiscoThreadChannel final : public DiscoChannel {
DiscoThreadedMessageQueue worker_to_controler_;
};
+DiscoWorkerThread::DiscoWorkerThread(int worker_id, int num_workers,
+ WorkerZeroData* worker_zero_data_)
+ : channel(std::make_unique<DiscoThreadChannel>()),
+ worker(
+ std::make_unique<DiscoWorker>(worker_id, num_workers,
worker_zero_data_, channel.get())),
+ thread(std::make_unique<std::thread>([worker = this->worker.get()] {
worker->MainLoop(); })) {
+}
+
class ThreadedSessionObj final : public BcastSessionObj {
public:
explicit ThreadedSessionObj(int num_workers) {
for (int i = 0; i < num_workers; ++i) {
- std::unique_ptr<DiscoThreadChannel> channel =
std::make_unique<DiscoThreadChannel>();
WorkerZeroData* data = (i == 0) ? &worker_zero_data_ : nullptr;
- workers_.emplace_back(std::make_unique<DiscoWorker>(i, num_workers,
data, channel.get()));
- channels_.emplace_back(std::move(channel));
- worker_threads_.emplace_back([worker = workers_.back().get()] {
worker->MainLoop(); });
+ workers_.emplace_back(i, num_workers, data);
}
}
~ThreadedSessionObj() {
this->Shutdown();
- for (std::thread& worker : this->worker_threads_) {
- worker.join();
- }
+ workers_.clear();
}
TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
this->SyncWorker(worker_id);
- return this->workers_.at(worker_id)->register_file.at(reg_id);
+ return this->workers_.at(worker_id).worker->register_file.at(reg_id);
+ }
+
+ void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) {
+ this->SyncWorker(worker_id);
+ this->workers_.at(worker_id).worker->SetRegister(reg_id, value);
}
void BroadcastPacked(const TVMArgs& args) final {
- for (const std::unique_ptr<DiscoThreadChannel>& channel : this->channels_)
{
- channel->Send(args);
+ for (const DiscoWorkerThread& worker : this->workers_) {
+ worker.channel->Send(args);
}
}
- TVMArgs RecvReplyPacked(int worker_id) final { return
channels_[worker_id]->RecvReply(); }
+ TVMArgs RecvReplyPacked(int worker_id) final {
+ return this->workers_.at(worker_id).channel->RecvReply();
+ }
static constexpr const char* _type_key = "runtime.disco.ThreadedSession";
TVM_DECLARE_FINAL_OBJECT_INFO(ThreadedSessionObj, SessionObj);
- std::vector<std::unique_ptr<DiscoThreadChannel>> channels_;
- std::vector<std::unique_ptr<DiscoWorker>> workers_;
- std::vector<std::thread> worker_threads_;
+ std::vector<DiscoWorkerThread> workers_;
};
TVM_REGISTER_OBJECT_TYPE(ThreadedSessionObj);
diff --git a/src/runtime/disco/worker.cc b/src/runtime/disco/worker.cc
index 63e814a7e2..3100985f18 100644
--- a/src/runtime/disco/worker.cc
+++ b/src/runtime/disco/worker.cc
@@ -19,11 +19,15 @@
#include "./worker.h"
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/disco/session.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <thread>
+#include "../../support/process_id.h"
#include "./builtin.h"
+#include "./protocol.h"
namespace tvm {
namespace runtime {
@@ -43,11 +47,23 @@ DiscoWorker* DiscoWorker::ThreadLocal() {
return ret;
}
+void DiscoWorker::SetRegister(int reg_id, TVMArgValue value) {
+ ICHECK(0 <= reg_id && reg_id < static_cast<int>(register_file.size()));
+ TVMRetValue& rv = register_file.at(reg_id);
+ if (rv.type_code() == kTVMNDArrayHandle && value.type_code() ==
kTVMNDArrayHandle) {
+ NDArray dst = rv;
+ NDArray src = value;
+ dst.CopyFrom(src);
+ } else {
+ rv = value;
+ }
+}
+
struct DiscoWorker::Impl {
static void MainLoop(DiscoWorker* self) {
ThreadLocalDiscoWorker::Get()->worker = self;
- LOG(INFO) << "[Thread " << std::this_thread::get_id() << "] Worker #" <<
self->worker_id
- << " Launched";
+ LOG(INFO) << "[Worker #" << self->worker_id << "] " <<
support::GetProcessIdAndThreadIdHeader()
+ << " started";
while (true) {
TVMArgs args = self->channel->Recv();
DiscoAction action = static_cast<DiscoAction>(args[0].operator int());
@@ -84,6 +100,17 @@ struct DiscoWorker::Impl {
SyncWorker(self, reg_id);
break;
}
+ case DiscoAction::kDebugGetFromRemote: {
+ int worker_id = args[2];
+ DebugGetFromRemote(self, reg_id, worker_id);
+ break;
+ }
+ case DiscoAction::kDebugSetRegister: {
+ int worker_id = args[2];
+ TVMArgValue value = args[3];
+ DebugSetRegister(self, reg_id, worker_id, value);
+ break;
+ }
}
}
}
@@ -131,6 +158,30 @@ struct DiscoWorker::Impl {
}
}
+ static void DebugGetFromRemote(DiscoWorker* self, int reg_id, int worker_id)
{
+ if (worker_id == self->worker_id) {
+ TVMRetValue rv = GetReg(self, reg_id);
+ if (rv.type_code() == kTVMNDArrayHandle || rv.type_code() ==
kTVMObjectHandle) {
+ rv = DiscoDebugObject::Wrap(rv);
+ }
+ TVMValue values[2];
+ int type_codes[2];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoAction::kDebugGetFromRemote), rv);
+ self->channel->Reply(TVMArgs(values, type_codes, 2));
+ }
+ }
+
+ static void DebugSetRegister(DiscoWorker* self, int reg_id, int worker_id,
TVMArgValue value) {
+ if (worker_id == self->worker_id) {
+ ::tvm::runtime::SyncWorker();
+ self->SetRegister(reg_id, value);
+ TVMValue values[1];
+ int type_codes[1];
+ PackArgs(values, type_codes,
static_cast<int>(DiscoAction::kDebugSetRegister));
+ self->channel->Reply(TVMArgs(values, type_codes, 1));
+ }
+ }
+
static void CallPacked(DiscoWorker* self, int64_t ret_reg_id, PackedFunc
func,
const TVMArgs& args) {
TVMValue* values = const_cast<TVMValue*>(args.values);
diff --git a/src/runtime/disco/worker.h b/src/runtime/disco/worker.h
index f10382b068..e948fa1668 100644
--- a/src/runtime/disco/worker.h
+++ b/src/runtime/disco/worker.h
@@ -33,6 +33,7 @@
#include <memory>
#include <mutex>
#include <queue>
+#include <thread>
#include <utility>
#include <vector>
@@ -81,6 +82,8 @@ class DiscoWorker {
void MainLoop();
/*! \brief Get the worker instance on the current thread */
static DiscoWorker* ThreadLocal();
+ /*! \brief Set the specific register to a specific value */
+ void SetRegister(int reg_id, TVMArgValue value);
/*! \brief The id of the worker.*/
int worker_id;
@@ -108,6 +111,46 @@ class DiscoWorker {
friend struct DiscoWorker::Impl;
};
+/*!
+ * \brief A worker thread in Disco, which upon creation, launches a new thread
to run the
+ * DiscoWorker.
+ * \sa DiscoWorker
+ */
+class DiscoWorkerThread {
+ public:
+ /*!
+ * \brief Construct a worker thread.
+ * \param worker_id The id of the worker.
+ * \param num_workers The total number of workers.
+ * \param worker_zero_data_ The data shared between worker-0 and the
controler. It's a nullptr if
+ * the worker is not worker-0.
+ */
+ explicit DiscoWorkerThread(int worker_id, int num_workers, WorkerZeroData*
worker_zero_data_);
+
+ /*! \brief Move constructor. */
+ explicit DiscoWorkerThread(DiscoWorkerThread&& other)
+ : channel(std::move(other.channel)),
+ worker(std::move(other.worker)),
+ thread(std::move(other.thread)) {}
+
+ /*! \brief Copy constructor is disabled */
+ DiscoWorkerThread(const DiscoWorkerThread& other) = delete;
+
+ /*! \brief Destructor that joins the thread before destruction */
+ ~DiscoWorkerThread() {
+ if (this->thread != nullptr) {
+ this->thread->join();
+ }
+ }
+
+ /*! \brief The communication channel between the controler and the worker */
+ std::unique_ptr<DiscoChannel> channel;
+ /*! \brief The worker whose internal state is visible to the controler */
+ std::unique_ptr<DiscoWorker> worker;
+ /*! \brief The thread that runs the worker's main loop. */
+ std::unique_ptr<std::thread> thread;
+};
+
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DISCO_WORKER_H_
diff --git a/src/support/process_id.h b/src/support/process_id.h
new file mode 100644
index 0000000000..8462ae0dd2
--- /dev/null
+++ b/src/support/process_id.h
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file pipe.h
+ * \brief Platform independent pipe, used for IPC.
+ */
+#ifndef TVM_SUPPORT_PROCESS_ID_H_
+#define TVM_SUPPORT_PROCESS_ID_H_
+
+#include <iomanip>
+#include <ios>
+#include <sstream>
+#include <string>
+#include <thread>
+
+#ifdef _WIN32
+#include <windows.h>
+#else
+#include <sys/types.h>
+#include <unistd.h>
+#endif
+
+namespace tvm {
+namespace support {
+
+/*! \brief Returns the PID of the current process as an 64-bit signed integer.
*/
+inline int64_t GetProcessId() {
+ int64_t result;
+#ifdef _WIN32
+ DWORD pid = GetCurrentProcessId();
+ result = static_cast<int64_t>(pid);
+#else
+ pid_t pid = getpid();
+ result = static_cast<int64_t>(pid);
+#endif
+ return result;
+}
+
+/*! \brief Returns the PID and TIR of the current process/thread as a
formatted string */
+inline std::string GetProcessIdAndThreadIdHeader() {
+ std::ostringstream os;
+ os << "[PID " << GetProcessId() << " TID 0x" << std::setw(16) <<
std::setfill('0') << std::hex
+ << std::this_thread::get_id() << "]";
+ return os.str();
+}
+
+} // namespace support
+} // namespace tvm
+
+#endif // TVM_SUPPORT_PROCESS_ID_H_
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index 6507af5699..f0f949ab80 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -19,30 +19,33 @@
import tempfile
import numpy as np
+import pytest
import tvm
-import tvm.testing
from tvm import dlight as dl
from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.relax_vm import VirtualMachine
from tvm.script import relax as R
+_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
-def test_init():
- devices = [0, 1]
- sess = di.ThreadedSession(num_workers=len(devices))
[email protected]("session_kind", _all_session_kinds)
+def test_init(session_kind):
+ devices = [0, 1]
+ sess = session_kind(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
-def test_allreduce():
[email protected]("session_kind", _all_session_kinds)
+def test_allreduce(session_kind):
devices = [0, 1]
+ sess = session_kind(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
+
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1,
dtype="float32").reshape(3, 4)
-
- sess = di.ThreadedSession(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array_1)
d_array.debug_copy_from(1, array_2)
@@ -60,12 +63,13 @@ def test_allreduce():
np.testing.assert_equal(result, expected)
-def test_broadcast_from_worker0():
[email protected]("session_kind", _all_session_kinds)
+def test_broadcast_from_worker0(session_kind):
devices = [0, 1]
- array = np.arange(12, dtype="float32").reshape(3, 4)
-
- sess = di.ThreadedSession(num_workers=len(devices))
+ sess = session_kind(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
+
+ array = np.arange(12, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
@@ -74,12 +78,13 @@ def test_broadcast_from_worker0():
np.testing.assert_equal(result, array)
-def test_scatter():
[email protected]("session_kind", _all_session_kinds)
+def test_scatter(session_kind):
devices = [0, 1]
- array = np.arange(36, dtype="float32").reshape(3, 4, 3)
-
- sess = di.ThreadedSession(num_workers=len(devices))
+ sess = session_kind(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
+
+ array = np.arange(36, dtype="float32").reshape(3, 4, 3)
d_src = sess.empty((3, 4, 3), "float32")
d_dst = sess.empty((3, 3, 2), "float32")
@@ -96,29 +101,29 @@ def test_scatter():
)
-# def test_gather():
-# num_workers = 2
-# devices = [1, 2]
-# array = np.arange(36, dtype="float32")
-
-# sess = di.ThreadedSession(num_workers=num_workers)
-# sess.init_ccl("nccl", *devices)
-# d_src = sess.empty((3, 3, 2), "float32")
-# d_dst = sess.empty((3, 4, 3), "float32")
-
-# d_src.debug_copy_from(0, array[:18])
-# d_src.debug_copy_from(1, array[18:])
-
-# sess.gather_to_worker0(d_src, d_dst)
[email protected]("session_kind", _all_session_kinds)
+def test_gather(session_kind):
+ devices = [1, 2]
+ sess = session_kind(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
-# np.testing.assert_equal(
-# d_dst.debug_get_from_remote(0).numpy(),
-# array.reshape(3, 4, 3),
-# )
+ array = np.arange(36, dtype="float32")
+ d_src = sess.empty((3, 3, 2), "float32")
+ d_dst = sess.empty((3, 4, 3), "float32")
+ d_src.debug_copy_from(0, array[:18])
+ d_src.debug_copy_from(1, array[18:])
+ sess.gather_to_worker0(d_src, d_dst)
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(0).numpy(),
+ array.reshape(3, 4, 3),
+ )
-def test_mlp(): # pylint: disable=too-many-locals
[email protected]("session_kind", _all_session_kinds)
+def test_mlp(session_kind): # pylint: disable=too-many-locals
devices = [0, 1]
+ sess = session_kind(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
# pylint: disable=invalid-name
@tvm.script.ir_module
@@ -193,8 +198,6 @@ def test_mlp(): # pylint: disable=too-many-locals
path = tmpdir + "/test.so"
relax_build(ShardedMLP, target).export_library(path)
- sess = di.ThreadedSession(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
mod = sess.load_vm_module(path)
d_X = sess.empty((128, 128), "float32")
@@ -215,8 +218,11 @@ def test_mlp(): # pylint: disable=too-many-locals
np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4)
-def test_attention(): # pylint: disable=too-many-locals,too-many-statements
[email protected]("session_kind", _all_session_kinds)
+def test_attention(session_kind): # pylint:
disable=too-many-locals,too-many-statements
devices = [0, 1]
+ sess = session_kind(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
# pylint: disable=invalid-name
@tvm.script.ir_module
@@ -343,8 +349,6 @@ def test_attention(): # pylint:
disable=too-many-locals,too-many-statements
path = tmpdir + "/test.so"
relax_build(ShardedAttention, target).export_library(path)
- sess = di.ThreadedSession(num_workers=len(devices))
- sess.init_ccl("nccl", *devices)
mod = sess.load_vm_module(path)
d_X = sess.empty((1, 10, 128), "float32")
@@ -372,4 +376,10 @@ def test_attention(): # pylint:
disable=too-many-locals,too-many-statements
if __name__ == "__main__":
- tvm.testing.main()
+ test_init(di.ProcessSession)
+ test_allreduce(di.ProcessSession)
+ test_broadcast_from_worker0(di.ProcessSession)
+ test_scatter(di.ProcessSession)
+ test_gather(di.ProcessSession)
+ test_mlp(di.ProcessSession)
+ test_attention(di.ProcessSession)
diff --git a/tests/python/disco/test_session.py
b/tests/python/disco/test_session.py
index a2c0906f22..40dcb04911 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -19,15 +19,16 @@
import tempfile
import numpy as np
+import pytest
import tvm
from tvm import relax as rx
-from tvm._ffi import register_func
from tvm.runtime import ShapeTuple, String
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
+from tvm.testing import disco as _
def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
@@ -44,29 +45,23 @@ def _numpy_from_worker_0(sess: di.Session, remote_array,
shape, dtype):
return host_array.numpy()
-def test_int():
- num_workers = 4
+_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
- @register_func("tests.disco.add_one", override=True)
- def add_one(x: int) -> int: # pylint: disable=invalid-name
- return x + 1
- sess = di.ThreadedSession(num_workers=num_workers)
[email protected]("session_kind", _all_session_kinds)
+def test_int(session_kind): # pylint: disable=invalid-name
+ num_workers = 4
+ sess = session_kind(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one")
result: di.DRef = func(1)
-
for i in range(num_workers):
assert result.debug_get_from_remote(i) == 2
-def test_float():
[email protected]("session_kind", _all_session_kinds)
+def test_float(session_kind):
num_workers = 4
-
- @register_func("tests.disco.add_one_float", override=True)
- def add_one(x: float): # pylint: disable=invalid-name
- return x + 0.5
-
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = session_kind(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.add_one_float")
result: di.DRef = func(1.5)
@@ -74,32 +69,23 @@ def test_float():
assert result.debug_get_from_remote(i) == 2.0
-def test_ndarray():
[email protected]("session_kind", _all_session_kinds)
+def test_ndarray(session_kind):
num_workers = 4
-
- @register_func("tests.disco.add_one_ndarray", override=True)
- def add_one(x: tvm.runtime.NDArray) -> tvm.runtime.NDArray: # pylint:
disable=invalid-name
- return tvm.nd.array(x.numpy() + 1)
-
+ sess = session_kind(num_workers=num_workers)
device = tvm.cpu(0)
x_np = np.arange(6).astype("float32").reshape([2, 3])
y_np = np.arange(6).astype("float32").reshape([2, 3]) + 1
-
- sess = di.ThreadedSession(num_workers=num_workers)
x_disc = _numpy_to_worker_0(sess, x_np, device=device)
y_disc = sess.get_global_func("tests.disco.add_one_ndarray")(x_disc)
y_nd = _numpy_from_worker_0(sess, y_disc, shape=y_np.shape,
dtype=y_np.dtype)
np.testing.assert_equal(y_nd, y_np)
-def test_string():
[email protected]("session_kind", _all_session_kinds)
+def test_string(session_kind):
num_workers = 4
-
- @register_func("tests.disco.str", override=True)
- def my_str_func(x: str): # pylint: disable=invalid-name
- return x + "_suffix"
-
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = session_kind(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.str")
result: di.DRef = func("hello")
@@ -107,15 +93,10 @@ def test_string():
assert result.debug_get_from_remote(i) == "hello_suffix"
-def test_string_obj():
[email protected]("session_kind", _all_session_kinds)
+def test_string_obj(session_kind):
num_workers = 4
-
- @register_func("tests.disco.str_obj", override=True)
- def my_str_func(x: String): # pylint: disable=invalid-name
- assert isinstance(x, String)
- return String(x + "_suffix")
-
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = session_kind(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.str_obj")
result: di.DRef = func(String("hello"))
@@ -125,26 +106,22 @@ def test_string_obj():
assert value == "hello_suffix"
-def test_shape_tuple():
[email protected]("session_kind", _all_session_kinds)
+def test_shape_tuple(session_kind):
num_workers = 4
-
- @register_func("tests.disco.shape_tuple", override=True)
- def my_str_func(x: ShapeTuple): # pylint: disable=invalid-name
- assert isinstance(x, ShapeTuple)
- return ShapeTuple(list(x) + [4, 5])
-
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = session_kind(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple")
result: di.DRef = func(ShapeTuple([1, 2, 3]))
-
for i in range(num_workers):
value = result.debug_get_from_remote(i)
assert isinstance(value, ShapeTuple)
assert list(value) == [1, 2, 3, 4, 5]
-def test_vm_module():
[email protected]("session_kind", _all_session_kinds)
+def test_vm_module(session_kind):
num_workers = 4
+ sess = session_kind(num_workers=num_workers)
# pylint: disable=invalid-name
@I.ir_module
@@ -172,7 +149,6 @@ def test_vm_module():
y_np = x_np.transpose()
rx.build(TestMod, target="llvm").export_library(path)
- sess = di.ThreadedSession(num_workers=num_workers)
mod = sess.load_vm_module(path, device=device)
x_disc = _numpy_to_worker_0(sess, x_np, device=device)
@@ -181,8 +157,10 @@ def test_vm_module():
np.testing.assert_equal(y_nd, y_np)
-def test_vm_multi_func():
[email protected]("session_kind", _all_session_kinds)
+def test_vm_multi_func(session_kind):
num_workers = 4
+ sess = session_kind(num_workers=num_workers)
# pylint: disable=invalid-name
@I.ir_module
@@ -231,7 +209,6 @@ def test_vm_multi_func():
y_np = x_np.transpose()
rx.build(TestMod, target="llvm").export_library(path)
- sess = di.ThreadedSession(num_workers=num_workers)
mod = sess.load_vm_module(path, device=device)
x_disc = _numpy_to_worker_0(sess, x_np, device=device)
@@ -244,11 +221,11 @@ def test_vm_multi_func():
if __name__ == "__main__":
- test_int()
- test_float()
- test_string()
- test_string_obj()
- test_shape_tuple()
- test_ndarray()
- test_vm_module()
- test_vm_multi_func()
+ test_int(di.ProcessSession)
+ test_float(di.ProcessSession)
+ test_string(di.ProcessSession)
+ test_string_obj(di.ProcessSession)
+ test_shape_tuple(di.ProcessSession)
+ test_ndarray(di.ProcessSession)
+ test_vm_module(di.ProcessSession)
+ test_vm_multi_func(di.ProcessSession)