You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/03 07:06:29 UTC

[tvm] branch unity updated: [Disco] Loading-time sharding support (#15826)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 88a08ae47a [Disco] Loading-time sharding support (#15826)
88a08ae47a is described below

commit 88a08ae47a10966f34a9218a92e20f66435f7540
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Tue Oct 3 00:06:21 2023 -0700

    [Disco] Loading-time sharding support (#15826)
    
    In our previous implementation, parameter sharding relies on pre-quantization weight processing,
    meaning each set of quantized weights corresponds strictly to a hardcoded constant `num_shards`,
    and re-quantization is strictly required upon each change of #GPUs, e.g. from 4-GPU to 8-GPU
    setting. This PR makes it possible to move parameter sharding to post-quantization loading-time.
    During loading, we iterate over all parameters and apply the sharding operation based on the
    provided sharding information.
    
    To make this happen, this PR makes an enhancement to the existing `shard_info.json` to include the
    sharding function being used at loading time. Each parameter is attached to a list of loading-time
    preprocessing methods that are serially applied to it to transform this parameter to the desired
    shape, as shown in the example below:
    
    ```python
    shard_info = {
      "x_0": [ # name of the parameter
        [ # a list of preprocessing functions to be applied
          "tests.disco.shard_dim_1",  # name of the sharding function
          [(num_shards, 64, 64), "float16"],  # output shape/dtype of `tests.disco.shard_dim_1`
          num_shards,  # extra inputs to `tests.disco.shard_dim_1`
        ],
      ],
      "x_1": [...],
    }
    ```
    
    To parameter `x_0`, it means we will call method `tests.disco.shard_dim_1` which has the signature:
    
    ```python
    def shard_dim_1(
      input: NDArray,
      num_shards, # extra inputs
      output: NDArray, # and its shape is (num_shards, 64, 64), and dtype is "float16"
    ) -> None: ...
    ```
    
    This approach simplifies parameter sharding for users and ensures correctness.
---
 python/tvm/runtime/disco/process_pool.py      |   1 -
 src/runtime/disco/loader.cc                   | 198 ++++++++++++--------------
 src/runtime/relax_vm/ndarray_cache_support.cc |  85 ++++++++---
 src/runtime/relax_vm/ndarray_cache_support.h  |  19 ++-
 tests/python/disco/test_loader.py             | 193 ++++++++++++++++++++-----
 5 files changed, 331 insertions(+), 165 deletions(-)

diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py
index 44348577f7..fd4ba7a165 100644
--- a/python/tvm/runtime/disco/process_pool.py
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -173,7 +173,6 @@ def _create_process_pool(num_workers: int):
         if worker_id != 0:
             read_fd, write_fd = pool[worker_id - 1].start()
             return ShapeTuple([read_fd, write_fd])
-        print("Shutting down the process pool")
         del pool
         return None
 
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 4125e0b259..7670cc5254 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -23,6 +23,7 @@
 #include <functional>
 #include <numeric>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #include "../file_utils.h"
@@ -36,41 +37,39 @@ namespace runtime {
 using relax_vm::NDArrayCacheMetadata;
 using FileRecord = NDArrayCacheMetadata::FileRecord;
 using ParamRecord = NDArrayCacheMetadata::FileRecord::ParamRecord;
-using relax_vm::LoadShardInfoFromStr;
+using relax_vm::ShardInfo;
 
 /*! \brief An object that helps to load parameters in shards. */
 class ShardLoaderObj : public Object {
  public:
   /*! \brief Create a shard loader. */
   static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata,
-                          const std::string& shard_info, Module mod);
+                          std::string shard_info, Module mod);
   /*! \brief Load the i-th parameter */
   NDArray Load(int weight_index) const;
   /*! \brief Load all the parameters */
   Array<NDArray> LoadAll() const;
-  /*! \brief Slice the given tensor at a specific dimension */
-  NDArray Shard(NDArray source, int dim, int num_slices) const;
+
+  NDArray ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const NDArray& param) const;
 
   static constexpr const char* _type_key = "runtime.disco.ShardLoader";
   TVM_DECLARE_FINAL_OBJECT_INFO(ShardLoaderObj, Object);
 
  public:
   /*! \brief Information of how each weight is stored and sharded */
-  struct ShardInfo {
+  struct ParamInfo {
     const FileRecord* file;
     const ParamRecord* param;
-    int shard_dim;
+    ShardInfo shard_info;
   };
+  /*! \brief The PackedFuncs being used during sharding */
+  std::unordered_map<std::string, PackedFunc> shard_funcs_;
   /*! \brief The metadata loaded from `ndarray-cache.json` */
   NDArrayCacheMetadata metadata_;
   /*! \brief Sharding information for each weight */
-  std::vector<ShardInfo> shard_info_;
+  std::vector<ParamInfo> param_info_;
   /*! \brief Maps the name of a shard to its index */
   std::unordered_map<std::string, int> param_name_to_index_;
-  /*! \brief A method to slice a 3-D tensor */
-  TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp16_;
-  TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp32_;
-  TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_uint32_;
   /*! \brief The current file opened to load weights in it */
   mutable const FileRecord* current_file_;
   /*! \brief The context of the current file to be loaded from */
@@ -79,50 +78,61 @@ class ShardLoaderObj : public Object {
 
 TVM_REGISTER_OBJECT_TYPE(ShardLoaderObj);
 
-/*!
- * \brief Get the shape of a result tensor if it is scattered along a given axis.
- * \param shape The shape of the input tensor.
- * \param dim The axis along which the tensor is scattered.
- * \param num_shards The number of shards.
- * \return The shape of the result tensor.
- */
-inline std::vector<ShapeTuple::index_type> ShardShape(const ShapeTuple& shape, int dim,
-                                                      int num_shards) {
-  CHECK(0 <= dim && dim < static_cast<int>(shape.size()))
-      << "ValueError: Cannot scatter at dim " << dim << ", because "
-      << "shape is " << shape << ".";
-  CHECK_EQ(shape[dim] % num_shards, 0)
-      << "ValueError: The shape " << shape << " cannot be scattered at dim " << dim << " into "
-      << num_shards << " shards.";
-  std::vector<ShapeTupleObj::index_type> result{shape.begin(), shape.end()};
-  result[dim] /= num_shards;
-  return result;
-}
-
 ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata,
-                                 const std::string& shard_info, Module mod) {
+                                 std::string shard_info, Module mod) {
+  if (shard_info.empty() && mod.defined()) {
+    if (PackedFunc get_shard_info = mod->GetFunction("get_shard_info"); get_shard_info != nullptr) {
+      shard_info = get_shard_info().operator String();
+    }
+  }
   ObjectPtr<ShardLoaderObj> n = make_object<ShardLoaderObj>();
-  n->f_shard3d_fp16_ = mod->GetFunction("shard3d_fp16", true);
-  n->f_shard3d_fp32_ = mod->GetFunction("shard3d_fp32", true);
-  n->f_shard3d_uint32_ = mod->GetFunction("shard3d_uint32", true);
-  CHECK(n->f_shard3d_fp16_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp16";
-  CHECK(n->f_shard3d_fp32_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp32";
-  CHECK(n->f_shard3d_uint32_ != nullptr) << "ValueError: Cannot find the function: shard3d_uint32";
   n->metadata_ = NDArrayCacheMetadata::LoadFromStr(metadata, path_to_metadata);
   n->current_file_ = nullptr;
-  n->shard_info_.clear();
-  std::unordered_map<std::string, int> shards = LoadShardInfoFromStr(shard_info);
+  n->param_info_.clear();
+  std::unordered_map<std::string, ShardInfo> shards = relax_vm::LoadShardInfoFromStr(shard_info);
   for (const FileRecord& file_record : n->metadata_.records) {
     for (const ParamRecord& param_record : file_record.records) {
       const std::string& name = param_record.name;
-      int shard_id = shards.count(name) ? shards[name] : -1;
-      n->param_name_to_index_[name] = n->shard_info_.size();
-      n->shard_info_.push_back(ShardInfo{&file_record, &param_record, shard_id});
+      int index = n->param_info_.size();
+      n->param_name_to_index_[name] = index;
+      ShardInfo& shard_info = shards[name];
+      for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) {
+        const std::string& name = shard_func.name;
+        if (PackedFunc f = mod.defined() ? mod->GetFunction(name, true) : nullptr; f != nullptr) {
+          n->shard_funcs_[name] = f;
+        } else if (const PackedFunc* f = runtime::Registry::Get(name)) {
+          n->shard_funcs_[name] = *f;
+        } else {
+          LOG(FATAL) << "ValueError: Undefined function: " << name;
+        }
+      }
+      n->param_info_.emplace_back(ParamInfo{&file_record, &param_record, shard_info});
     }
   }
   return ObjectRef(std::move(n));
 }
 
+NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func,
+                                       const NDArray& param) const {
+  Device device = param->device;
+  NDArray o = NDArray::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device);
+  PackedFunc f = this->shard_funcs_.at(shard_func.name);
+  int n = static_cast<int>(shard_func.params.size());
+  std::vector<TVMValue> tvm_args(n + 2);
+  std::vector<int> type_codes(n + 2);
+  TVMArgsSetter setter(tvm_args.data(), type_codes.data());
+  const DLTensor* w_in = param.operator->();
+  const DLTensor* w_out = o.operator->();
+  setter(0, const_cast<DLTensor*>(w_in));
+  for (int i = 0; i < n; ++i) {
+    setter(i + 1, shard_func.params[i]);
+  }
+  setter(n + 1, const_cast<DLTensor*>(w_out));
+  TVMRetValue rv;
+  f.CallPacked(TVMArgs(tvm_args.data(), type_codes.data(), n + 2), &rv);
+  return o;
+}
+
 std::string GetSiblingPath(const std::string& path, const std::string& filename) {
   size_t found = path.find_last_of("/\\");
   if (found != std::string::npos) {
@@ -133,97 +143,69 @@ std::string GetSiblingPath(const std::string& path, const std::string& filename)
 
 NDArray ShardLoaderObj::Load(int weight_index) const {
   DiscoWorker* worker = DiscoWorker::ThreadLocal();
-  int shard_idx = worker->worker_id;
-  Device device = worker->default_device;
-  const auto& shard_info = shard_info_.at(weight_index);
-  const ParamRecord* param = shard_info.param;
-  const FileRecord* file = shard_info.file;
-  int shard_dim = shard_info.shard_dim;
+  int worker_id = worker->worker_id;
   int num_shards = worker->num_workers;
-  Optional<NDArray> send = NullOpt;
-  if (shard_idx == 0) {
+  Device device = worker->default_device;
+  const ParamInfo& param_info = param_info_.at(weight_index);
+  const ParamRecord* param = param_info.param;
+  const FileRecord* file = param_info.file;
+
+  auto load = [this, param, device, file]() {
     if (file != current_file_) {
       current_file_ = file;
       std::string file_name = GetSiblingPath(this->metadata_.path, file->data_path);
       LoadBinaryFromFile(file_name, &this->current_file_stream_);
     }
-    auto f_load = [](NDArray param, const void* data, size_t nbytes) {
-      param.CopyFromBytes(data, nbytes);
-    };
-    if (shard_dim != -1) {
-      send = this->Shard(param->Load(device, &this->current_file_stream_, f_load), shard_dim,
-                         num_shards);
+    return param->Load(
+        device, &this->current_file_stream_,
+        [](NDArray param, const void* data, size_t nbytes) { param.CopyFromBytes(data, nbytes); });
+  };
+
+  bool needs_sharding = !param_info.shard_info.funcs.empty();
+  if (needs_sharding) {
+    ShapeTuple shape = param_info.shard_info.funcs.back().output_info.shape;
+    DataType dtype = param_info.shard_info.funcs.back().output_info.dtype;
+    ICHECK(shape.size() >= 1 && shape[0] == num_shards)
+        << "ValueError: The first dimension of the "
+        << "output shape must be equal to the "
+        << "number of shards, but got: " << shape << " and num_shards = " << num_shards;
+    NDArray recv = NDArray::Empty(ShapeTuple(shape.begin() + 1, shape.end()), dtype, device);
+    if (worker_id == 0) {
+      NDArray w = load();
+      for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) {
+        w = this->ApplyShardFunc(shard_func, w);
+      }
+      ScatterFromWorker0(w, recv);
     } else {
-      send = param->Load(device, &this->current_file_stream_, f_load);
+      ScatterFromWorker0(NullOpt, recv);
     }
-  }
-  if (shard_dim != -1) {
-    NDArray recv =
-        NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards), param->dtype, device);
-    ScatterFromWorker0(send, recv);
     return recv;
   } else {
-    NDArray recv;
-    if (send.defined()) {
-      recv = NDArray(send.value());
+    if (worker_id == 0) {
+      NDArray w = load();
+      BroadcastFromWorker0(w, w);
+      return w;
     } else {
-      recv = NDArray::Empty(param->shape, param->dtype, device);
+      NDArray w = NDArray::Empty(param->shape, param->dtype, device);
+      BroadcastFromWorker0(w, w);
+      return w;
     }
-    BroadcastFromWorker0(recv, recv);
-    return recv;
   }
 }
 
 Array<NDArray> ShardLoaderObj::LoadAll() const {
-  int n = static_cast<int>(shard_info_.size());
+  int n = static_cast<int>(param_info_.size());
   Array<NDArray> shards;
   shards.reserve(n);
   for (int i = 0; i < n; ++i) {
     std::string param_name = "param_" + std::to_string(i);
+    ICHECK(this->param_name_to_index_.count(param_name));
     int shard_id = this->param_name_to_index_.at(param_name);
     shards.push_back(this->Load(shard_id));
   }
   return shards;
 }
 
-NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
-  ICHECK(dim >= 0 && dim < source->ndim);
-  // Assemble a flattened 3d src tensor
-  int64_t src_flat[3] = {1, 1, 1};
-  {
-    const int64_t* s = source.Shape().data();
-    int ndim = source->ndim;
-    src_flat[0] = std::accumulate(&s[0], &s[dim], 1, std::multiplies<int64_t>());
-    src_flat[1] = s[dim];
-    src_flat[2] = std::accumulate(&s[dim + 1], &s[ndim], 1, std::multiplies<int64_t>());
-  }
-  DLTensor src_tensor = *source.operator->();
-  src_tensor.ndim = 3;
-  src_tensor.shape = src_flat;
-  // Assmeble a flattened 4d dst tensor
-  int64_t dst_flat[4] = {num_slices, src_flat[0], src_flat[1] / num_slices, src_flat[2]};
-  NDArray destination{nullptr};
-  {
-    std::vector<ShapeTuple::index_type> dst_shape = ShardShape(source.Shape(), dim, num_slices);
-    dst_shape.insert(dst_shape.begin(), static_cast<ShapeTuple::index_type>(num_slices));
-    destination = NDArray::Empty(dst_shape, source->dtype, source->device);
-  }
-  DLTensor dst_tensor = *destination.operator->();
-  dst_tensor.ndim = 4;
-  dst_tensor.shape = dst_flat;
-  // Copy slices using the API
-  if (source.DataType() == DataType::Float(32)) {
-    this->f_shard3d_fp32_(&src_tensor, num_slices, &dst_tensor);
-  } else if (source.DataType() == DataType::Float(16)) {
-    this->f_shard3d_fp16_(&src_tensor, num_slices, &dst_tensor);
-  } else if (source.DataType() == DataType::UInt(32)) {
-    this->f_shard3d_uint32_(&src_tensor, num_slices, &dst_tensor);
-  } else {
-    LOG(FATAL) << "ValueError: Unsupported data type: " << source.DataType();
-  }
-  return destination;
-}
-
 TVM_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create);
 TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
     .set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) {
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc
index 3a5e961fe1..b2f53bfe1e 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -54,44 +54,56 @@ namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
+template <typename ExpectedType>
+ExpectedType AsType(const picojson::value& json) {
+  ICHECK(json.is<ExpectedType>());
+  return json.get<ExpectedType>();
+}
+
+template <typename ValueType>
+ValueType GetValue(const picojson::object& json, const std::string& key) {
+  return AsType<ValueType>(json.at(key));
+}
+
 NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) {
   std::vector<ShapeTuple::index_type> shape;
   {
-    picojson::array shape_json = json.at("shape").get<picojson::array>();
+    picojson::array shape_json = GetValue<picojson::array>(json, "shape");
     shape.reserve(shape_json.size());
     for (const picojson::value& d : shape_json) {
-      shape.push_back(d.get<int64_t>());
+      shape.push_back(AsType<int64_t>(d));
     }
   }
   NDArrayCacheMetadata::FileRecord::ParamRecord result;
-  result.name = json.at("name").get<std::string>();
-  result.dtype = DataType(String2DLDataType(json.at("dtype").get<std::string>()));
-  result.format = json.at("format").get<std::string>();
-  result.nbytes = json.at("nbytes").get<int64_t>();
-  result.byte_offset = json.at("byteOffset").get<int64_t>();
+  std::string dtype = GetValue<std::string>(json, "dtype");
+  result.name = GetValue<std::string>(json, "name");
+  result.dtype = DataType(String2DLDataType(dtype));
+  result.format = GetValue<std::string>(json, "format");
+  result.nbytes = GetValue<int64_t>(json, "nbytes");
+  result.byte_offset = GetValue<int64_t>(json, "byteOffset");
   result.shape = ShapeTuple(std::move(shape));
   return result;
 }
 
 NDArrayCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) {
-  picojson::array records = json.at("records").get<picojson::array>();
+  picojson::array records = GetValue<picojson::array>(json, "records");
   NDArrayCacheMetadata::FileRecord result;
-  result.data_path = json.at("dataPath").get<std::string>();
-  result.format = json.at("format").get<std::string>();
-  result.nbytes = json.at("nbytes").get<int64_t>();
+  result.data_path = GetValue<std::string>(json, "dataPath");
+  result.format = GetValue<std::string>(json, "format");
+  result.nbytes = GetValue<int64_t>(json, "nbytes");
   result.records.reserve(records.size());
   for (const picojson::value& item : records) {
-    result.records.push_back(JSONAsParamRecord(item.get<picojson::object>()));
+    result.records.push_back(JSONAsParamRecord(AsType<picojson::object>(item)));
   }
   return result;
 }
 
 NDArrayCacheMetadata JSONAsNDArrayCacheMetadata(const picojson::object& json) {
-  picojson::array records = json.at("records").get<picojson::array>();
+  picojson::array records = GetValue<picojson::array>(json, "records");
   NDArrayCacheMetadata result;
   result.records.reserve(records.size());
   for (const picojson::value& item : records) {
-    result.records.push_back(JSONAsFileRecord(item.get<picojson::object>()));
+    result.records.push_back(JSONAsFileRecord(AsType<picojson::object>(item)));
   }
   return result;
 }
