This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ec0ca5b0b [Disco] Expose functions to query the per-worker 
device/rank (#16639)
3ec0ca5b0b is described below

commit 3ec0ca5b0b3941d9314cfada23dac3101cc163f7
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Mon Feb 26 04:06:15 2024 -0600

    [Disco] Expose functions to query the per-worker device/rank (#16639)
    
    In addition to the PackedFunc `"runtime.disco.worker_id"`, which
    returns the worker ID wrapped in a `ShapeTuple`, this commit adds
    `"runtime.disco.worker_rank"`, which returns the worker ID without
    wrapping, and `"runtime.disco.device"`, which returns the device for
    each worker.
    
    The unit test added in this commit simulates loading of model weights
    through a parameter transformation function.
---
 python/tvm/exec/disco_worker.py     |  56 +++++++++++++---
 python/tvm/runtime/disco/session.py |   2 +-
 python/tvm/testing/utils.py         |   3 +
 src/runtime/disco/builtin.cc        |   6 ++
 tests/python/disco/test_callback.py | 130 ++++++++++++++++++++++++++++++++++++
 5 files changed, 188 insertions(+), 9 deletions(-)

diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index b5eea6328d..76ce0ff993 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -19,44 +19,84 @@
 import os
 import sys
 
-from tvm import runtime as _  # pylint: disable=unused-import
+from typing import Callable
+
+import tvm
 from tvm._ffi import get_global_func, register_func
 from tvm.runtime import NDArray, ShapeTuple, String
 from tvm.runtime.ndarray import array
 
 
-@register_func("tests.disco.add_one")
-def _add_one(x: int) -> int:  # pylint: disable=invalid-name
+@register_func("tests.disco.add_one", override=True)
+def _add_one(x: int) -> int:
     return x + 1
 
 
 @register_func("tests.disco.add_one_float", override=True)
-def _add_one_float(x: float):  # pylint: disable=invalid-name
+def _add_one_float(x: float):
     return x + 0.5
 
 
 @register_func("tests.disco.add_one_ndarray", override=True)
-def _add_one_ndarray(x: NDArray) -> NDArray:  # pylint: disable=invalid-name
+def _add_one_ndarray(x: NDArray) -> NDArray:
     return array(x.numpy() + 1)
 
 
 @register_func("tests.disco.str", override=True)
-def _str_func(x: str):  # pylint: disable=invalid-name
+def _str_func(x: str):
     return x + "_suffix"
 
 
 @register_func("tests.disco.str_obj", override=True)
-def _str_obj_func(x: String):  # pylint: disable=invalid-name
+def _str_obj_func(x: String):
     assert isinstance(x, String)
     return String(x + "_suffix")
 
 
 @register_func("tests.disco.shape_tuple", override=True)
-def _shape_tuple_func(x: ShapeTuple):  # pylint: disable=invalid-name
+def _shape_tuple_func(x: ShapeTuple):
     assert isinstance(x, ShapeTuple)
     return ShapeTuple(list(x) + [4, 5])
 
 
+@register_func("tests.disco.test_callback", override=True)
+def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], 
NDArray]:
+    """For use in tests/python/disco/test_callback.py
+
+    This function simulates a callback to be used for lazy parameter
+    loading.
+
+    Parameters
+    ----------
+    device: tvm.runtime.Device
+
+        The device on which parameters should be located, when
+        returned by the callback function.
+
+    Returns
+    -------
+    fget_item: Callable[[str,int], NDArray]
+
+        A callback function that accepts a parameter's name and index,
+        and returns the specified parameter.
+
+    """
+    import numpy as np  # pylint: disable=import-outside-toplevel
+
+    def fget_item(param_name: str, param_index: int) -> NDArray:
+        if param_index == 0:
+            assert param_name == "A"
+            arr = np.arange(16).reshape([4, 4]).astype("int32")
+        elif param_index == 1:
+            assert param_name == "B"
+            arr = np.arange(4).reshape([2, 2]).astype("float32")
+        else:
+            raise ValueError(f"Unexpected index {param_index}")
+        return tvm.nd.array(arr, device=device)
+
+    return fget_item
+
+
 def main():
     """Main worker function"""
     if len(sys.argv) != 5:
diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index c54f646e17..1013d14a89 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -377,7 +377,7 @@ class ThreadedSession(Session):
 class ProcessSession(Session):
     """A Disco session backed by pipe-based multi-processing."""
 
-    def __init__(self, num_workers: int, entrypoint: str) -> None:
+    def __init__(self, num_workers: int, entrypoint: str = 
"tvm.exec.disco_worker") -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.SessionProcess,  # type: ignore # pylint: 
disable=no-member
             num_workers,
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index d59aa964f9..6e23a84bc2 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -896,6 +896,9 @@ requires_cudnn = Feature("cudnn", "cuDNN", 
cmake_flag="USE_CUDNN", parent_featur
 # Mark a test as requiring the cuBLAS library.
 requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS", 
parent_features="cuda")
 
