This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4956e4f260 [Disco] Add LoadAll method to Disco Shard Loader (#15673)
4956e4f260 is described below

commit 4956e4f260793fc18a818829d5903d5e6c6be62f
Author: Hongyi Jin <[email protected]>
AuthorDate: Fri Sep 8 00:05:42 2023 -0700

    [Disco] Add LoadAll method to Disco Shard Loader (#15673)
    
    This PR adds a `LoadAll` method to the Disco shard loader to load
    parameters all at once.
    
    Co-authored-by: Junru Shao <[email protected]>
---
 src/runtime/disco/loader.cc       | 52 +++++++++++++++++++++++++++++++-----
 tests/python/disco/test_loader.py | 55 ++++++++++++++++++++++++++++++++++++---
 tests/python/disco/test_nccl.py   | 36 +++++++++++--------------
 3 files changed, 112 insertions(+), 31 deletions(-)

diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 6f9054296a..6aa6f23f8f 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -46,6 +46,8 @@ class ShardLoaderObj : public Object {
                           TypedPackedFunc<void(DLTensor*, int, DLTensor*)> 
f_shard);
   /*! \brief Load the i-th parameter */
   NDArray Load(int weight_index) const;
+  /*! \brief Load all the parameters */
+  Array<NDArray> LoadAll() const;
   /*! \brief Slice the given tensor at a specific dimension */
   NDArray Shard(NDArray source, int dim, int num_slices) const;
 
@@ -63,6 +65,8 @@ class ShardLoaderObj : public Object {
   NDArrayCacheMetadata metadata_;
   /*! \brief Sharding information for each weight */
   std::vector<ShardInfo> shard_info_;
+  /*! \brief Maps the name of a shard to its index */
+  std::unordered_map<std::string, int> param_name_to_index_;
   /*! \brief A method to slice a 3-D tensor */
   TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard_;
   /*! \brief The current file opened to load weights in it */
@@ -106,6 +110,7 @@ ObjectRef ShardLoaderObj::Create(const std::string& 
path_to_metadata, const std:
     for (const ParamRecord& param_record : file_record.records) {
       const std::string& name = param_record.name;
       int shard_id = shards.count(name) ? shards[name] : -1;
+      n->param_name_to_index_[name] = n->shard_info_.size();
       n->shard_info_.push_back(ShardInfo{&file_record, &param_record, 
shard_id});
     }
   }
@@ -124,7 +129,7 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
   DiscoWorker* worker = DiscoWorker::ThreadLocal();
   int shard_idx = worker->worker_id;
   Device device = worker->default_device;
-  const auto& shard_info = shard_info_[weight_index];
+  const auto& shard_info = shard_info_.at(weight_index);
   const ParamRecord* param = shard_info.param;
   const FileRecord* file = shard_info.file;
   int shard_dim = shard_info.shard_dim;
@@ -139,13 +144,39 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
     auto f_load = [](NDArray param, const void* data, size_t nbytes) {
       param.CopyFromBytes(data, nbytes);
     };
-    send = this->Shard(param->Load(device, &this->current_file_stream_, 
f_load), shard_dim,
-                       num_shards);
+    if (shard_dim != -1) {
+      send = this->Shard(param->Load(device, &this->current_file_stream_, 
f_load), shard_dim,
+                         num_shards);
+    } else {
+      send = param->Load(device, &this->current_file_stream_, f_load);
+    }
+  }
+  if (shard_dim != -1) {
+    NDArray recv =
+        NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards), 
param->dtype, device);
+    ScatterFromWorker0(send, recv);
+    return recv;
+  } else {
+    NDArray recv;
+    if (send.defined()) {
+      recv = NDArray(send.value());
+    } else {
+      recv = NDArray::Empty(param->shape, param->dtype, device);
+    }
+    return BroadcastFromWorker0(recv);
   }
-  NDArray recv =
-      NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards), 
param->dtype, device);
-  ScatterFromWorker0(send, recv);
-  return recv;
+}
+
+Array<NDArray> ShardLoaderObj::LoadAll() const {
+  int n = static_cast<int>(shard_info_.size());
+  Array<NDArray> shards;
+  shards.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    std::string param_name = "param_" + std::to_string(i);
+    int shard_id = this->param_name_to_index_.at(param_name);
+    shards.push_back(this->Load(shard_id));
+  }
+  return shards;
 }
 
 NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
