You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/01 15:35:48 UTC
[tvm] branch unity updated: [Runtime] Support Loading Standalone `ndarray-cache.json` (#15654)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6e2ca87cf8 [Runtime] Support Loading Standalone `ndarray-cache.json` (#15654)
6e2ca87cf8 is described below
commit 6e2ca87cf8fdce10227c6b4349fe77a3cda01537
Author: Junru Shao <ju...@apache.org>
AuthorDate: Fri Sep 1 08:35:41 2023 -0700
[Runtime] Support Loading Standalone `ndarray-cache.json` (#15654)
Prior to this PR, `ndarray-cache.json` is loaded, parsed along with the
concrete weights. This PR adds support to parse this JSON file to a
structured C++ class instead for later use.
---
src/runtime/relax_vm/ndarray_cache_support.cc | 150 +++++++++++++++++---------
src/runtime/relax_vm/ndarray_cache_support.h | 92 ++++++++++++++++
2 files changed, 193 insertions(+), 49 deletions(-)
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc
index 7d1d1fc59e..a19cee42d7 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -35,12 +35,12 @@
* runtime builtin provide as in this file.
*/
#define PICOJSON_USE_INT64
+#include "./ndarray_cache_support.h"
#include <picojson.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
-#include <sstream>
#include <string>
#include <vector>
@@ -51,6 +51,94 @@ namespace tvm {
namespace runtime {
namespace relax_vm {
+NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) {
+ std::vector<ShapeTuple::index_type> shape;
+ {
+ picojson::array shape_json = json.at("shape").get<picojson::array>();
+ shape.reserve(shape_json.size());
+ for (const picojson::value& d : shape_json) {
+ shape.push_back(d.get<int64_t>());
+ }
+ }
+ NDArrayCacheMetadata::FileRecord::ParamRecord result;
+ result.name = json.at("name").get<std::string>();
+ result.dtype = DataType(String2DLDataType(json.at("dtype").get<std::string>()));
+ result.format = json.at("format").get<std::string>();
+ result.nbytes = json.at("nbytes").get<int64_t>();
+ result.byte_offset = json.at("byteOffset").get<int64_t>();
+ result.shape = ShapeTuple(std::move(shape));
+ return result;
+}
+
+NDArrayCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) {
+ picojson::array records = json.at("records").get<picojson::array>();
+ NDArrayCacheMetadata::FileRecord result;
+ result.data_path = json.at("dataPath").get<std::string>();
+ result.format = json.at("format").get<std::string>();
+ result.nbytes = json.at("nbytes").get<int64_t>();
+ result.records.reserve(records.size());
+ for (const picojson::value& item : records) {
+ result.records.push_back(JSONAsParamRecord(item.get<picojson::object>()));
+ }
+ return result;
+}
+
+NDArrayCacheMetadata JSONAsNDArrayCacheMetadata(const picojson::object& json) {
+ picojson::array records = json.at("records").get<picojson::array>();
+ NDArrayCacheMetadata result;
+ result.records.reserve(records.size());
+ for (const picojson::value& item : records) {
+ result.records.push_back(JSONAsFileRecord(item.get<picojson::object>()));
+ }
+ return result;
+}
+
+NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromFile(const std::filesystem::path& path) {
+ std::string json_str;
+ LoadBinaryFromFile(path, &json_str);
+ picojson::value json_info;
+ picojson::parse(json_info, json_str);
+ NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(json_info.get<picojson::object>());
+ for (auto& file : result.records) {
+ file.data_path = path.parent_path() / file.data_path;
+ }
+ return result;
+}
+
+std::unordered_map<std::string, int> LoadShardInfoFromFile(const std::filesystem::path& path) {
+ std::string json_str;
+ LoadBinaryFromFile(path, &json_str);
+ picojson::value json_info;
+ picojson::parse(json_info, json_str);
+ std::unordered_map<std::string, int> result;
+ for (const auto& item : json_info.get<picojson::array>()) {
+ picojson::object kv_pair = item.get<picojson::object>();
+ std::string name = kv_pair["name"].get<std::string>();
+ int shard_dim = kv_pair["shard_dim"].get<int64_t>();
+ result[name] = shard_dim;
+ }
+ return result;
+}
+
+NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load(
+ Device device, const std::string* raw_data,
+ std::function<void(NDArray, const void*, int64_t)> f_load) const {
+ NDArray arr = NDArray::Empty(shape, dtype, device);
+ if (dtype == DataType::Float(32) && format == "f32-to-bf16") {
+ // decode bf16 to f32
+ std::vector<uint16_t> buffer(nbytes / 2);
+ std::vector<uint32_t> decoded(nbytes / 2);
+ std::memcpy(buffer.data(), raw_data->data() + byte_offset, nbytes);
+ for (size_t i = 0; i < buffer.size(); ++i) {
+ decoded[i] = static_cast<uint32_t>(buffer[i]) << 16;
+ }
+ f_load(arr, decoded.data(), decoded.size() * sizeof(uint32_t));
+ } else {
+ f_load(arr, raw_data->data() + byte_offset, nbytes);
+ }
+ return arr;
+}
+
/*!
* A NDArray cache to store pre-loaded arrays in the system.
*/
@@ -95,19 +183,14 @@ class NDArrayCache {
*/
static void Load(const std::string& cache_path, int device_type, int device_id) {
DLDevice device{static_cast<DLDeviceType>(device_type), device_id};
- std::string json_str;
- LoadBinaryFromFile(cache_path + "/ndarray-cache.json", &json_str);
- picojson::value json_info;
- picojson::parse(json_info, json_str);
- auto shard_records = json_info.get<picojson::object>()["records"].get<picojson::array>();
+ NDArrayCacheMetadata metadata =
+ NDArrayCacheMetadata::LoadFromFile(cache_path + "/ndarray-cache.json");
- Map<String, NDArray> result;
- std::string raw_data;
Optional<NDArray> staging_buffer;
-
- auto fcopy_param_from_bytes = [&](NDArray param, void* data, size_t nbytes) {
+ auto fcopy_param_from_bytes = [&](NDArray param, const void* data, size_t nbytes) {
if (device_type != kDLOpenCL) {
param.CopyFromBytes(data, nbytes);
+ return;
}
// special handle OpenCL
// OpenCL runtime can create a host side memory mirror
@@ -130,47 +213,16 @@ class NDArrayCache {
TVMSynchronize(device_type, device_id, nullptr);
};
- for (auto shard_item : shard_records) {
- auto shard_rec = shard_item.get<picojson::object>();
- ICHECK(shard_rec["dataPath"].is<std::string>());
- std::string data_path = shard_rec["dataPath"].get<std::string>();
-
- LoadBinaryFromFile(cache_path + "/" + data_path, &raw_data);
- CHECK_EQ(shard_rec["format"].get<std::string>(), "raw-shard");
- int64_t raw_nbytes = shard_rec["nbytes"].get<int64_t>();
- CHECK_EQ(raw_nbytes, raw_data.length())
+ Map<String, NDArray> result;
+ std::string raw_data;
+ for (const auto& shard_rec : metadata.records) {
+ LoadBinaryFromFile(shard_rec.data_path, &raw_data);
+ CHECK_EQ(shard_rec.format, "raw-shard") << "ValueError: Only `raw-shard` format is supported";
+ CHECK_EQ(shard_rec.nbytes, raw_data.length())
<< "ValueError: Parameters are not loaded properly. Please check your parameter shards "
"and git lfs installation";
-
- for (auto nd_item : shard_rec["records"].get<picojson::array>()) {
- auto nd_rec = nd_item.get<picojson::object>();
- CHECK(nd_rec["name"].is<std::string>());
- String name = nd_rec["name"].get<std::string>();
-
- std::vector<int64_t> shape;
- for (auto value : nd_rec["shape"].get<picojson::array>()) {
- shape.push_back(value.get<int64_t>());
- }
-
- DataType dtype(String2DLDataType(nd_rec["dtype"].get<std::string>()));
- std::string encode_format = nd_rec["format"].get<std::string>();
- int64_t offset = nd_rec["byteOffset"].get<int64_t>();
- int64_t nbytes = nd_rec["nbytes"].get<int64_t>();
- NDArray arr = NDArray::Empty(ShapeTuple(shape.begin(), shape.end()), dtype, device);
-
- if (dtype == DataType::Float(32) && encode_format == "f32-to-bf16") {
- // decode bf16 to f32
- std::vector<uint16_t> buffer(nbytes / 2);
- std::vector<uint32_t> decoded(nbytes / 2);
- std::memcpy(buffer.data(), raw_data.data() + offset, nbytes);
- for (size_t i = 0; i < buffer.size(); ++i) {
- decoded[i] = static_cast<uint32_t>(buffer[i]) << 16;
- }
- fcopy_param_from_bytes(arr, decoded.data(), decoded.size() * sizeof(uint32_t));
- } else {
- fcopy_param_from_bytes(arr, raw_data.data() + offset, nbytes);
- }
- Update(name, arr, true);
+ for (const auto& nd_rec : shard_rec.records) {
+ Update(nd_rec.name, nd_rec.Load(device, &raw_data, fcopy_param_from_bytes), true);
}
}
}
diff --git a/src/runtime/relax_vm/ndarray_cache_support.h b/src/runtime/relax_vm/ndarray_cache_support.h
new file mode 100644
index 0000000000..8ae93ce249
--- /dev/null
+++ b/src/runtime/relax_vm/ndarray_cache_support.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_
+#define TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <filesystem>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+/*!
+ * \brief Metadata for NDArray cache, which by default, is named as "ndarray-cache.json".
+ */
+struct NDArrayCacheMetadata {
+ /*! \brief Each shard of NDArray cache, which by default, is named as "params_shard_x.bin". */
+ struct FileRecord {
+ /*! \brief Metadata of each parameter */
+ struct ParamRecord {
+ /*!
+ * \brief Load the parameter from raw data.
+ * \param device The device to load the parameter onto.
+ * \param raw_data The raw data stream
+ * \param f_load The function to load the parameter from raw data.
+ */
+ NDArray Load(Device device, const std::string* raw_data,
+ std::function<void(NDArray, const void*, int64_t)> f_load) const;
+
+ /*! \brief Name of the parameter */
+ std::string name;
+ /*! \brief Shape of the parameter */
+ ShapeTuple shape;
+ /*! \brief Data type of the parameter */
+ DataType dtype;
+ /*! \brief Format of the parameter */
+ std::string format;
+ /*! \brief Number of bytes */
+ int64_t nbytes;
+ /*! \brief Offset from the raw stream */
+ int64_t byte_offset;
+ };
+
+ /*! \brief Relative path to the bin file */
+ std::filesystem::path data_path;
+ /*! \brief Format of the file */
+ std::string format;
+ /*! \brief Size of the file */
+ int64_t nbytes;
+ /*! \brief The parameters in the file */
+ std::vector<ParamRecord> records;
+ };
+ /*! \brief The files in the NDArray cache */
+ std::vector<FileRecord> records;
+
+ /*! \brief Load the metadata from a specific path */
+ static NDArrayCacheMetadata LoadFromFile(const std::filesystem::path& path);
+};
+
+/*!
+ * \brief Load the shard information from dist
+ * \param path Path to the file to be loaded
+ * \return Mapping from parameter name to its shard dim
+ */
+std::unordered_map<std::string, int> LoadShardInfoFromFile(const std::filesystem::path& path);
+
+} // namespace relax_vm
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_