@@ -100,20 +112,51 @@ NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_s
                                                        const std::string& path) {
   picojson::value json_info;
   picojson::parse(json_info, json_str);
-  NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(json_info.get<picojson::object>());
+  NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(AsType<picojson::object>(json_info));
   result.path = path;
   return result;
 }
 
-std::unordered_map<std::string, int> LoadShardInfoFromStr(const std::string& json_str) {
+ShardInfo::TensorInfo LoadTensorInfoFromJSON(const picojson::array& json_tensor_info) {
+  CHECK_EQ(json_tensor_info.size(), 2) << "ValueError: Invalid tensor info JSON";
+  picojson::array shape_json = AsType<picojson::array>(json_tensor_info[0]);
+  int ndim = shape_json.size();
+  std::vector<int64_t> shape;
+  shape.reserve(ndim);
+  for (int i = 0; i < ndim; ++i) {
+    shape.push_back(AsType<int64_t>(shape_json[i]));
+  }
+  std::string dtype = AsType<std::string>(json_tensor_info[1]);
+  return ShardInfo::TensorInfo{ShapeTuple(std::move(shape)), DataType(String2DLDataType(dtype))};
+}
+
+ShardInfo::ShardFunc LoadShardFuncFromJSON(const picojson::array& json_shard_func) {
+  int n = json_shard_func.size();
+  ShardInfo::ShardFunc shard_info;
+  shard_info.name = AsType<std::string>(json_shard_func[0]);
+  shard_info.output_info = LoadTensorInfoFromJSON(AsType<picojson::array>(json_shard_func[1]));
+  shard_info.params.reserve(n - 2);
+  for (int i = 2; i < n; ++i) {
+    shard_info.params.push_back(AsType<int64_t>(json_shard_func[i]));
+  }
+  return shard_info;
+}
+
+std::unordered_map<std::string, ShardInfo> LoadShardInfoFromStr(const std::string& json_str) {
   picojson::value json_info;
   picojson::parse(json_info, json_str);
-  picojson::object json_obj = json_info.get<picojson::object>();
-  std::unordered_map<std::string, int> result;
-  for (const auto& kv : json_obj) {
+  picojson::object json_obj = AsType<picojson::object>(json_info);
+  std::unordered_map<std::string, ShardInfo> result;
+  for (auto kv : json_obj) {
     std::string name = kv.first;
-    int64_t shard_dim = kv.second.get<int64_t>();
-    result[name] = shard_dim;
+    picojson::array json_shard_funcs = AsType<picojson::array>(kv.second);
+    ShardInfo info;
+    std::vector<ShardInfo::ShardFunc>& shard_funcs = info.funcs;
+    shard_funcs.reserve(json_shard_funcs.size());
+    for (const picojson::value& json_shard_func : json_shard_funcs) {
+      shard_funcs.push_back(LoadShardFuncFromJSON(AsType<picojson::array>(json_shard_func)));
+    }
+    result[name] = info;
   }
   return result;
 }
diff --git a/src/runtime/relax_vm/ndarray_cache_support.h b/src/runtime/relax_vm/ndarray_cache_support.h
index 2bc638d6e3..c1beb5a946 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.h
+++ b/src/runtime/relax_vm/ndarray_cache_support.h
@@ -79,12 +79,29 @@ struct NDArrayCacheMetadata {
   static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path);
 };
 
+/*!
+ * \brief Information of sharding function,
+ * including the shard function name and extra parameters.
+ */
+struct ShardInfo {
+  struct TensorInfo {
+    ShapeTuple shape;
+    DataType dtype;
+  };
+  struct ShardFunc {
+    std::string name;
+    TensorInfo output_info;
+    std::vector<int64_t> params;
+  };
+  std::vector<ShardFunc> funcs;
+};
+
 /*!
  * \brief Load the shard information from dist
  * \param path Path to the file to be loaded
  * \return Mapping from parameter name to its shard dim
  */
-std::unordered_map<std::string, int> LoadShardInfoFromStr(const std::string& json_str);
+std::unordered_map<std::string, ShardInfo> LoadShardInfoFromStr(const std::string& json_str);
 
 }  // namespace relax_vm
 }  // namespace runtime
diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py
index c92bac6b46..923afe4ac1 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -21,7 +21,6 @@ import tempfile
 
 import numpy as np
 
-from tvm import dlight as dl
 from tvm import relax as rx
 from tvm._ffi import register_func
 from tvm.contrib import tvmjs
@@ -32,10 +31,50 @@ from tvm.script import relax as R
 from tvm.target import Target
 
 
-@register_func("tests.disco.shard_with_numpy", override=True)
-def _shard_with_numpy(src, num_shards, tgt):
-    s_0, s_1, s_2 = src.shape
-    tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards, s_2).transpose(1, 0, 2, 3))
+@register_func("tests.disco.shard_dim_0", override=True)
+def _shard_dim_0(src, num_shards, tgt):
+    s_0, s_1 = src.shape
+    tgt.copyfrom(src.numpy().reshape(num_shards, s_0 // num_shards, s_1))
+
+
+@register_func("tests.disco.shard_dim_1", override=True)
+def _shard_dim_1(src, num_shards, tgt):
+    s_0, s_1 = src.shape
+    tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards).transpose(1, 0, 2))
+
+
+@register_func("tests.disco.shard_qkv_0", override=True)
+def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt):
+    total_dim, hidden_size = src.shape
+    head_dim = total_dim // (q_heads + kv_heads + kv_heads)
+    q_dim = q_heads * head_dim
+    kv_dim = kv_heads * head_dim
+    w_q = src.numpy()[:q_dim, :].reshape(
+        num_shards,
+        q_heads // num_shards,
+        head_dim,
+        hidden_size,
+    )
+    w_k = src.numpy()[q_dim : q_dim + kv_dim, :].reshape(
+        num_shards,
+        kv_heads // num_shards,
+        head_dim,
+        hidden_size,
+    )
+    w_v = src.numpy()[q_dim + kv_dim :, :].reshape(
+        num_shards,
+        kv_heads // num_shards,
+        head_dim,
+        hidden_size,
+    )
+    w_qkv = np.concatenate([w_q, w_k, w_v], axis=1)
+    tgt.copyfrom(w_qkv)
+
+
+@register_func("tests.disco.shard_qkv_1", override=True)
+def _shard_qkv_1(src, tgt):
+    s, _, _, h = src.shape  # pylint: disable=invalid-name
+    tgt.copyfrom(src.numpy().reshape(s, -1, h))
 
 
 def _create_loader(sess, path, param_dict, shard_info):
