This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e359e7a210 [Disco] Add loader for presharded params. (#15957)
e359e7a210 is described below
commit e359e7a210b144327eb3cb5e4f0c4d1968568407
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Nov 9 08:05:45 2023 -0600
[Disco] Add loader for presharded params. (#15957)
* [Disco] Add loader for presharded params.
Prior to this commit, sharding of model weights was always performed
when initializing the model. This could cause slow initialization,
especially for larger numbers of GPUs, as all model weights are
initially transferred to GPU-0, before being scattered to all workers.
This commit updates the `tvm::runtime::ShardLoaderObj` to also allow
loading of pre-sharded model weights. With pre-sharded model weights,
the tensors are sharded while the model is being built, and each
worker independently loads the specific model weights that it
requires.
* Update based on review comments.
* Removed commented-out print statements
---------
Co-authored-by: Chris Sullivan <[email protected]>
---
src/runtime/disco/loader.cc | 143 ++++++++++++++++++++++++++++++++++----
src/runtime/disco/nccl/nccl.cc | 2 +-
src/runtime/disco/worker.cc | 2 -
tests/python/disco/test_loader.py | 116 +++++++++++++++++++++++++++++--
4 files changed, 241 insertions(+), 22 deletions(-)
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 7670cc5254..c8d7eeb2a4 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -47,11 +47,21 @@ class ShardLoaderObj : public Object {
std::string shard_info, Module mod);
/*! \brief Load the i-th parameter */
NDArray Load(int weight_index) const;
+
/*! \brief Load all the parameters */
Array<NDArray> LoadAll() const;
NDArray ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const
NDArray& param) const;
+ /*! \brief Load all the pre-sharded parameters */
+ Array<NDArray> LoadAllPresharded() const;
+
+ /*! \brief Load the i-th parameter from presharded binaries */
+ NDArray LoadPresharded(int weight_index) const;
+
+ /*! \brief Slice the given tensor at a specific dimension */
+ NDArray Shard(NDArray source, int dim, int num_slices) const;
+
static constexpr const char* _type_key = "runtime.disco.ShardLoader";
TVM_DECLARE_FINAL_OBJECT_INFO(ShardLoaderObj, Object);
@@ -74,6 +84,19 @@ class ShardLoaderObj : public Object {
mutable const FileRecord* current_file_;
/*! \brief The context of the current file to be loaded from */
mutable std::string current_file_stream_;
+
+ private:
+ /*! \brief Load the i-th parameter without post-processing
+ *
+ * This function should not be called externally, as it does not
+ * check for post-processing that may be required. Instead, the
+ * public function `Load` or `LoadPresharded` should be called.
+ *
+ * \param weight_index The index of NDArray tensor to load
+ *
+ * \returns The full tensor at the specified index
+ */
+ NDArray LoadDirect(int weight_index) const;
};
TVM_REGISTER_OBJECT_TYPE(ShardLoaderObj);
@@ -141,6 +164,46 @@ std::string GetSiblingPath(const std::string& path, const
std::string& filename)
LOG(FATAL) << "ValueError: Cannot find the parent directory: " << path;
}
+std::tuple<int, int> ParseParamShardingInfo(const ParamRecord* param) {
+ // Given a name "param_shard-X-of-Y", return the integer values
+ // rank=(X-1) and world_size=Y.
+
+ std::string name = param->name;
+ size_t pos1 = name.rfind("-of-");
+ CHECK(pos1 != std::string::npos)
+ << "Attempt to read num_shards from unexpected param name: " << name;
+ size_t pos2 = name.rfind("_shard-", pos1 - 1);
+ CHECK(pos2 != std::string::npos)
+ << "Attempt to read sharded worker_id from unexpected param name: " <<
name;
+
+ int num_shards = std::stoi(name.substr(pos1 + 4));
+ int worker_id = std::stoi(name.substr(pos2 + 7, pos1 - pos2 - 7)) - 1;
+
+ CHECK_GT(num_shards, 1);
+ CHECK_GE(worker_id, 0);
+ CHECK_LT(worker_id, num_shards);
+
+ return {num_shards, worker_id};
+}
+
+NDArray ShardLoaderObj::LoadDirect(int weight_index) const {
+ const ParamInfo& param_info = param_info_.at(weight_index);
+ const ParamRecord* param = param_info.param;
+ const FileRecord* file = param_info.file;
+
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ Device device = worker->default_device;
+
+ if (file != current_file_) {
+ current_file_ = file;
+ std::string file_name = GetSiblingPath(this->metadata_.path,
file->data_path);
+ LoadBinaryFromFile(file_name, &this->current_file_stream_);
+ }
+ return param->Load(
+ device, &this->current_file_stream_,
+ [](NDArray param, const void* data, size_t nbytes) {
param.CopyFromBytes(data, nbytes); });
+}
+
NDArray ShardLoaderObj::Load(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int worker_id = worker->worker_id;
@@ -148,18 +211,6 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
Device device = worker->default_device;
const ParamInfo& param_info = param_info_.at(weight_index);
const ParamRecord* param = param_info.param;
- const FileRecord* file = param_info.file;
-
- auto load = [this, param, device, file]() {
- if (file != current_file_) {
- current_file_ = file;
- std::string file_name = GetSiblingPath(this->metadata_.path,
file->data_path);
- LoadBinaryFromFile(file_name, &this->current_file_stream_);
- }
- return param->Load(
- device, &this->current_file_stream_,
- [](NDArray param, const void* data, size_t nbytes) {
param.CopyFromBytes(data, nbytes); });
- };
bool needs_sharding = !param_info.shard_info.funcs.empty();
if (needs_sharding) {
@@ -171,7 +222,7 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
<< "number of shards, but got: " << shape << " and num_shards = " <<
num_shards;
NDArray recv = NDArray::Empty(ShapeTuple(shape.begin() + 1, shape.end()),
dtype, device);
if (worker_id == 0) {
- NDArray w = load();
+ NDArray w = LoadDirect(weight_index);
for (const ShardInfo::ShardFunc& shard_func :
param_info.shard_info.funcs) {
w = this->ApplyShardFunc(shard_func, w);
}
@@ -182,7 +233,7 @@ NDArray ShardLoaderObj::Load(int weight_index) const {
return recv;
} else {
if (worker_id == 0) {
- NDArray w = load();
+ NDArray w = LoadDirect(weight_index);
BroadcastFromWorker0(w, w);
return w;
} else {
@@ -206,6 +257,55 @@ Array<NDArray> ShardLoaderObj::LoadAll() const {
return shards;
}
+NDArray ShardLoaderObj::LoadPresharded(int weight_index) const {
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ int worker_id = worker->worker_id;
+ int num_shards = worker->num_workers;
+ size_t num_weights = param_info_.size() / num_shards;
+ size_t index = worker_id * num_weights + weight_index;
+ CHECK(index < param_info_.size())
+ << "Loading param " << weight_index << " for shard " << worker_id << "
at position " << index
+ << " is out of bounds for the provided ndarray chace.";
+
+ const auto& shard_info = param_info_[index];
+ const ParamRecord* param = shard_info.param;
+ const FileRecord* file = shard_info.file;
+
+ auto [p_num_shards, p_worker_id] = ParseParamShardingInfo(param);
+ CHECK_EQ(num_shards, p_num_shards)
+ << "Runtime number of shards (" << num_shards
+ << ") does not match number of compiled shards (" << p_num_shards << "):
" << param->name
+ << " loaded from " << file->data_path;
+ CHECK_EQ(worker_id, p_worker_id)
+ << "Runtime worker_id (" << worker_id << ") does not match worker_id of
compiled shard ("
+ << p_worker_id << "): " << param->name << " loaded from " <<
file->data_path;
+
+ return LoadDirect(index);
+}
+
+Array<NDArray> ShardLoaderObj::LoadAllPresharded() const {
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ size_t worker_id = static_cast<size_t>(worker->worker_id);
+ size_t num_workers = static_cast<size_t>(worker->num_workers);
+ size_t num_params = param_info_.size() / num_workers;
+
+ Array<NDArray> params;
+ params.reserve(num_params);
+ for (size_t i_param = 0; i_param < num_params; ++i_param) {
+ std::string param_name = static_cast<const std::stringstream&>(
+ std::stringstream() << "param_" << i_param <<
"_shard-"
+ << (worker_id + 1) <<
"-of-" << num_workers)
+ .str();
+
+ auto it = param_name_to_index_.find(param_name);
+ CHECK(it != param_name_to_index_.end())
+ << "Parameter " << param_name << " was not found in the parameter set";
+ int param_id = this->param_name_to_index_.at(param_name);
+ params.push_back(this->LoadDirect(param_id));
+ }
+ return params;
+}
+
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create);
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
.set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) {
@@ -214,6 +314,13 @@ TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
<< loader_obj->GetTypeKey();
return loader->Load(IntegerFromShapeTuple(weight_index));
});
+TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded")
+ .set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) {
+ const auto* loader = loader_obj.as<ShardLoaderObj>();
+ CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but
gets: "
+ << loader_obj->GetTypeKey();
+ return loader->LoadPresharded(IntegerFromShapeTuple(weight_index));
+ });
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll").set_body_typed([](ObjectRef
loader_obj) {
const auto* loader = loader_obj.as<ShardLoaderObj>();
@@ -222,5 +329,13 @@
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll").set_body_typed([](Object
return loader->LoadAll();
});
+TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded")
+ .set_body_typed([](ObjectRef loader_obj) {
+ const auto* loader = loader_obj.as<ShardLoaderObj>();
+ CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but
gets: "
+ << loader_obj->GetTypeKey();
+ return loader->LoadAllPresharded();
+ });
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 997765caf3..e61306377f 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -162,7 +162,7 @@ struct CCLThreadLocalContext {
void InitCCL(Session sess, IntTuple device_ids) {
DRef func = sess->GetGlobalFunc("runtime.disco." TVM_DISCO_CCL_NAME
".init_ccl_per_worker");
- LOG(INFO) << "Initializing " TVM_DISCO_CCL_NAME " with devices: " <<
device_ids;
+ DLOG(INFO) << "Initializing " TVM_DISCO_CCL_NAME " with devices: " <<
device_ids;
ncclUniqueId id;
TVMByteArray array;
NCCL_CALL(ncclGetUniqueId(&id));
diff --git a/src/runtime/disco/worker.cc b/src/runtime/disco/worker.cc
index 3100985f18..9192215dda 100644
--- a/src/runtime/disco/worker.cc
+++ b/src/runtime/disco/worker.cc
@@ -62,8 +62,6 @@ void DiscoWorker::SetRegister(int reg_id, TVMArgValue value) {
struct DiscoWorker::Impl {
static void MainLoop(DiscoWorker* self) {
ThreadLocalDiscoWorker::Get()->worker = self;
- LOG(INFO) << "[Worker #" << self->worker_id << "] " <<
support::GetProcessIdAndThreadIdHeader()
- << " started";
while (true) {
TVMArgs args = self->channel->Recv();
DiscoAction action = static_cast<DiscoAction>(args[0].operator int());
diff --git a/tests/python/disco/test_loader.py
b/tests/python/disco/test_loader.py
index 923afe4ac1..502cbe0b81 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -21,6 +21,8 @@ import tempfile
import numpy as np
+import tvm
+from tvm import dlight as dl
from tvm import relax as rx
from tvm._ffi import register_func
from tvm.contrib import tvmjs
@@ -29,6 +31,7 @@ from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.target import Target
+from tvm.contrib import tvmjs
@register_func("tests.disco.shard_dim_0", override=True)
@@ -87,6 +90,35 @@ def _create_loader(sess, path, param_dict, shard_info):
return loader
+def _simulate_presharded_weights(base_path, param_dict, num_shards,
shard_info):
+ """Create fake weights to simulate those produced MLC-LLM's pre-sharding"""
+
+ sharded_params = {}
+
+ for key, ndarray in param_dict.items():
+ assert key in shard_info, f"ShardInfo lacks shard info about param:
{key}"
+ shard_dim = shard_info[key]
+ sharded_params[key] = [
+ tvm.nd.array(np_shard) for np_shard in np.split(ndarray,
num_shards, axis=shard_dim)
+ ]
+
+ # Re-order so that the parameter order is sorted first by shard,
+ # then by parameter. This matches the ordering used by MLC-LLM,
+ # and avoids having *.bin files that must be accessed by more than
+ # one worker.
+ sharded_params = {
+ f"{key}_shard-{i+1}-of-{num_shards}": shards[i]
+ for i in range(num_shards)
+ for key, shards in sharded_params.items()
+ }
+
+ tvmjs.dump_ndarray_cache(
+ sharded_params,
+ base_path,
+ encode_format="raw",
+ )
+
+
def test_load_shard():
devices = [0, 1]
num_shards = len(devices)
@@ -135,6 +167,55 @@ def test_load_shard():
)
+def _create_presharded_loader(sess, path):
+ path_ndarray_cache = path + "/ndarray-cache.json"
+ with open(path_ndarray_cache, "r", encoding="utf-8") as i_f:
+ ndarray_cache = i_f.read()
+ loader_create = sess.get_global_func("runtime.disco.ShardLoader")
+ loader = loader_create(path_ndarray_cache, ndarray_cache, json.dumps({}),
None)
+ return loader
+
+
+def test_load_presharded():
+ devices = [0, 1]
+ param_dict = {
+ "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
+ "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
+ }
+ shard_info = {
+ "x_0": 1,
+ "x_1": 0,
+ }
+
+ with tempfile.TemporaryDirectory() as path:
+ _simulate_presharded_weights(path, param_dict, len(devices),
shard_info)
+ sess = di.ThreadedSession(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
+
+ loader = _create_presharded_loader(sess, path)
+ loader_load =
sess.get_global_func("runtime.disco.ShardLoaderLoadPresharded")
+
+ d_0 = loader_load(loader, ShapeTuple([0]))
+ d_1 = loader_load(loader, ShapeTuple([1]))
+
+ np.testing.assert_equal(
+ param_dict["x_0"][:, 0:64],
+ d_0.debug_get_from_remote(0).numpy(),
+ )
+ np.testing.assert_equal(
+ param_dict["x_0"][:, 64:128],
+ d_0.debug_get_from_remote(1).numpy(),
+ )
+ np.testing.assert_equal(
+ param_dict["x_1"][0:16, :],
+ d_1.debug_get_from_remote(0).numpy(),
+ )
+ np.testing.assert_equal(
+ param_dict["x_1"][16:32, :],
+ d_1.debug_get_from_remote(1).numpy(),
+ )
+
+
def test_load_shard_in_relax():
devices = [0, 1]
num_shards = len(devices)
@@ -264,6 +345,35 @@ def test_load_shard_all():
np.testing.assert_equal(param_dict["param_1"][16:32, :],
p_1[1].numpy())
+def test_load_all_presharded():
+ devices = [0, 1]
+ num_shards = len(devices)
+ param_dict = {
+ "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
+ "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
+ }
+ shard_info = {
+ "param_0": 0,
+ "param_1": 1,
+ }
+ with tempfile.TemporaryDirectory() as path:
+ _simulate_presharded_weights(path, param_dict, len(devices),
shard_info)
+
+ sess = di.ThreadedSession(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
+ loader = _create_presharded_loader(sess, path)
+ loader_load =
sess.get_global_func("runtime.disco.ShardLoaderLoadAllPresharded")
+ params = loader_load(loader)
+
+ p_0 = params.debug_get_from_remote(0)
+ p_1 = params.debug_get_from_remote(1)
+
+ np.testing.assert_equal(param_dict["param_0"][0:32, :], p_0[0].numpy())
+ np.testing.assert_equal(param_dict["param_0"][32:64, :],
p_1[0].numpy())
+ np.testing.assert_equal(param_dict["param_1"][:, 0:64], p_0[1].numpy())
+ np.testing.assert_equal(param_dict["param_1"][:, 64:128],
p_1[1].numpy())
+
+
def test_load_shard_broadcast():
devices = [0, 1]
param_dict = {
@@ -345,8 +455,4 @@ def test_load_qkv_proj_shard(): # pylint:
disable=too-many-locals
if __name__ == "__main__":
- test_load_shard()
- test_load_shard_in_relax()
- test_load_shard_all()
- test_load_shard_broadcast()
- test_load_qkv_proj_shard()
+ tvm.testing.main()