This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3ec0ca5b0b [Disco] Expose functions to query the per-worker
device/rank (#16639)
3ec0ca5b0b is described below
commit 3ec0ca5b0b3941d9314cfada23dac3101cc163f7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Feb 26 04:06:15 2024 -0600
[Disco] Expose functions to query the per-worker device/rank (#16639)
In addition to the PackedFunc `"runtime.disco.worker_id"`, which
returns the worker ID wrapped in a `ShapeTuple`, this commit adds
`"runtime.disco.worker_rank"`, which returns the worker ID without
wrapping, and `"runtime.disco.device"`, which returns the device for
each worker.
The unit test added in this commit simulates loading of model weights
through a parameter transformation function.
---
python/tvm/exec/disco_worker.py | 56 +++++++++++++---
python/tvm/runtime/disco/session.py | 2 +-
python/tvm/testing/utils.py | 3 +
src/runtime/disco/builtin.cc | 6 ++
tests/python/disco/test_callback.py | 130 ++++++++++++++++++++++++++++++++++++
5 files changed, 188 insertions(+), 9 deletions(-)
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index b5eea6328d..76ce0ff993 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -19,44 +19,84 @@
import os
import sys
-from tvm import runtime as _ # pylint: disable=unused-import
+from typing import Callable
+
+import tvm
from tvm._ffi import get_global_func, register_func
from tvm.runtime import NDArray, ShapeTuple, String
from tvm.runtime.ndarray import array
-@register_func("tests.disco.add_one")
-def _add_one(x: int) -> int: # pylint: disable=invalid-name
+@register_func("tests.disco.add_one", override=True)
+def _add_one(x: int) -> int:
return x + 1
@register_func("tests.disco.add_one_float", override=True)
-def _add_one_float(x: float): # pylint: disable=invalid-name
+def _add_one_float(x: float):
return x + 0.5
@register_func("tests.disco.add_one_ndarray", override=True)
-def _add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name
+def _add_one_ndarray(x: NDArray) -> NDArray:
return array(x.numpy() + 1)
@register_func("tests.disco.str", override=True)
-def _str_func(x: str): # pylint: disable=invalid-name
+def _str_func(x: str):
return x + "_suffix"
@register_func("tests.disco.str_obj", override=True)
-def _str_obj_func(x: String): # pylint: disable=invalid-name
+def _str_obj_func(x: String):
assert isinstance(x, String)
return String(x + "_suffix")
@register_func("tests.disco.shape_tuple", override=True)
-def _shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name
+def _shape_tuple_func(x: ShapeTuple):
assert isinstance(x, ShapeTuple)
return ShapeTuple(list(x) + [4, 5])
+@register_func("tests.disco.test_callback", override=True)
+def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int],
NDArray]:
+ """For use in tests/python/disco/test_callback.py
+
+ This function simulates a callback to be used for lazy parameter
+ loading.
+
+ Parameters
+ ----------
+ device: tvm.runtime.Device
+
+ The device on which parameters should be located, when
+ returned by the callback function.
+
+ Returns
+ -------
+ fget_item: Callable[[str,int], NDArray]
+
+ A callback function that accepts a parameter's name and index,
+ and returns the specified parameter.
+
+ """
+ import numpy as np # pylint: disable=import-outside-toplevel
+
+ def fget_item(param_name: str, param_index: int) -> NDArray:
+ if param_index == 0:
+ assert param_name == "A"
+ arr = np.arange(16).reshape([4, 4]).astype("int32")
+ elif param_index == 1:
+ assert param_name == "B"
+ arr = np.arange(4).reshape([2, 2]).astype("float32")
+ else:
+ raise ValueError(f"Unexpected index {param_index}")
+ return tvm.nd.array(arr, device=device)
+
+ return fget_item
+
+
def main():
"""Main worker function"""
if len(sys.argv) != 5:
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index c54f646e17..1013d14a89 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -377,7 +377,7 @@ class ThreadedSession(Session):
class ProcessSession(Session):
"""A Disco session backed by pipe-based multi-processing."""
- def __init__(self, num_workers: int, entrypoint: str) -> None:
+ def __init__(self, num_workers: int, entrypoint: str =
"tvm.exec.disco_worker") -> None:
self.__init_handle_by_constructor__(
_ffi_api.SessionProcess, # type: ignore # pylint:
disable=no-member
num_workers,
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index d59aa964f9..6e23a84bc2 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -896,6 +896,9 @@ requires_cudnn = Feature("cudnn", "cuDNN",
cmake_flag="USE_CUDNN", parent_featur
# Mark a test as requiring the cuBLAS library.
requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS",
parent_features="cuda")
+# Mark a test as requiring NCCL support
+requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL",
parent_features="cuda")
+
# Mark a test as requiring the NVPTX compilation on the CUDA runtime
requires_nvptx = Feature(
"nvptx",
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 911fdaae3d..05961df9d5 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -123,6 +123,12 @@
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWo
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() ->
ShapeTuple {
return ShapeTuple({WorkerId()});
});
+TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() ->
int64_t {
+ return WorkerId();
+});
+TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device {
+ return DiscoWorker::ThreadLocal()->default_device;
+});
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/disco/test_callback.py
b/tests/python/disco/test_callback.py
new file mode 100644
index 0000000000..6e2dc9b747
--- /dev/null
+++ b/tests/python/disco/test_callback.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test sharded loader"""
+# pylint: disable=missing-docstring
+
+import pathlib
+import tempfile
+
+import numpy as np
+
+import tvm
+import tvm.testing
+
+from tvm.script import relax as R, tir as T
+
+
[email protected]_nccl
+def test_callback():
+ @R.function
+ def transform_params(
+ rank_arg: R.Prim(value="rank"),
+ fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
+ ):
+ """Simulate lazy loading of parameters in a callback
+
+ The output of a lazy parameter loading, which would accept a
+ callback to load the parameters.
+ """
+ rank = T.int64()
+
+ A = fget_item(R.str("A"), R.prim_value(0))
+ A = R.match_cast(A, R.Tensor([4, 4], "int32"))
+ A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) *
2])
+
+ B = fget_item(R.str("B"), R.prim_value(1))
+ B = R.match_cast(B, R.Tensor([2, 2], "float32"))
+ B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) *
1])
+
+ return (A, B)
+
+ pipeline = tvm.ir.transform.Sequential(
+ [
+ tvm.relax.transform.LegalizeOps(),
+ tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()),
+ ],
+ name="pipeline",
+ )
+
+ with tvm.target.Target("cuda"):
+ mod = tvm.IRModule.from_expr(transform_params)
+ mod = pipeline(mod)
+ built = tvm.relax.build(mod, "cuda")
+
+ num_shards = 2
+
+ session = tvm.runtime.disco.ProcessSession(num_workers=num_shards)
+ session.import_python_module("tvm.exec.disco_worker")
+ session.init_ccl("nccl", *range(num_shards))
+
+ worker_device = session.get_global_func("runtime.disco.device")()
+ worker_id = session.get_global_func("runtime.disco.worker_rank")()
+ callback_maker = session.get_global_func("tests.disco.test_callback")
+ fget_item = callback_maker(worker_device)
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_dir = pathlib.Path(temp_dir)
+
+ # TODO(Lunderberg): Update `disco.Session.load_vm_module` to
+ # allow a `tvm.runtime.Module` argument. This would avoid the
+ # need for a temporary file.
+ shlib_path = temp_dir.joinpath("libtemp.so")
+ built.export_library(shlib_path)
+ vm = session.load_vm_module(shlib_path.as_posix())
+ transform_params = vm["transform_params"]
+
+ params = transform_params(worker_id, fget_item)
+
+ # Worker 0 is the same PID as the controlling scope, so
+ # `debug_get_from_remote(0)` returns the NDArray containing
+ # the output.
+ params_gpu0 = params.debug_get_from_remote(0)
+ assert params_gpu0[0].device == tvm.cuda(0)
+ assert params_gpu0[1].device == tvm.cuda(0)
+ np.testing.assert_array_equal(
+ params_gpu0[0].numpy(),
+ [
+ [0, 1, 2, 3],
+ [4, 5, 6, 7],
+ ],
+ )
+ np.testing.assert_array_equal(
+ params_gpu0[1].numpy(),
+ [[0], [2]],
+ )
+
+ # Worker 1 is a different PID altogether, so
+ # `debug_get_from_remote(1)` returns a new NDArray within the
+ # calling scope's PID.
+ params_gpu1 = params.debug_get_from_remote(1)
+ assert params_gpu1[0].device == tvm.cpu()
+ assert params_gpu1[1].device == tvm.cpu()
+ np.testing.assert_array_equal(
+ params_gpu1[0].numpy(),
+ [
+ [8, 9, 10, 11],
+ [12, 13, 14, 15],
+ ],
+ )
+ np.testing.assert_array_equal(
+ params_gpu1[1].numpy(),
+ [[1], [3]],
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()