@@ -44,23 +83,33 @@ def _create_loader(sess, path, param_dict, shard_info):
     with open(path_ndarray_cache, "r", encoding="utf-8") as i_f:
         ndarray_cache = i_f.read()
     loader_create = sess.get_global_func("runtime.disco.ShardLoader")
-    shard_with_numpy = sess.get_global_func("tests.disco.shard_with_numpy")
-    loader = loader_create(path_ndarray_cache, ndarray_cache, shard_info, shard_with_numpy)
+    loader = loader_create(path_ndarray_cache, ndarray_cache, json.dumps(shard_info), None)
     return loader
 
 
 def test_load_shard():
     devices = [0, 1]
+    num_shards = len(devices)
     param_dict = {
         "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
         "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
     }
-    shard_info = json.dumps(
-        {
-            "x_0": 1,
-            "x_1": 0,
-        }
-    )
+    shard_info = {
+        "x_0": [
+            [
+                "tests.disco.shard_dim_1",
+                [(num_shards, 64, 64), "float16"],
+                num_shards,
+            ],
+        ],
+        "x_1": [
+            [
+                "tests.disco.shard_dim_0",
+                [(num_shards, 16, 128), "float32"],
+                num_shards,
+            ]
+        ],
+    }
     with tempfile.TemporaryDirectory() as path:
         sess = di.ThreadedSession(num_workers=len(devices))
         sess.init_ccl("nccl", *devices)