+# Mark a test as requiring NCCL support
+requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL", 
parent_features="cuda")
+
 # Mark a test as requiring the NVPTX compilation on the CUDA runtime
 requires_nvptx = Feature(
     "nvptx",
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 911fdaae3d..05961df9d5 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -123,6 +123,12 @@ 
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWo
 TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> 
ShapeTuple {
   return ShapeTuple({WorkerId()});
 });
+TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> 
int64_t {
+  return WorkerId();
+});
+TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device {
+  return DiscoWorker::ThreadLocal()->default_device;
+});
 
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/disco/test_callback.py 
b/tests/python/disco/test_callback.py
new file mode 100644
index 0000000000..6e2dc9b747
--- /dev/null
+++ b/tests/python/disco/test_callback.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test sharded loader"""
+# pylint: disable=missing-docstring
+
+import pathlib
+import tempfile
+
+import numpy as np
+
+import tvm
+import tvm.testing
+
+from tvm.script import relax as R, tir as T
+
+
+@tvm.testing.requires_nccl
+def test_callback():
+    @R.function
+    def transform_params(
+        rank_arg: R.Prim(value="rank"),
+        fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
+    ):
+        """Simulate lazy loading of parameters in a callback
+
+        The output of a lazy parameter loading, which would accept a
+        callback to load the parameters.
+        """
+        rank = T.int64()
+
+        A = fget_item(R.str("A"), R.prim_value(0))
+        A = R.match_cast(A, R.Tensor([4, 4], "int32"))
+        A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) * 
2])
+
+        B = fget_item(R.str("B"), R.prim_value(1))
+        B = R.match_cast(B, R.Tensor([2, 2], "float32"))
+        B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) * 
1])
+
+        return (A, B)
+
+    pipeline = tvm.ir.transform.Sequential(
+        [
+            tvm.relax.transform.LegalizeOps(),
+            tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()),
+        ],
+        name="pipeline",
+    )
+
+    with tvm.target.Target("cuda"):
+        mod = tvm.IRModule.from_expr(transform_params)
+        mod = pipeline(mod)
+        built = tvm.relax.build(mod, "cuda")
+
+    num_shards = 2
+
+    session = tvm.runtime.disco.ProcessSession(num_workers=num_shards)
+    session.import_python_module("tvm.exec.disco_worker")
+    session.init_ccl("nccl", *range(num_shards))
+
+    worker_device = session.get_global_func("runtime.disco.device")()
+    worker_id = session.get_global_func("runtime.disco.worker_rank")()
+    callback_maker = session.get_global_func("tests.disco.test_callback")
+    fget_item = callback_maker(worker_device)
+
+    with tempfile.TemporaryDirectory() as temp_dir:
+        temp_dir = pathlib.Path(temp_dir)
+
+        # TODO(Lunderberg): Update `disco.Session.load_vm_module` to
+        # allow a `tvm.runtime.Module` argument.  This would avoid the
+        # need for a temporary file.
+        shlib_path = temp_dir.joinpath("libtemp.so")
+        built.export_library(shlib_path)
+        vm = session.load_vm_module(shlib_path.as_posix())
+        transform_params = vm["transform_params"]
+
+        params = transform_params(worker_id, fget_item)
+
+        # Worker 0 is the same PID as the controlling scope, so
+        # `debug_get_from_remote(0)` returns the NDArray containing
+        # the output.
+        params_gpu0 = params.debug_get_from_remote(0)
+        assert params_gpu0[0].device == tvm.cuda(0)
+        assert params_gpu0[1].device == tvm.cuda(0)
+        np.testing.assert_array_equal(
+            params_gpu0[0].numpy(),
+            [
+                [0, 1, 2, 3],
+                [4, 5, 6, 7],
+            ],
+        )
+        np.testing.assert_array_equal(
+            params_gpu0[1].numpy(),
+            [[0], [2]],
+        )
+
+        # Worker 1 is a different PID altogether, so
+        # `debug_get_from_remote(1)` returns a new NDArray within the
+        # calling scope's PID.
+        params_gpu1 = params.debug_get_from_remote(1)
+        assert params_gpu1[0].device == tvm.cpu()
+        assert params_gpu1[1].device == tvm.cpu()
+        np.testing.assert_array_equal(
+            params_gpu1[0].numpy(),
+            [
+                [8, 9, 10, 11],
+                [12, 13, 14, 15],
+            ],
+        )
+        np.testing.assert_array_equal(
+            params_gpu1[1].numpy(),
+            [[1], [3]],
+        )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to