You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "Lunderberg (via GitHub)" <gi...@apache.org> on 2023/09/21 16:18:04 UTC

[GitHub] [tvm] Lunderberg commented on a diff in pull request #15676: [Disco] Add loader for presharded params.

Lunderberg commented on code in PR #15676:
URL: https://github.com/apache/tvm/pull/15676#discussion_r1331941078


##########
src/runtime/disco/loader.cc:
##########
@@ -178,6 +197,38 @@ NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
   return destination;
 }
 
+NDArray ShardLoaderObj::LoadPresharded(int weight_index) const {
+  DiscoWorker* worker = DiscoWorker::ThreadLocal();
+  int worker_id = worker->worker_id;
+  int num_shards = worker->num_workers;
+  Device device = worker->default_device;
+  size_t index = weight_index * num_shards + worker_id;

Review Comment:
   Does this line line imply that a sharded parameter set contains only sharded parameters?  I don't think that assumption is true for our use case.  (e.g. Sharded weights/bias for matmul, but unsharded vocab embedding)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org