This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bc10f769d8 [Unity] Add LoadParamOnWorker0 function in shard loader
(#16093)
bc10f769d8 is described below
commit bc10f769d8e1410acc07eb549429ed260e250c46
Author: Hongyi Jin <[email protected]>
AuthorDate: Fri Nov 10 08:48:20 2023 -0800
[Unity] Add LoadParamOnWorker0 function in shard loader (#16093)
In DistIR compilation flow, shard loading is implemented in a Relax func,
like
```
a = LoadParamOnWorker0(loader, index=0)
b = broadcast(a)
c = LoadParamOnWorker0(loader, index=1)
d = scatter_from_worker0(c)
```
LoadWholeParamOnWorker0 loads the unsharded param on worker0, and for other
workers returns an empty array.
This PR implements LoadWholeParamOnWorker0
---
src/runtime/disco/loader.cc | 39 +++++++++++++++++++++++++++++++++++++++
1 file changed, 39 insertions(+)
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index c8d7eeb2a4..c931baa942 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -48,6 +48,8 @@ class ShardLoaderObj : public Object {
/*! \brief Load the i-th parameter */
NDArray Load(int weight_index) const;
+ NDArray LoadParamOnWorker0(int weight_index) const;
+
/*! \brief Load all the parameters */
Array<NDArray> LoadAll() const;
@@ -164,6 +166,35 @@ std::string GetSiblingPath(const std::string& path, const
std::string& filename)
LOG(FATAL) << "ValueError: Cannot find the parent directory: " << path;
}
+NDArray ShardLoaderObj::LoadParamOnWorker0(int weight_index) const {
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ int worker_id = worker->worker_id;
+ Device device = worker->default_device;
+ int param_index = param_name_to_index_.at("param_" +
std::to_string(weight_index));
+ const ParamInfo& param_info = param_info_.at(param_index);
+ const ParamRecord* param = param_info.param;
+ const FileRecord* file = param_info.file;
+
+ auto load = [this, param, device, file]() {
+ if (file != current_file_) {
+ current_file_ = file;
+ std::string file_name = GetSiblingPath(this->metadata_.path,
file->data_path);
+ LoadBinaryFromFile(file_name, &this->current_file_stream_);
+ }
+ return param->Load(
+ device, &this->current_file_stream_,
+ [](NDArray param, const void* data, size_t nbytes) {
param.CopyFromBytes(data, nbytes); });
+ };
+
+ if (worker_id == 0) {
+ NDArray w = load();
+ return w;
+ } else {
+ NDArray w = NDArray::Empty(param->shape, param->dtype, device);
+ return w;
+ }
+}
+
std::tuple<int, int> ParseParamShardingInfo(const ParamRecord* param) {
// Given a name "param_shard-X-of-Y", return the integer values
// rank=(X-1) and world_size=Y.
@@ -337,5 +368,13 @@
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded")
return loader->LoadAllPresharded();
});
+TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0")
+ .set_body_typed([](ObjectRef loader_obj, int param_index) {
+ const auto* loader = loader_obj.as<ShardLoaderObj>();
+ CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but
gets: "
+ << loader_obj->GetTypeKey();
+ return loader->LoadParamOnWorker0(param_index);
+ });
+
} // namespace runtime
} // namespace tvm