@@ -187,5 +218,12 @@ TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
       return loader->Load(IntegerFromShapeTuple(weight_index));
     });
 
+TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll").set_body_typed([](ObjectRef
 loader_obj) {
+  const auto* loader = loader_obj.as<ShardLoaderObj>();
+  CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: "
+                           << loader_obj->GetTypeKey();
+  return loader->LoadAll();
+});
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/disco/test_loader.py 
b/tests/python/disco/test_loader.py
index baa6745b00..c92bac6b46 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -43,14 +43,14 @@ def _create_loader(sess, path, param_dict, shard_info):
     tvmjs.dump_ndarray_cache(param_dict, path, encode_format="raw")
     with open(path_ndarray_cache, "r", encoding="utf-8") as i_f:
         ndarray_cache = i_f.read()
-    shard_with_numpy = sess.get_global_func("tests.disco.shard_with_numpy")
     loader_create = sess.get_global_func("runtime.disco.ShardLoader")
+    shard_with_numpy = sess.get_global_func("tests.disco.shard_with_numpy")
     loader = loader_create(path_ndarray_cache, ndarray_cache, shard_info, 
shard_with_numpy)
     return loader
 
 
 def test_load_shard():
-    devices = [1, 2]
+    devices = [0, 1]
     param_dict = {
         "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
         "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
@@ -87,7 +87,7 @@ def test_load_shard():
 
 
 def test_load_shard_in_relax():
-    devices = [1, 2]
+    devices = [0, 1]
     param_dict = {
         "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
         "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
@@ -173,6 +173,55 @@ def test_load_shard_in_relax():
         )
 
 
+def test_load_shard_all():
+    devices = [0, 1]
+    param_dict = {
+        "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
+        "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
+    }
+    shard_info = json.dumps(
+        {
+            "param_0": 1,
+            "param_1": 0,
+        }
+    )
+    with tempfile.TemporaryDirectory() as path:
+        sess = di.ThreadedSession(num_workers=len(devices))
+        sess.init_ccl("nccl", *devices)
+        loader = _create_loader(sess, path, param_dict, shard_info)
+        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
+        params = loader_load(loader)
+        p_0 = params.debug_get_from_remote(0)
+        p_1 = params.debug_get_from_remote(1)
+        np.testing.assert_equal(param_dict["param_0"][:, 0:64], p_0[0].numpy())
+        np.testing.assert_equal(param_dict["param_0"][:, 64:128], 
p_1[0].numpy())
+        np.testing.assert_equal(param_dict["param_1"][0:16, :], p_0[1].numpy())
+        np.testing.assert_equal(param_dict["param_1"][16:32, :], 
p_1[1].numpy())
+
+
+def test_load_shard_broadcast():
+    devices = [0, 1]
+    param_dict = {
+        "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
+        "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
+    }
+    shard_info = "{}"
+    with tempfile.TemporaryDirectory() as path:
+        sess = di.ThreadedSession(num_workers=len(devices))
+        sess.init_ccl("nccl", *devices)
+        loader = _create_loader(sess, path, param_dict, shard_info)
+        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
+        params = loader_load(loader)
+        p_0 = params.debug_get_from_remote(0)
+        p_1 = params.debug_get_from_remote(1)
+        np.testing.assert_equal(param_dict["param_0"], p_0[0].numpy())
+        np.testing.assert_equal(param_dict["param_0"], p_1[0].numpy())
+        np.testing.assert_equal(param_dict["param_1"], p_0[1].numpy())
+        np.testing.assert_equal(param_dict["param_1"], p_1[1].numpy())
+
+
 if __name__ == "__main__":
     test_load_shard()
     test_load_shard_in_relax()
+    test_load_shard_all()
+    test_load_shard_broadcast()
diff --git a/tests/python/disco/test_nccl.py b/tests/python/disco/test_nccl.py
index adf507b024..f0979c8e3c 100644
--- a/tests/python/disco/test_nccl.py
+++ b/tests/python/disco/test_nccl.py
@@ -29,20 +29,18 @@ from tvm.script import relax as R
 
 
 def test_init():
-    num_workers = 2
-    devices = [1, 2]
+    devices = [0, 1]
 
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = di.ThreadedSession(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
 
 
 def test_allreduce():
-    num_workers = 2
-    devices = [1, 2]
+    devices = [0, 1]
     array_1 = np.arange(12, dtype="float32").reshape(3, 4)
     array_2 = np.arange(start=1, stop=-11, step=-1, 
dtype="float32").reshape(3, 4)
 
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = di.ThreadedSession(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
     d_array = sess.empty((3, 4), "float32")
     d_array.debug_copy_from(0, array_1)
@@ -61,11 +59,10 @@ def test_allreduce():
 
 
 def test_broadcast_from_worker0():
-    num_workers = 2
-    devices = [1, 2]
+    devices = [0, 1]
     array = np.arange(12, dtype="float32").reshape(3, 4)
 
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = di.ThreadedSession(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
     d_array = sess.empty((3, 4), "float32")
     d_array.debug_copy_from(0, array)
@@ -75,11 +72,10 @@ def test_broadcast_from_worker0():
 
 
 def test_scatter():
-    num_workers = 2
-    devices = [1, 2]
+    devices = [0, 1]
     array = np.arange(36, dtype="float32").reshape(3, 4, 3)
 
-    sess = di.ThreadedSession(num_workers=num_workers)
+    sess = di.ThreadedSession(num_workers=len(devices))
     sess.init_ccl("nccl", *devices)
     d_src = sess.empty((3, 4, 3), "float32")
     d_dst = sess.empty((3, 3, 2), "float32")
@@ -119,8 +115,7 @@ def test_gather():
 
 
 def test_mlp():  # pylint: disable=too-many-locals
-    num_workers = 2
-    devices = [1, 2]
+    devices = [0, 1]
 
     # pylint: disable=invalid-name
     @tvm.script.ir_module
@@ -195,7 +190,7 @@ def test_mlp():  # pylint: disable=too-many-locals
         path = tmpdir + "/test.so"
         relax_build(ShardedMLP, target).export_library(path)
 
-        sess = di.ThreadedSession(num_workers=num_workers)
+        sess = di.ThreadedSession(num_workers=len(devices))
         sess.init_ccl("nccl", *devices)
         mod = sess.load_vm_module(path)
 
@@ -217,15 +212,14 @@ def test_mlp():  # pylint: disable=too-many-locals
     np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4)
 
 
-def test_attention():  # pylint: disable=too-many-locals
-    num_workers = 2
-    devices = [1, 2]
+def test_attention():  # pylint: disable=too-many-locals,too-many-statements
+    devices = [0, 1]
 
     # pylint: disable=invalid-name
     @tvm.script.ir_module
     class Attention:  # pylint: disable=too-few-public-methods
         @R.function
-        def main(
+        def main(  # pylint: disable=too-many-locals
             x: R.Tensor((1, 10, 128), "float32"),
             Wq: R.Tensor((128, 512), "float32"),
             Wk: R.Tensor((128, 512), "float32"),
@@ -265,7 +259,7 @@ def test_attention():  # pylint: disable=too-many-locals
     @tvm.script.ir_module
     class ShardedAttention:  # pylint: disable=too-few-public-methods
         @R.function
-        def main(
+        def main(  # pylint: disable=too-many-locals
             x: R.Tensor((1, 10, 128), "float32"),
             Wq: R.Tensor((128, 256), "float32"),  # shard along axis 1
             Wk: R.Tensor((128, 256), "float32"),  # shard along axis 1
@@ -346,7 +340,7 @@ def test_attention():  # pylint: disable=too-many-locals
         path = tmpdir + "/test.so"
         relax_build(ShardedAttention, target).export_library(path)
 
-        sess = di.ThreadedSession(num_workers=num_workers)
+        sess = di.ThreadedSession(num_workers=len(devices))
         sess.init_ccl("nccl", *devices)
         mod = sess.load_vm_module(path)
 

Reply via email to