@@ -88,16 +137,27 @@ def test_load_shard():
 
 def test_load_shard_in_relax():
     devices = [0, 1]
+    num_shards = len(devices)
     param_dict = {
         "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
         "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
     }
-    shard_info = json.dumps(
-        {
-            "x_0": 1,
-            "x_1": 0,
-        }
-    )
+    shard_info = {
+        "x_0": [
+            [
+                "tests.disco.shard_dim_1",
+                [(num_shards, 64, 64), "float16"],
+                num_shards,
+            ],
+        ],
+        "x_1": [
+            [
+                "tests.disco.shard_dim_0",
+                [(num_shards, 16, 128), "float32"],
+                num_shards,
+            ]
+        ],
+    }
 
     # pylint: disable=invalid-name
     @I.ir_module
@@ -128,13 +188,6 @@ def test_load_shard_in_relax():
     def relax_build(mod, target):
         with target:
             mod = rx.get_pipeline("zero")(mod)  # pylint: disable=no-value-for-parameter
-            mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
-                dl.gpu.Matmul(),
-                dl.gpu.GEMV(),
-                dl.gpu.Reduction(),
-                dl.gpu.GeneralReduction(),
-                dl.gpu.Fallback(),
-            )(mod)
             return rx.build(mod, target="cuda")
 
     target = Target(
@@ -149,9 +202,10 @@ def test_load_shard_in_relax():
     )
     with tempfile.TemporaryDirectory() as tmpdir:
         dso_path = tmpdir + "/test.so"
