This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bc10f769d8 [Unity] Add LoadParamOnWorker0 function in shard loader 
(#16093)
bc10f769d8 is described below

commit bc10f769d8e1410acc07eb549429ed260e250c46
Author: Hongyi Jin <[email protected]>
AuthorDate: Fri Nov 10 08:48:20 2023 -0800

    [Unity] Add LoadParamOnWorker0 function in shard loader (#16093)
    
    In DistIR compilation flow, shard loading is implemented in a Relax func, 
like
    
    ```
    a = LoadParamOnWorker0(loader, index=0)
    b = broadcast(a)
    c = LoadParamOnWorker0(loader, index=1)
    d = scatter_from_worker0(c)
    ```
    
    LoadWholeParamOnWorker0 loads the unsharded param on worker0, and for other 
workers returns an empty array.
    
    This PR implements LoadWholeParamOnWorker0
---
 src/runtime/disco/loader.cc | 39 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 39 insertions(+)

diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index c8d7eeb2a4..c931baa942 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -48,6 +48,8 @@ class ShardLoaderObj : public Object {
   /*! \brief Load the i-th parameter */
   NDArray Load(int weight_index) const;
 
+  NDArray LoadParamOnWorker0(int weight_index) const;
+
   /*! \brief Load all the parameters */
   Array<NDArray> LoadAll() const;
 
@@ -164,6 +166,35 @@ std::string GetSiblingPath(const std::string& path, const 
std::string& filename)
   LOG(FATAL) << "ValueError: Cannot find the parent directory: " << path;
 }
 
+NDArray ShardLoaderObj::LoadParamOnWorker0(int weight_index) const {
+  DiscoWorker* worker = DiscoWorker::ThreadLocal();
+  int worker_id = worker->worker_id;
+  Device device = worker->default_device;
+  int param_index = param_name_to_index_.at("param_" + 
std::to_string(weight_index));
+  const ParamInfo& param_info = param_info_.at(param_index);
+  const ParamRecord* param = param_info.param;
+  const FileRecord* file = param_info.file;
+
+  auto load = [this, param, device, file]() {
+    if (file != current_file_) {
+      current_file_ = file;
+      std::string file_name = GetSiblingPath(this->metadata_.path, 
file->data_path);
+      LoadBinaryFromFile(file_name, &this->current_file_stream_);
+    }
+    return param->Load(
+        device, &this->current_file_stream_,
+        [](NDArray param, const void* data, size_t nbytes) { 
param.CopyFromBytes(data, nbytes); });
+  };
+
+  if (worker_id == 0) {
+    NDArray w = load();
+    return w;
+  } else {
+    NDArray w = NDArray::Empty(param->shape, param->dtype, device);
+    return w;
+  }
+}
+
 std::tuple<int, int> ParseParamShardingInfo(const ParamRecord* param) {
   // Given a name "param_shard-X-of-Y", return the integer values
   // rank=(X-1) and world_size=Y.
@@ -337,5 +368,13 @@ 
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded")
       return loader->LoadAllPresharded();
     });
 
+TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0")
+    .set_body_typed([](ObjectRef loader_obj, int param_index) {
+      const auto* loader = loader_obj.as<ShardLoaderObj>();
+      CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but 
gets: "
+                               << loader_obj->GetTypeKey();
+      return loader->LoadParamOnWorker0(param_index);
+    });
+
 }  // namespace runtime
 }  // namespace tvm

Reply via email to