You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/03 07:06:29 UTC
[tvm] branch unity updated: [Disco] Loading-time sharding support (#15826)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 88a08ae47a [Disco] Loading-time sharding support (#15826)
88a08ae47a is described below
commit 88a08ae47a10966f34a9218a92e20f66435f7540
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Tue Oct 3 00:06:21 2023 -0700
[Disco] Loading-time sharding support (#15826)
In our previous implementation, parameter sharding relies on pre-quantization weight processing,
meaning each set of quantized weights corresponds strictly to a hardcoded constant `num_shards`,
and re-quantization is strictly required upon each change of #GPUs, e.g. from 4-GPU to 8-GPU
setting. This PR makes it possible to move parameter sharding to post-quantization loading-time.
During loading, we iterate over all parameters and apply the sharding operation based on the
provided sharding information.
To make this happen, this PR makes an enhancement to the existing `shard_info.json` to include the
sharding function being used at loading time. Each parameter is attached to a list of loading-time
preprocessing methods that are serially applied to it to transform this parameter to the desired
shape, as shown in the example below:
```python
shard_info = {
"x_0": [ # name of the parameter
[ # a list of preprocessing functions to be applied
"tests.disco.shard_dim_1", # name of the sharding function
[(num_shards, 64, 64), "float16"], # output shape/dtype of `tests.disco.shard_dim_1`
num_shards, # extra inputs to `tests.disco.shard_dim_1`
],
],
"x_1": [...],
}
```
To parameter `x_0`, it means we will call method `tests.disco.shard_dim_1` which has the signature:
```python
def shard_dim_1(
input: NDArray,
num_shards, # extra inputs
output: NDArray, # and its shape is (num_shards, 64, 64), and dtype is "float16"
) -> None: ...
```
This approach simplifies parameter sharding for users and ensures correctness.
---
python/tvm/runtime/disco/process_pool.py | 1 -
src/runtime/disco/loader.cc | 198 ++++++++++++--------------
src/runtime/relax_vm/ndarray_cache_support.cc | 85 ++++++++---
src/runtime/relax_vm/ndarray_cache_support.h | 19 ++-
tests/python/disco/test_loader.py | 193 ++++++++++++++++++++-----
5 files changed, 331 insertions(+), 165 deletions(-)
diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py
index 44348577f7..fd4ba7a165 100644
--- a/python/tvm/runtime/disco/process_pool.py
+++ b/python/tvm/runtime/disco/process_pool.py
@@ -173,7 +173,6 @@ def _create_process_pool(num_workers: int):
if worker_id != 0:
read_fd, write_fd = pool[worker_id - 1].start()
return ShapeTuple([read_fd, write_fd])
- print("Shutting down the process pool")
del pool
return None
diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc
index 4125e0b259..7670cc5254 100644
--- a/src/runtime/disco/loader.cc
+++ b/src/runtime/disco/loader.cc
@@ -23,6 +23,7 @@
#include <functional>
#include <numeric>
#include <string>
+#include <unordered_map>
#include <vector>
#include "../file_utils.h"
@@ -36,41 +37,39 @@ namespace runtime {
using relax_vm::NDArrayCacheMetadata;
using FileRecord = NDArrayCacheMetadata::FileRecord;
using ParamRecord = NDArrayCacheMetadata::FileRecord::ParamRecord;
-using relax_vm::LoadShardInfoFromStr;
+using relax_vm::ShardInfo;
/*! \brief An object that helps to load parameters in shards. */
class ShardLoaderObj : public Object {
public:
/*! \brief Create a shard loader. */
static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata,
- const std::string& shard_info, Module mod);
+ std::string shard_info, Module mod);
/*! \brief Load the i-th parameter */
NDArray Load(int weight_index) const;
/*! \brief Load all the parameters */
Array<NDArray> LoadAll() const;
- /*! \brief Slice the given tensor at a specific dimension */
- NDArray Shard(NDArray source, int dim, int num_slices) const;
+
+ NDArray ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const NDArray& param) const;
static constexpr const char* _type_key = "runtime.disco.ShardLoader";
TVM_DECLARE_FINAL_OBJECT_INFO(ShardLoaderObj, Object);
public:
/*! \brief Information of how each weight is stored and sharded */
- struct ShardInfo {
+ struct ParamInfo {
const FileRecord* file;
const ParamRecord* param;
- int shard_dim;
+ ShardInfo shard_info;
};
+ /*! \brief The PackedFuncs being used during sharding */
+ std::unordered_map<std::string, PackedFunc> shard_funcs_;
/*! \brief The metadata loaded from `ndarray-cache.json` */
NDArrayCacheMetadata metadata_;
/*! \brief Sharding information for each weight */
- std::vector<ShardInfo> shard_info_;
+ std::vector<ParamInfo> param_info_;
/*! \brief Maps the name of a shard to its index */
std::unordered_map<std::string, int> param_name_to_index_;
- /*! \brief A method to slice a 3-D tensor */
- TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp16_;
- TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp32_;
- TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_uint32_;
/*! \brief The current file opened to load weights in it */
mutable const FileRecord* current_file_;
/*! \brief The context of the current file to be loaded from */
@@ -79,50 +78,61 @@ class ShardLoaderObj : public Object {
TVM_REGISTER_OBJECT_TYPE(ShardLoaderObj);
-/*!
- * \brief Get the shape of a result tensor if it is scattered along a given axis.
- * \param shape The shape of the input tensor.
- * \param dim The axis along which the tensor is scattered.
- * \param num_shards The number of shards.
- * \return The shape of the result tensor.
- */
-inline std::vector<ShapeTuple::index_type> ShardShape(const ShapeTuple& shape, int dim,
- int num_shards) {
- CHECK(0 <= dim && dim < static_cast<int>(shape.size()))
- << "ValueError: Cannot scatter at dim " << dim << ", because "
- << "shape is " << shape << ".";
- CHECK_EQ(shape[dim] % num_shards, 0)
- << "ValueError: The shape " << shape << " cannot be scattered at dim " << dim << " into "
- << num_shards << " shards.";
- std::vector<ShapeTupleObj::index_type> result{shape.begin(), shape.end()};
- result[dim] /= num_shards;
- return result;
-}
-
ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata,
- const std::string& shard_info, Module mod) {
+ std::string shard_info, Module mod) {
+ if (shard_info.empty() && mod.defined()) {
+ if (PackedFunc get_shard_info = mod->GetFunction("get_shard_info"); get_shard_info != nullptr) {
+ shard_info = get_shard_info().operator String();
+ }
+ }
ObjectPtr<ShardLoaderObj> n = make_object<ShardLoaderObj>();
- n->f_shard3d_fp16_ = mod->GetFunction("shard3d_fp16", true);
- n->f_shard3d_fp32_ = mod->GetFunction("shard3d_fp32", true);
- n->f_shard3d_uint32_ = mod->GetFunction("shard3d_uint32", true);
- CHECK(n->f_shard3d_fp16_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp16";
- CHECK(n->f_shard3d_fp32_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp32";
- CHECK(n->f_shard3d_uint32_ != nullptr) << "ValueError: Cannot find the function: shard3d_uint32";
n->metadata_ = NDArrayCacheMetadata::LoadFromStr(metadata, path_to_metadata);
n->current_file_ = nullptr;
- n->shard_info_.clear();
- std::unordered_map<std::string, int> shards = LoadShardInfoFromStr(shard_info);
+ n->param_info_.clear();
+ std::unordered_map<std::string, ShardInfo> shards = relax_vm::LoadShardInfoFromStr(shard_info);
for (const FileRecord& file_record : n->metadata_.records) {
for (const ParamRecord& param_record : file_record.records) {
const std::string& name = param_record.name;
- int shard_id = shards.count(name) ? shards[name] : -1;
- n->param_name_to_index_[name] = n->shard_info_.size();
- n->shard_info_.push_back(ShardInfo{&file_record, ¶m_record, shard_id});
+ int index = n->param_info_.size();
+ n->param_name_to_index_[name] = index;
+ ShardInfo& shard_info = shards[name];
+ for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) {
+ const std::string& name = shard_func.name;
+ if (PackedFunc f = mod.defined() ? mod->GetFunction(name, true) : nullptr; f != nullptr) {
+ n->shard_funcs_[name] = f;
+ } else if (const PackedFunc* f = runtime::Registry::Get(name)) {
+ n->shard_funcs_[name] = *f;
+ } else {
+ LOG(FATAL) << "ValueError: Undefined function: " << name;
+ }
+ }
+ n->param_info_.emplace_back(ParamInfo{&file_record, ¶m_record, shard_info});
}
}
return ObjectRef(std::move(n));
}
+NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func,
+ const NDArray& param) const {
+ Device device = param->device;
+ NDArray o = NDArray::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device);
+ PackedFunc f = this->shard_funcs_.at(shard_func.name);
+ int n = static_cast<int>(shard_func.params.size());
+ std::vector<TVMValue> tvm_args(n + 2);
+ std::vector<int> type_codes(n + 2);
+ TVMArgsSetter setter(tvm_args.data(), type_codes.data());
+ const DLTensor* w_in = param.operator->();
+ const DLTensor* w_out = o.operator->();
+ setter(0, const_cast<DLTensor*>(w_in));
+ for (int i = 0; i < n; ++i) {
+ setter(i + 1, shard_func.params[i]);
+ }
+ setter(n + 1, const_cast<DLTensor*>(w_out));
+ TVMRetValue rv;
+ f.CallPacked(TVMArgs(tvm_args.data(), type_codes.data(), n + 2), &rv);
+ return o;
+}
+
std::string GetSiblingPath(const std::string& path, const std::string& filename) {
size_t found = path.find_last_of("/\\");
if (found != std::string::npos) {
@@ -133,97 +143,69 @@ std::string GetSiblingPath(const std::string& path, const std::string& filename)
NDArray ShardLoaderObj::Load(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
- int shard_idx = worker->worker_id;
- Device device = worker->default_device;
- const auto& shard_info = shard_info_.at(weight_index);
- const ParamRecord* param = shard_info.param;
- const FileRecord* file = shard_info.file;
- int shard_dim = shard_info.shard_dim;
+ int worker_id = worker->worker_id;
int num_shards = worker->num_workers;
- Optional<NDArray> send = NullOpt;
- if (shard_idx == 0) {
+ Device device = worker->default_device;
+ const ParamInfo& param_info = param_info_.at(weight_index);
+ const ParamRecord* param = param_info.param;
+ const FileRecord* file = param_info.file;
+
+ auto load = [this, param, device, file]() {
if (file != current_file_) {
current_file_ = file;
std::string file_name = GetSiblingPath(this->metadata_.path, file->data_path);
LoadBinaryFromFile(file_name, &this->current_file_stream_);
}
- auto f_load = [](NDArray param, const void* data, size_t nbytes) {
- param.CopyFromBytes(data, nbytes);
- };
- if (shard_dim != -1) {
- send = this->Shard(param->Load(device, &this->current_file_stream_, f_load), shard_dim,
- num_shards);
+ return param->Load(
+ device, &this->current_file_stream_,
+ [](NDArray param, const void* data, size_t nbytes) { param.CopyFromBytes(data, nbytes); });
+ };
+
+ bool needs_sharding = !param_info.shard_info.funcs.empty();
+ if (needs_sharding) {
+ ShapeTuple shape = param_info.shard_info.funcs.back().output_info.shape;
+ DataType dtype = param_info.shard_info.funcs.back().output_info.dtype;
+ ICHECK(shape.size() >= 1 && shape[0] == num_shards)
+ << "ValueError: The first dimension of the "
+ << "output shape must be equal to the "
+ << "number of shards, but got: " << shape << " and num_shards = " << num_shards;
+ NDArray recv = NDArray::Empty(ShapeTuple(shape.begin() + 1, shape.end()), dtype, device);
+ if (worker_id == 0) {
+ NDArray w = load();
+ for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) {
+ w = this->ApplyShardFunc(shard_func, w);
+ }
+ ScatterFromWorker0(w, recv);
} else {
- send = param->Load(device, &this->current_file_stream_, f_load);
+ ScatterFromWorker0(NullOpt, recv);
}
- }
- if (shard_dim != -1) {
- NDArray recv =
- NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards), param->dtype, device);
- ScatterFromWorker0(send, recv);
return recv;
} else {
- NDArray recv;
- if (send.defined()) {
- recv = NDArray(send.value());
+ if (worker_id == 0) {
+ NDArray w = load();
+ BroadcastFromWorker0(w, w);
+ return w;
} else {
- recv = NDArray::Empty(param->shape, param->dtype, device);
+ NDArray w = NDArray::Empty(param->shape, param->dtype, device);
+ BroadcastFromWorker0(w, w);
+ return w;
}
- BroadcastFromWorker0(recv, recv);
- return recv;
}
}
Array<NDArray> ShardLoaderObj::LoadAll() const {
- int n = static_cast<int>(shard_info_.size());
+ int n = static_cast<int>(param_info_.size());
Array<NDArray> shards;
shards.reserve(n);
for (int i = 0; i < n; ++i) {
std::string param_name = "param_" + std::to_string(i);
+ ICHECK(this->param_name_to_index_.count(param_name));
int shard_id = this->param_name_to_index_.at(param_name);
shards.push_back(this->Load(shard_id));
}
return shards;
}
-NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
- ICHECK(dim >= 0 && dim < source->ndim);
- // Assemble a flattened 3d src tensor
- int64_t src_flat[3] = {1, 1, 1};
- {
- const int64_t* s = source.Shape().data();
- int ndim = source->ndim;
- src_flat[0] = std::accumulate(&s[0], &s[dim], 1, std::multiplies<int64_t>());
- src_flat[1] = s[dim];
- src_flat[2] = std::accumulate(&s[dim + 1], &s[ndim], 1, std::multiplies<int64_t>());
- }
- DLTensor src_tensor = *source.operator->();
- src_tensor.ndim = 3;
- src_tensor.shape = src_flat;
- // Assmeble a flattened 4d dst tensor
- int64_t dst_flat[4] = {num_slices, src_flat[0], src_flat[1] / num_slices, src_flat[2]};
- NDArray destination{nullptr};
- {
- std::vector<ShapeTuple::index_type> dst_shape = ShardShape(source.Shape(), dim, num_slices);
- dst_shape.insert(dst_shape.begin(), static_cast<ShapeTuple::index_type>(num_slices));
- destination = NDArray::Empty(dst_shape, source->dtype, source->device);
- }
- DLTensor dst_tensor = *destination.operator->();
- dst_tensor.ndim = 4;
- dst_tensor.shape = dst_flat;
- // Copy slices using the API
- if (source.DataType() == DataType::Float(32)) {
- this->f_shard3d_fp32_(&src_tensor, num_slices, &dst_tensor);
- } else if (source.DataType() == DataType::Float(16)) {
- this->f_shard3d_fp16_(&src_tensor, num_slices, &dst_tensor);
- } else if (source.DataType() == DataType::UInt(32)) {
- this->f_shard3d_uint32_(&src_tensor, num_slices, &dst_tensor);
- } else {
- LOG(FATAL) << "ValueError: Unsupported data type: " << source.DataType();
- }
- return destination;
-}
-
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create);
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
.set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) {
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc
index 3a5e961fe1..b2f53bfe1e 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -54,44 +54,56 @@ namespace tvm {
namespace runtime {
namespace relax_vm {
+template <typename ExpectedType>
+ExpectedType AsType(const picojson::value& json) {
+ ICHECK(json.is<ExpectedType>());
+ return json.get<ExpectedType>();
+}
+
+template <typename ValueType>
+ValueType GetValue(const picojson::object& json, const std::string& key) {
+ return AsType<ValueType>(json.at(key));
+}
+
NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) {
std::vector<ShapeTuple::index_type> shape;
{
- picojson::array shape_json = json.at("shape").get<picojson::array>();
+ picojson::array shape_json = GetValue<picojson::array>(json, "shape");
shape.reserve(shape_json.size());
for (const picojson::value& d : shape_json) {
- shape.push_back(d.get<int64_t>());
+ shape.push_back(AsType<int64_t>(d));
}
}
NDArrayCacheMetadata::FileRecord::ParamRecord result;
- result.name = json.at("name").get<std::string>();
- result.dtype = DataType(String2DLDataType(json.at("dtype").get<std::string>()));
- result.format = json.at("format").get<std::string>();
- result.nbytes = json.at("nbytes").get<int64_t>();
- result.byte_offset = json.at("byteOffset").get<int64_t>();
+ std::string dtype = GetValue<std::string>(json, "dtype");
+ result.name = GetValue<std::string>(json, "name");
+ result.dtype = DataType(String2DLDataType(dtype));
+ result.format = GetValue<std::string>(json, "format");
+ result.nbytes = GetValue<int64_t>(json, "nbytes");
+ result.byte_offset = GetValue<int64_t>(json, "byteOffset");
result.shape = ShapeTuple(std::move(shape));
return result;
}
NDArrayCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) {
- picojson::array records = json.at("records").get<picojson::array>();
+ picojson::array records = GetValue<picojson::array>(json, "records");
NDArrayCacheMetadata::FileRecord result;
- result.data_path = json.at("dataPath").get<std::string>();
- result.format = json.at("format").get<std::string>();
- result.nbytes = json.at("nbytes").get<int64_t>();
+ result.data_path = GetValue<std::string>(json, "dataPath");
+ result.format = GetValue<std::string>(json, "format");
+ result.nbytes = GetValue<int64_t>(json, "nbytes");
result.records.reserve(records.size());
for (const picojson::value& item : records) {
- result.records.push_back(JSONAsParamRecord(item.get<picojson::object>()));
+ result.records.push_back(JSONAsParamRecord(AsType<picojson::object>(item)));
}
return result;
}
NDArrayCacheMetadata JSONAsNDArrayCacheMetadata(const picojson::object& json) {
- picojson::array records = json.at("records").get<picojson::array>();
+ picojson::array records = GetValue<picojson::array>(json, "records");
NDArrayCacheMetadata result;
result.records.reserve(records.size());
for (const picojson::value& item : records) {
- result.records.push_back(JSONAsFileRecord(item.get<picojson::object>()));
+ result.records.push_back(JSONAsFileRecord(AsType<picojson::object>(item)));
}
return result;
}
@@ -100,20 +112,51 @@ NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_s
const std::string& path) {
picojson::value json_info;
picojson::parse(json_info, json_str);
- NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(json_info.get<picojson::object>());
+ NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(AsType<picojson::object>(json_info));
result.path = path;
return result;
}
-std::unordered_map<std::string, int> LoadShardInfoFromStr(const std::string& json_str) {
+ShardInfo::TensorInfo LoadTensorInfoFromJSON(const picojson::array& json_tensor_info) {
+ CHECK_EQ(json_tensor_info.size(), 2) << "ValueError: Invalid tensor info JSON";
+ picojson::array shape_json = AsType<picojson::array>(json_tensor_info[0]);
+ int ndim = shape_json.size();
+ std::vector<int64_t> shape;
+ shape.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+ shape.push_back(AsType<int64_t>(shape_json[i]));
+ }
+ std::string dtype = AsType<std::string>(json_tensor_info[1]);
+ return ShardInfo::TensorInfo{ShapeTuple(std::move(shape)), DataType(String2DLDataType(dtype))};
+}
+
+ShardInfo::ShardFunc LoadShardFuncFromJSON(const picojson::array& json_shard_func) {
+ int n = json_shard_func.size();
+ ShardInfo::ShardFunc shard_info;
+ shard_info.name = AsType<std::string>(json_shard_func[0]);
+ shard_info.output_info = LoadTensorInfoFromJSON(AsType<picojson::array>(json_shard_func[1]));
+ shard_info.params.reserve(n - 2);
+ for (int i = 2; i < n; ++i) {
+ shard_info.params.push_back(AsType<int64_t>(json_shard_func[i]));
+ }
+ return shard_info;
+}
+
+std::unordered_map<std::string, ShardInfo> LoadShardInfoFromStr(const std::string& json_str) {
picojson::value json_info;
picojson::parse(json_info, json_str);
- picojson::object json_obj = json_info.get<picojson::object>();
- std::unordered_map<std::string, int> result;
- for (const auto& kv : json_obj) {
+ picojson::object json_obj = AsType<picojson::object>(json_info);
+ std::unordered_map<std::string, ShardInfo> result;
+ for (auto kv : json_obj) {
std::string name = kv.first;
- int64_t shard_dim = kv.second.get<int64_t>();
- result[name] = shard_dim;
+ picojson::array json_shard_funcs = AsType<picojson::array>(kv.second);
+ ShardInfo info;
+ std::vector<ShardInfo::ShardFunc>& shard_funcs = info.funcs;
+ shard_funcs.reserve(json_shard_funcs.size());
+ for (const picojson::value& json_shard_func : json_shard_funcs) {
+ shard_funcs.push_back(LoadShardFuncFromJSON(AsType<picojson::array>(json_shard_func)));
+ }
+ result[name] = info;
}
return result;
}
diff --git a/src/runtime/relax_vm/ndarray_cache_support.h b/src/runtime/relax_vm/ndarray_cache_support.h
index 2bc638d6e3..c1beb5a946 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.h
+++ b/src/runtime/relax_vm/ndarray_cache_support.h
@@ -79,12 +79,29 @@ struct NDArrayCacheMetadata {
static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path);
};
+/*!
+ * \brief Information of sharding function,
+ * including the shard function name and extra parameters.
+ */
+struct ShardInfo {
+ struct TensorInfo {
+ ShapeTuple shape;
+ DataType dtype;
+ };
+ struct ShardFunc {
+ std::string name;
+ TensorInfo output_info;
+ std::vector<int64_t> params;
+ };
+ std::vector<ShardFunc> funcs;
+};
+
/*!
* \brief Load the shard information from dist
* \param path Path to the file to be loaded
* \return Mapping from parameter name to its shard dim
*/
-std::unordered_map<std::string, int> LoadShardInfoFromStr(const std::string& json_str);
+std::unordered_map<std::string, ShardInfo> LoadShardInfoFromStr(const std::string& json_str);
} // namespace relax_vm
} // namespace runtime
diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py
index c92bac6b46..923afe4ac1 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -21,7 +21,6 @@ import tempfile
import numpy as np
-from tvm import dlight as dl
from tvm import relax as rx
from tvm._ffi import register_func
from tvm.contrib import tvmjs
@@ -32,10 +31,50 @@ from tvm.script import relax as R
from tvm.target import Target
-@register_func("tests.disco.shard_with_numpy", override=True)
-def _shard_with_numpy(src, num_shards, tgt):
- s_0, s_1, s_2 = src.shape
- tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards, s_2).transpose(1, 0, 2, 3))
+@register_func("tests.disco.shard_dim_0", override=True)
+def _shard_dim_0(src, num_shards, tgt):
+ s_0, s_1 = src.shape
+ tgt.copyfrom(src.numpy().reshape(num_shards, s_0 // num_shards, s_1))
+
+
+@register_func("tests.disco.shard_dim_1", override=True)
+def _shard_dim_1(src, num_shards, tgt):
+ s_0, s_1 = src.shape
+ tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards).transpose(1, 0, 2))
+
+
+@register_func("tests.disco.shard_qkv_0", override=True)
+def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt):
+ total_dim, hidden_size = src.shape
+ head_dim = total_dim // (q_heads + kv_heads + kv_heads)
+ q_dim = q_heads * head_dim
+ kv_dim = kv_heads * head_dim
+ w_q = src.numpy()[:q_dim, :].reshape(
+ num_shards,
+ q_heads // num_shards,
+ head_dim,
+ hidden_size,
+ )
+ w_k = src.numpy()[q_dim : q_dim + kv_dim, :].reshape(
+ num_shards,
+ kv_heads // num_shards,
+ head_dim,
+ hidden_size,
+ )
+ w_v = src.numpy()[q_dim + kv_dim :, :].reshape(
+ num_shards,
+ kv_heads // num_shards,
+ head_dim,
+ hidden_size,
+ )
+ w_qkv = np.concatenate([w_q, w_k, w_v], axis=1)
+ tgt.copyfrom(w_qkv)
+
+
+@register_func("tests.disco.shard_qkv_1", override=True)
+def _shard_qkv_1(src, tgt):
+ s, _, _, h = src.shape # pylint: disable=invalid-name
+ tgt.copyfrom(src.numpy().reshape(s, -1, h))
def _create_loader(sess, path, param_dict, shard_info):
@@ -44,23 +83,33 @@ def _create_loader(sess, path, param_dict, shard_info):
with open(path_ndarray_cache, "r", encoding="utf-8") as i_f:
ndarray_cache = i_f.read()
loader_create = sess.get_global_func("runtime.disco.ShardLoader")
- shard_with_numpy = sess.get_global_func("tests.disco.shard_with_numpy")
- loader = loader_create(path_ndarray_cache, ndarray_cache, shard_info, shard_with_numpy)
+ loader = loader_create(path_ndarray_cache, ndarray_cache, json.dumps(shard_info), None)
return loader
def test_load_shard():
devices = [0, 1]
+ num_shards = len(devices)
param_dict = {
"x_0": np.random.uniform(size=[64, 128]).astype("float16"),
"x_1": np.random.uniform(size=[32, 128]).astype("float32"),
}
- shard_info = json.dumps(
- {
- "x_0": 1,
- "x_1": 0,
- }
- )
+ shard_info = {
+ "x_0": [
+ [
+ "tests.disco.shard_dim_1",
+ [(num_shards, 64, 64), "float16"],
+ num_shards,
+ ],
+ ],
+ "x_1": [
+ [
+ "tests.disco.shard_dim_0",
+ [(num_shards, 16, 128), "float32"],
+ num_shards,
+ ]
+ ],
+ }
with tempfile.TemporaryDirectory() as path:
sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
@@ -88,16 +137,27 @@ def test_load_shard():
def test_load_shard_in_relax():
devices = [0, 1]
+ num_shards = len(devices)
param_dict = {
"x_0": np.random.uniform(size=[64, 128]).astype("float16"),
"x_1": np.random.uniform(size=[32, 128]).astype("float32"),
}
- shard_info = json.dumps(
- {
- "x_0": 1,
- "x_1": 0,
- }
- )
+ shard_info = {
+ "x_0": [
+ [
+ "tests.disco.shard_dim_1",
+ [(num_shards, 64, 64), "float16"],
+ num_shards,
+ ],
+ ],
+ "x_1": [
+ [
+ "tests.disco.shard_dim_0",
+ [(num_shards, 16, 128), "float32"],
+ num_shards,
+ ]
+ ],
+ }
# pylint: disable=invalid-name
@I.ir_module
@@ -128,13 +188,6 @@ def test_load_shard_in_relax():
def relax_build(mod, target):
with target:
mod = rx.get_pipeline("zero")(mod) # pylint: disable=no-value-for-parameter
- mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
- dl.gpu.Matmul(),
- dl.gpu.GEMV(),
- dl.gpu.Reduction(),
- dl.gpu.GeneralReduction(),
- dl.gpu.Fallback(),
- )(mod)
return rx.build(mod, target="cuda")
target = Target(
@@ -149,9 +202,10 @@ def test_load_shard_in_relax():
)
with tempfile.TemporaryDirectory() as tmpdir:
dso_path = tmpdir + "/test.so"
- relax_build(Module, target).export_library(dso_path)
sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
+ relax_build(Module, target).export_library(dso_path)
+
mod = sess.load_vm_module(dso_path)
loader = _create_loader(sess, tmpdir, param_dict, shard_info)
result = mod["main"](loader)
@@ -175,16 +229,27 @@ def test_load_shard_in_relax():
def test_load_shard_all():
devices = [0, 1]
+ num_shards = len(devices)
param_dict = {
"param_0": np.random.uniform(size=[64, 128]).astype("float16"),
"param_1": np.random.uniform(size=[32, 128]).astype("float32"),
}
- shard_info = json.dumps(
- {
- "param_0": 1,
- "param_1": 0,
- }
- )
+ shard_info = {
+ "param_0": [
+ [
+ "tests.disco.shard_dim_1",
+ [(num_shards, 64, 64), "float16"],
+ num_shards,
+ ],
+ ],
+ "param_1": [
+ [
+ "tests.disco.shard_dim_0",
+ [(2, 16, 128), "float32"],
+ num_shards,
+ ]
+ ],
+ }
with tempfile.TemporaryDirectory() as path:
sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
@@ -205,7 +270,7 @@ def test_load_shard_broadcast():
"param_0": np.random.uniform(size=[64, 128]).astype("float16"),
"param_1": np.random.uniform(size=[32, 128]).astype("float32"),
}
- shard_info = "{}"
+ shard_info = {}
with tempfile.TemporaryDirectory() as path:
sess = di.ThreadedSession(num_workers=len(devices))
sess.init_ccl("nccl", *devices)
@@ -220,8 +285,68 @@ def test_load_shard_broadcast():
np.testing.assert_equal(param_dict["param_1"], p_1[1].numpy())
+def test_load_qkv_proj_shard(): # pylint: disable=too-many-locals
+ devices = [0, 1]
+ num_shards = len(devices)
+ q_heads = 8
+ kv_heads = 10
+ head_dim = 10
+ hidden_size = 20
+ w_q = np.random.uniform(size=[q_heads * head_dim, hidden_size]).astype("float16")
+ w_k = np.random.uniform(size=[kv_heads * head_dim, hidden_size]).astype("float16")
+ w_v = np.random.uniform(size=[kv_heads * head_dim, hidden_size]).astype("float16")
+ w_qkv = np.concatenate([w_q, w_k, w_v], axis=0)
+ param_dict = {"w_qkv": w_qkv}
+ np_qkv = np.concatenate(
+ [
+ w_q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)),
+ w_k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)),
+ w_v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)),
+ ],
+ axis=1,
+ ).reshape((num_shards, -1, hidden_size))
+
+ shard_info = {
+ "w_qkv": [
+ [
+ "tests.disco.shard_qkv_0",
+ [
+ (num_shards, (q_heads + kv_heads * 2) // num_shards, head_dim, hidden_size),
+ "float16",
+ ],
+ num_shards,
+ q_heads,
+ kv_heads,
+ ],
+ [
+ "tests.disco.shard_qkv_1",
+ [
+ (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim, hidden_size),
+ "float16",
+ ],
+ ],
+ ],
+ }
+
+ with tempfile.TemporaryDirectory() as path:
+ sess = di.ThreadedSession(num_workers=len(devices))
+ sess.init_ccl("nccl", *devices)
+ loader = _create_loader(sess, path, param_dict, shard_info)
+ loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoad")
+ d_0 = loader_load(loader, ShapeTuple([0]))
+ np.testing.assert_equal(
+ np_qkv[0],
+ d_0.debug_get_from_remote(0).numpy(),
+ )
+ np.testing.assert_equal(
+ np_qkv[1],
+ d_0.debug_get_from_remote(1).numpy(),
+ )
+
+
if __name__ == "__main__":
test_load_shard()
test_load_shard_in_relax()
test_load_shard_all()
test_load_shard_broadcast()
+ test_load_qkv_proj_shard()