This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 3ec0ca5b0b [Disco] Expose functions to query the per-worker device/rank (#16639) 3ec0ca5b0b is described below commit 3ec0ca5b0b3941d9314cfada23dac3101cc163f7 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Mon Feb 26 04:06:15 2024 -0600 [Disco] Expose functions to query the per-worker device/rank (#16639) In addition to the PackedFunc `"runtime.disco.worker_id"`, which returns the worker ID wrapped in a `ShapeTuple`, this commit adds `"runtime.disco.worker_rank"`, which returns the worker ID without wrapping, and `"runtime.disco.device"`, which returns the device for each worker. The unit test added in this commit simulates loading of model weights through a parameter transformation function. --- python/tvm/exec/disco_worker.py | 56 +++++++++++++--- python/tvm/runtime/disco/session.py | 2 +- python/tvm/testing/utils.py | 3 + src/runtime/disco/builtin.cc | 6 ++ tests/python/disco/test_callback.py | 130 ++++++++++++++++++++++++++++++++++++ 5 files changed, 188 insertions(+), 9 deletions(-) diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py index b5eea6328d..76ce0ff993 100644 --- a/python/tvm/exec/disco_worker.py +++ b/python/tvm/exec/disco_worker.py @@ -19,44 +19,84 @@ import os import sys -from tvm import runtime as _ # pylint: disable=unused-import +from typing import Callable + +import tvm from tvm._ffi import get_global_func, register_func from tvm.runtime import NDArray, ShapeTuple, String from tvm.runtime.ndarray import array -@register_func("tests.disco.add_one") -def _add_one(x: int) -> int: # pylint: disable=invalid-name +@register_func("tests.disco.add_one", override=True) +def _add_one(x: int) -> int: return x + 1 @register_func("tests.disco.add_one_float", override=True) -def _add_one_float(x: float): # pylint: disable=invalid-name +def _add_one_float(x: float): return x + 0.5 @register_func("tests.disco.add_one_ndarray", override=True) -def _add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name +def _add_one_ndarray(x: NDArray) -> NDArray: return array(x.numpy() + 1) @register_func("tests.disco.str", override=True) -def _str_func(x: str): # pylint: disable=invalid-name +def _str_func(x: str): return x + "_suffix" @register_func("tests.disco.str_obj", override=True) -def _str_obj_func(x: String): # pylint: disable=invalid-name +def _str_obj_func(x: String): assert isinstance(x, String) return String(x + "_suffix") @register_func("tests.disco.shape_tuple", override=True) -def _shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name +def _shape_tuple_func(x: ShapeTuple): assert isinstance(x, ShapeTuple) return ShapeTuple(list(x) + [4, 5]) +@register_func("tests.disco.test_callback", override=True) +def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], NDArray]: + """For use in tests/python/disco/test_callback.py + + This function simulates a callback to be used for lazy parameter + loading. + + Parameters + ---------- + device: tvm.runtime.Device + + The device on which parameters should be located, when + returned by the callback function. + + Returns + ------- + fget_item: Callable[[str,int], NDArray] + + A callback function that accepts a parameter's name and index, + and returns the specified parameter. + + """ + import numpy as np # pylint: disable=import-outside-toplevel + + def fget_item(param_name: str, param_index: int) -> NDArray: + if param_index == 0: + assert param_name == "A" + arr = np.arange(16).reshape([4, 4]).astype("int32") + elif param_index == 1: + assert param_name == "B" + arr = np.arange(4).reshape([2, 2]).astype("float32") + else: + raise ValueError(f"Unexpected index {param_index}") + return tvm.nd.array(arr, device=device) + + return fget_item + + def main(): """Main worker function""" if len(sys.argv) != 5: diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index c54f646e17..1013d14a89 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -377,7 +377,7 @@ class ThreadedSession(Session): class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" - def __init__(self, num_workers: int, entrypoint: str) -> None: + def __init__(self, num_workers: int, entrypoint: str = "tvm.exec.disco_worker") -> None: self.__init_handle_by_constructor__( _ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member num_workers, diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index d59aa964f9..6e23a84bc2 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -896,6 +896,9 @@ requires_cudnn = Feature("cudnn", "cuDNN", cmake_flag="USE_CUDNN", parent_featur # Mark a test as requiring the cuBLAS library. requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS", parent_features="cuda") +# Mark a test as requiring NCCL support +requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL", parent_features="cuda") + # Mark a test as requiring the NVPTX compilation on the CUDA runtime requires_nvptx = Feature( "nvptx", diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 911fdaae3d..05961df9d5 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -123,6 +123,12 @@ TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWo TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple { return ShapeTuple({WorkerId()}); }); +TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t { + return WorkerId(); +}); +TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { + return DiscoWorker::ThreadLocal()->default_device; +}); } // namespace runtime } // namespace tvm diff --git a/tests/python/disco/test_callback.py b/tests/python/disco/test_callback.py new file mode 100644 index 0000000000..6e2dc9b747 --- /dev/null +++ b/tests/python/disco/test_callback.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test sharded loader""" +# pylint: disable=missing-docstring + +import pathlib +import tempfile + +import numpy as np + +import tvm +import tvm.testing + +from tvm.script import relax as R, tir as T + + +@tvm.testing.requires_nccl +def test_callback(): + @R.function + def transform_params( + rank_arg: R.Prim(value="rank"), + fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object), + ): + """Simulate lazy loading of parameters in a callback + + The output of a lazy parameter loading, which would accept a + callback to load the parameters. + """ + rank = T.int64() + + A = fget_item(R.str("A"), R.prim_value(0)) + A = R.match_cast(A, R.Tensor([4, 4], "int32")) + A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) * 2]) + + B = fget_item(R.str("B"), R.prim_value(1)) + B = R.match_cast(B, R.Tensor([2, 2], "float32")) + B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) * 1]) + + return (A, B) + + pipeline = tvm.ir.transform.Sequential( + [ + tvm.relax.transform.LegalizeOps(), + tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()), + ], + name="pipeline", + ) + + with tvm.target.Target("cuda"): + mod = tvm.IRModule.from_expr(transform_params) + mod = pipeline(mod) + built = tvm.relax.build(mod, "cuda") + + num_shards = 2 + + session = tvm.runtime.disco.ProcessSession(num_workers=num_shards) + session.import_python_module("tvm.exec.disco_worker") + session.init_ccl("nccl", *range(num_shards)) + + worker_device = session.get_global_func("runtime.disco.device")() + worker_id = session.get_global_func("runtime.disco.worker_rank")() + callback_maker = session.get_global_func("tests.disco.test_callback") + fget_item = callback_maker(worker_device) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + # TODO(Lunderberg): Update `disco.Session.load_vm_module` to + # allow a `tvm.runtime.Module` argument. This would avoid the + # need for a temporary file. + shlib_path = temp_dir.joinpath("libtemp.so") + built.export_library(shlib_path) + vm = session.load_vm_module(shlib_path.as_posix()) + transform_params = vm["transform_params"] + + params = transform_params(worker_id, fget_item) + + # Worker 0 is the same PID as the controlling scope, so + # `debug_get_from_remote(0)` returns the NDArray containing + # the output. + params_gpu0 = params.debug_get_from_remote(0) + assert params_gpu0[0].device == tvm.cuda(0) + assert params_gpu0[1].device == tvm.cuda(0) + np.testing.assert_array_equal( + params_gpu0[0].numpy(), + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + ], + ) + np.testing.assert_array_equal( + params_gpu0[1].numpy(), + [[0], [2]], + ) + + # Worker 1 is a different PID altogether, so + # `debug_get_from_remote(1)` returns a new NDArray within the + # calling scope's PID. + params_gpu1 = params.debug_get_from_remote(1) + assert params_gpu1[0].device == tvm.cpu() + assert params_gpu1[1].device == tvm.cpu() + np.testing.assert_array_equal( + params_gpu1[0].numpy(), + [ + [8, 9, 10, 11], + [12, 13, 14, 15], + ], + ) + np.testing.assert_array_equal( + params_gpu1[1].numpy(), + [[1], [3]], + ) + + +if __name__ == "__main__": + tvm.testing.main()