-        relax_build(Module, target).export_library(dso_path)
         sess = di.ThreadedSession(num_workers=len(devices))
         sess.init_ccl("nccl", *devices)
+        relax_build(Module, target).export_library(dso_path)
+
         mod = sess.load_vm_module(dso_path)
         loader = _create_loader(sess, tmpdir, param_dict, shard_info)
         result = mod["main"](loader)
@@ -175,16 +229,27 @@ def test_load_shard_in_relax():
 
 def test_load_shard_all():
     devices = [0, 1]
+    num_shards = len(devices)
     param_dict = {
         "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
         "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
     }
-    shard_info = json.dumps(
-        {
-            "param_0": 1,
-            "param_1": 0,
-        }
-    )
+    shard_info = {
+        "param_0": [
+            [
+                "tests.disco.shard_dim_1",
+                [(num_shards, 64, 64), "float16"],
+                num_shards,
+            ],
+        ],
+        "param_1": [
+            [
+                "tests.disco.shard_dim_0",
+                [(2, 16, 128), "float32"],
+                num_shards,
+            ]
+        ],
+    }
     with tempfile.TemporaryDirectory() as path:
         sess = di.ThreadedSession(num_workers=len(devices))
         sess.init_ccl("nccl", *devices)
@@ -205,7 +270,7 @@ def test_load_shard_broadcast():
         "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
         "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
     }
