Lunderberg commented on code in PR #15676:
URL: https://github.com/apache/tvm/pull/15676#discussion_r1331941078
##########
src/runtime/disco/loader.cc:
##########
@@ -178,6 +197,38 @@ NDArray ShardLoaderObj::Shard(NDArray source, int dim, int
num_slices) const {
return destination;
}
+NDArray ShardLoaderObj::LoadPresharded(int weight_index) const {
+ DiscoWorker* worker = DiscoWorker::ThreadLocal();
+ int worker_id = worker->worker_id;
+ int num_shards = worker->num_workers;
+ Device device = worker->default_device;
+ size_t index = weight_index * num_shards + worker_id;
Review Comment:
Does this line line imply that a sharded parameter set contains only sharded
parameters? I don't think that assumption is true for our use case. (e.g.
Sharded weights/bias for matmul, but unsharded vocab embedding)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]