This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4956e4f260 [Disco] Add LoadAll method to Disco Shard Loader (#15673)
4956e4f260 is described below
commit 4956e4f260793fc18a818829d5903d5e6c6be62f
Author: Hongyi Jin <[email protected]>
AuthorDate: Fri Sep 8 00:05:42 2023 -0700
[Disco] Add LoadAll method to Disco Shard Loader (#15673)
This PR adds a `LoadAll` method to the Disco shard loader to load
parameters all at once.
Co-authored-by: Junru Shao <[email protected]>
---
src/runtime/disco/loader.cc | 52 +++++++++++++++++++++++++++++++-----
tests/python/disco/test_loader.py | 55 ++++++++++++++++++++++++++++++++++++---
tests/python/disco/test_nccl.py | 36 +++++++++++--------------
3 files changed, 112 insertions(+), 31 deletions(-)
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 6f9054296a..6aa6f23f8f 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -46,6 +46,8 @@ class ShardLoaderObj : public Object {
TypedPackedFunc<void(DLTensor*, int, DLTensor*)>
f_shard);
/*! \brief Load the i-th parameter */
NDArray Load(int weight_index) const;
+ /*! \brief Load all the parameters */
+ Array<NDArray> LoadAll() const;
/*! \brief Slice the given tensor at a specific dimension */
NDArray Shard(NDArray source, int dim, int num_slices) const;
@@ -63,6 +65,8 @@ class ShardLoaderObj : public Object {
NDArrayCacheMetadata metadata_;
/*! \brief Sharding information for each weight */
std::vector<ShardInfo> shard_info_;
+ /*! \brief Maps the name of a shard to its index */
+ std::unordered_map<std::string, int> param_name_to_index_;
/*! \brief A method to slice a 3-D tensor */
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard_;
/*! \brief The current file opened to load weights in it */
@@ -106,6 +110,7 @@ ObjectRef ShardLoaderObj::Create(const std::string&
path_to_metadata, const std:
for (const ParamRecord& param_record : file_record.records) {
const std::string& name = param_record.name;
int shard_id = shards.count(name) ? shards[name] : -1;
+ n->param_name_to_index_[name] = n->shard_info_.size();
n->shard_info_.push_back(ShardInfo{&file_record, ¶m_record,
shard_id});
}
}
@@ -124,7 +129,7 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int shard_idx = worker->worker_id;
Device device = worker->default_device;
- const auto& shard_info = shard_info_[weight_index];
+ const auto& shard_info = shard_info_.at(weight_index);
const ParamRecord* param = shard_info.param;
const FileRecord* file = shard_info.file;
int shard_dim = shard_info.shard_dim;
@@ -139,13 +144,39 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
auto f_load = [](NDArray param, const void* data, size_t nbytes) {
param.CopyFromBytes(data, nbytes);
};
- send = this->Shard(param->Load(device, &this->current_file_stream_,
f_load), shard_dim,
- num_shards);
+ if (shard_dim != -1) {
+ send = this->Shard(param->Load(device, &this->current_file_stream_,
f_load), shard_dim,
+ num_shards);
+ } else {
+ send = param->Load(device, &this->current_file_stream_, f_load);
+ }
+ }
+ if (shard_dim != -1) {
+ NDArray recv =
+ NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards),
param->dtype, device);
+ ScatterFromWorker0(send, recv);
+ return recv;
+ } else {
+ NDArray recv;
+ if (send.defined()) {
+ recv = NDArray(send.value());
+ } else {
+ recv = NDArray::Empty(param->shape, param->dtype, device);
+ }
+ return BroadcastFromWorker0(recv);
}
- NDArray recv =
- NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards),
param->dtype, device);
- ScatterFromWorker0(send, recv);
- return recv;
+}
+
+Array<NDArray> ShardLoaderObj::LoadAll() const {
+ int n = static_cast<int>(shard_info_.size());
+ Array<NDArray> shards;
+ shards.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ std::string param_name = "param_" + std::to_string(i);
+ int shard_id = this->param_name_to_index_.at(param_name);
+ shards.push_back(this->Load(shard_id));
+ }
+ return shards;
}
NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
@@ -187,5 +218,12 @@ TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
return loader->Load(IntegerFromShapeTuple(weight_index));
});
+TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll").set_body_typed([](ObjectRef
loader_obj) {
+ const auto* loader = loader_obj.as<ShardLoaderObj>();
+ CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: "
+ << loader_obj->GetTypeKey();
+ return loader->LoadAll();
+});
+
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/disco/test_loader.py
b/tests/python/disco/test_loader.py
index baa6745b00..c92bac6b46 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -43,14 +43,14 @@ def _create_loader(sess, path, param_dict, shard_info):
tvmjs.dump_ndarray_cache(param_dict, path, encode_format="raw")
with open(path_ndarray_cache, "r", encoding="utf-8") as i_f:
ndarray_cache = i_f.read()
- shard_with_numpy = sess.get_global_func("tests.disco.shard_with_numpy")
loader_create = sess.get_global_func("runtime.disco.ShardLoader")
+ shard_with_numpy = sess.get_global_func("tests.disco.shard_with_numpy")
loader = loader_create(path_ndarray_cache, ndarray_cache, shard_info,
shard_with_numpy)
return loader
def test_load_shard():
- devices = [1, 2]
+ devices = [0, 1]
param_dict = {
"x_0": np.random.uniform(size=[64, 128]).astype("float16"),
"x_1": np.random.uniform(size=[32, 128]).astype("float32"),
@@ -87,7 +87,7 @@ def test_load_shard():
def test_load_shard_in_relax():
- devices = [1, 2]
+ devices = [0, 1]
param_dict = {
"x_0": np.random.uniform(size=[64, 128]).astype("float16"),
"x_1": np.random.uniform(size=[32, 128]).astype("float32"),
@@ -173,6 +173,55 @@ def test_load_shard_in_relax():
)
+def test_load_shard_all():
+ devices = [0, 1]
+ param_dict = {
+ "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
+ "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
+ }
+ shard_info = json.dumps(
+ {
+ "param_0": 1,
+ "param_1": 0,
+ }
+ )
+ with tempfile.TemporaryDirectory() as path:
+ sess = di.ThreadedSession(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
+ loader = _create_loader(sess, path, param_dict, shard_info)
+ loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
+ params = loader_load(loader)
+ p_0 = params.debug_get_from_remote(0)
+ p_1 = params.debug_get_from_remote(1)
+ np.testing.assert_equal(param_dict["param_0"][:, 0:64], p_0[0].numpy())
+ np.testing.assert_equal(param_dict["param_0"][:, 64:128],
p_1[0].numpy())
+ np.testing.assert_equal(param_dict["param_1"][0:16, :], p_0[1].numpy())
+ np.testing.assert_equal(param_dict["param_1"][16:32, :],
p_1[1].numpy())
+
+
+def test_load_shard_broadcast():
+ devices = [0, 1]
+ param_dict = {
+ "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
+ "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
+ }
+ shard_info = "{}"
+ with tempfile.TemporaryDirectory() as path:
+ sess = di.ThreadedSession(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
+ loader = _create_loader(sess, path, param_dict, shard_info)
+ loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
+ params = loader_load(loader)
+ p_0 = params.debug_get_from_remote(0)
+ p_1 = params.debug_get_from_remote(1)
+ np.testing.assert_equal(param_dict["param_0"], p_0[0].numpy())
+ np.testing.assert_equal(param_dict["param_0"], p_1[0].numpy())
+ np.testing.assert_equal(param_dict["param_1"], p_0[1].numpy())
+ np.testing.assert_equal(param_dict["param_1"], p_1[1].numpy())
+
+
if __name__ == "__main__":
test_load_shard()
test_load_shard_in_relax()
+ test_load_shard_all()
+ test_load_shard_broadcast()
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index adf507b024..f0979c8e3c 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -29,20 +29,18 @@ from tvm.script import relax as R
def test_init():
- num_workers = 2
- devices = [1, 2]
+ devices = [0, 1]
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
def test_allreduce():
- num_workers = 2
- devices = [1, 2]
+ devices = [0, 1]
array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1,
dtype="float32").reshape(3, 4)
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array_1)
@@ -61,11 +59,10 @@ def test_allreduce():
def test_broadcast_from_worker0():
- num_workers = 2
- devices = [1, 2]
+ devices = [0, 1]
array = np.arange(12, dtype="float32").reshape(3, 4)
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array)
@@ -75,11 +72,10 @@ def test_broadcast_from_worker0():
def test_scatter():
- num_workers = 2
- devices = [1, 2]
+ devices = [0, 1]
array = np.arange(36, dtype="float32").reshape(3, 4, 3)
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
d_src = sess.empty((3, 4, 3), "float32")
d_dst = sess.empty((3, 3, 2), "float32")
@@ -119,8 +115,7 @@ def test_gather():
def test_mlp(): # pylint: disable=too-many-locals
- num_workers = 2
- devices = [1, 2]
+ devices = [0, 1]
# pylint: disable=invalid-name
@tvm.script.ir_module
@@ -195,7 +190,7 @@ def test_mlp(): # pylint: disable=too-many-locals
path = tmpdir + "/test.so"
relax_build(ShardedMLP, target).export_library(path)
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
mod = sess.load_vm_module(path)
@@ -217,15 +212,14 @@ def test_mlp(): # pylint: disable=too-many-locals
np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4)
-def test_attention(): # pylint: disable=too-many-locals
- num_workers = 2
- devices = [1, 2]
+def test_attention(): # pylint: disable=too-many-locals,too-many-statements
+ devices = [0, 1]
# pylint: disable=invalid-name
@tvm.script.ir_module
class Attention: # pylint: disable=too-few-public-methods
@R.function
- def main(
+ def main( # pylint: disable=too-many-locals
x: R.Tensor((1, 10, 128), "float32"),
Wq: R.Tensor((128, 512), "float32"),
Wk: R.Tensor((128, 512), "float32"),
@@ -265,7 +259,7 @@ def test_attention(): # pylint: disable=too-many-locals
@tvm.script.ir_module
class ShardedAttention: # pylint: disable=too-few-public-methods
@R.function
- def main(
+ def main( # pylint: disable=too-many-locals
x: R.Tensor((1, 10, 128), "float32"),
Wq: R.Tensor((128, 256), "float32"), # shard along axis 1
Wk: R.Tensor((128, 256), "float32"), # shard along axis 1
@@ -346,7 +340,7 @@ def test_attention(): # pylint: disable=too-many-locals
path = tmpdir + "/test.so"
relax_build(ShardedAttention, target).export_library(path)
- sess = di.ThreadedSession(num_workers=num_workers)
+ sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
mod = sess.load_vm_module(path)