-    shard_info = "{}"
+    shard_info = {}
     with tempfile.TemporaryDirectory() as path:
         sess = di.ThreadedSession(num_workers=len(devices))
         sess.init_ccl("nccl", *devices)
@@ -220,8 +285,68 @@ def test_load_shard_broadcast():
         np.testing.assert_equal(param_dict["param_1"], p_1[1].numpy())
 
 
+def test_load_qkv_proj_shard():  # pylint: disable=too-many-locals
+    devices = [0, 1]
+    num_shards = len(devices)
+    q_heads = 8
+    kv_heads = 10
+    head_dim = 10
+    hidden_size = 20
+    w_q = np.random.uniform(size=[q_heads * head_dim, hidden_size]).astype("float16")
+    w_k = np.random.uniform(size=[kv_heads * head_dim, hidden_size]).astype("float16")
+    w_v = np.random.uniform(size=[kv_heads * head_dim, hidden_size]).astype("float16")
+    w_qkv = np.concatenate([w_q, w_k, w_v], axis=0)
+    param_dict = {"w_qkv": w_qkv}
+    np_qkv = np.concatenate(
+        [
+            w_q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)),
+            w_k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)),
+            w_v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)),
+        ],
+        axis=1,
+    ).reshape((num_shards, -1, hidden_size))
+
+    shard_info = {
+        "w_qkv": [
+            [
+                "tests.disco.shard_qkv_0",
+                [
+                    (num_shards, (q_heads + kv_heads * 2) // num_shards, head_dim, hidden_size),
+                    "float16",
+                ],
+                num_shards,
+                q_heads,
+                kv_heads,
+            ],
+            [
+                "tests.disco.shard_qkv_1",
+                [
+                    (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim, hidden_size),
+                    "float16",
+                ],
+            ],
+        ],
+    }
+
+    with tempfile.TemporaryDirectory() as path:
+        sess = di.ThreadedSession(num_workers=len(devices))
+        sess.init_ccl("nccl", *devices)
+        loader = _create_loader(sess, path, param_dict, shard_info)
+        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoad")
+        d_0 = loader_load(loader, ShapeTuple([0]))
+        np.testing.assert_equal(
+            np_qkv[0],
+            d_0.debug_get_from_remote(0).numpy(),
+        )
+        np.testing.assert_equal(
+            np_qkv[1],
+            d_0.debug_get_from_remote(1).numpy(),
+        )
+
+
 if __name__ == "__main__":
     test_load_shard()
     test_load_shard_in_relax()
     test_load_shard_all()
     test_load_shard_broadcast()
+    test_load_qkv_proj_shard()