You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/01 15:35:48 UTC

[tvm] branch unity updated: [Runtime] Support Loading Standalone `ndarray-cache.json` (#15654)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6e2ca87cf8 [Runtime] Support Loading Standalone `ndarray-cache.json` (#15654)
6e2ca87cf8 is described below

commit 6e2ca87cf8fdce10227c6b4349fe77a3cda01537
Author: Junru Shao <ju...@apache.org>
AuthorDate: Fri Sep 1 08:35:41 2023 -0700

    [Runtime] Support Loading Standalone `ndarray-cache.json` (#15654)
    
    Prior to this PR, `ndarray-cache.json` is loaded, parsed along with the
    concrete weights. This PR adds support to parse this JSON file to a
    structured C++ class instead for later use.
---
 src/runtime/relax_vm/ndarray_cache_support.cc | 150 +++++++++++++++++---------
 src/runtime/relax_vm/ndarray_cache_support.h  |  92 ++++++++++++++++
 2 files changed, 193 insertions(+), 49 deletions(-)

diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc
index 7d1d1fc59e..a19cee42d7 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -35,12 +35,12 @@
  * runtime builtin provide as in this file.
  */
 #define PICOJSON_USE_INT64
+#include "./ndarray_cache_support.h"
 
 #include <picojson.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/registry.h>
 
-#include <sstream>
 #include <string>
 #include <vector>
 
@@ -51,6 +51,94 @@ namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
+NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) {
+  std::vector<ShapeTuple::index_type> shape;
+  {
+    picojson::array shape_json = json.at("shape").get<picojson::array>();
+    shape.reserve(shape_json.size());
+    for (const picojson::value& d : shape_json) {
+      shape.push_back(d.get<int64_t>());
+    }
+  }
+  NDArrayCacheMetadata::FileRecord::ParamRecord result;
+  result.name = json.at("name").get<std::string>();
+  result.dtype = DataType(String2DLDataType(json.at("dtype").get<std::string>()));
+  result.format = json.at("format").get<std::string>();
+  result.nbytes = json.at("nbytes").get<int64_t>();
+  result.byte_offset = json.at("byteOffset").get<int64_t>();
+  result.shape = ShapeTuple(std::move(shape));
+  return result;
+}
+
+NDArrayCacheMetadata::FileRecord JSONAsFileRecord(const picojson::object& json) {
+  picojson::array records = json.at("records").get<picojson::array>();
+  NDArrayCacheMetadata::FileRecord result;
+  result.data_path = json.at("dataPath").get<std::string>();
+  result.format = json.at("format").get<std::string>();
+  result.nbytes = json.at("nbytes").get<int64_t>();
+  result.records.reserve(records.size());
+  for (const picojson::value& item : records) {
+    result.records.push_back(JSONAsParamRecord(item.get<picojson::object>()));
+  }
+  return result;
+}
+
+NDArrayCacheMetadata JSONAsNDArrayCacheMetadata(const picojson::object& json) {
+  picojson::array records = json.at("records").get<picojson::array>();
+  NDArrayCacheMetadata result;
+  result.records.reserve(records.size());
+  for (const picojson::value& item : records) {
+    result.records.push_back(JSONAsFileRecord(item.get<picojson::object>()));
+  }
+  return result;
+}
+
+NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromFile(const std::filesystem::path& path) {
+  std::string json_str;
+  LoadBinaryFromFile(path, &json_str);
+  picojson::value json_info;
+  picojson::parse(json_info, json_str);
+  NDArrayCacheMetadata result = JSONAsNDArrayCacheMetadata(json_info.get<picojson::object>());
+  for (auto& file : result.records) {
+    file.data_path = path.parent_path() / file.data_path;
+  }
+  return result;
+}
+
+std::unordered_map<std::string, int> LoadShardInfoFromFile(const std::filesystem::path& path) {
+  std::string json_str;
+  LoadBinaryFromFile(path, &json_str);
+  picojson::value json_info;
+  picojson::parse(json_info, json_str);
+  std::unordered_map<std::string, int> result;
+  for (const auto& item : json_info.get<picojson::array>()) {
+    picojson::object kv_pair = item.get<picojson::object>();
+    std::string name = kv_pair["name"].get<std::string>();
+    int shard_dim = kv_pair["shard_dim"].get<int64_t>();
+    result[name] = shard_dim;
+  }
+  return result;
+}
+
+NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load(
+    Device device, const std::string* raw_data,
+    std::function<void(NDArray, const void*, int64_t)> f_load) const {
+  NDArray arr = NDArray::Empty(shape, dtype, device);
+  if (dtype == DataType::Float(32) && format == "f32-to-bf16") {
+    // decode bf16 to f32
+    std::vector<uint16_t> buffer(nbytes / 2);
+    std::vector<uint32_t> decoded(nbytes / 2);
+    std::memcpy(buffer.data(), raw_data->data() + byte_offset, nbytes);
+    for (size_t i = 0; i < buffer.size(); ++i) {
+      decoded[i] = static_cast<uint32_t>(buffer[i]) << 16;
+    }
+    f_load(arr, decoded.data(), decoded.size() * sizeof(uint32_t));
+  } else {
+    f_load(arr, raw_data->data() + byte_offset, nbytes);
+  }
+  return arr;
+}
+
 /*!
  * A NDArray cache to store pre-loaded arrays in the system.
  */
@@ -95,19 +183,14 @@ class NDArrayCache {
    */
   static void Load(const std::string& cache_path, int device_type, int device_id) {
     DLDevice device{static_cast<DLDeviceType>(device_type), device_id};
-    std::string json_str;
-    LoadBinaryFromFile(cache_path + "/ndarray-cache.json", &json_str);
-    picojson::value json_info;
-    picojson::parse(json_info, json_str);
-    auto shard_records = json_info.get<picojson::object>()["records"].get<picojson::array>();
+    NDArrayCacheMetadata metadata =
+        NDArrayCacheMetadata::LoadFromFile(cache_path + "/ndarray-cache.json");
 
-    Map<String, NDArray> result;
-    std::string raw_data;
     Optional<NDArray> staging_buffer;
-
-    auto fcopy_param_from_bytes = [&](NDArray param, void* data, size_t nbytes) {
+    auto fcopy_param_from_bytes = [&](NDArray param, const void* data, size_t nbytes) {
       if (device_type != kDLOpenCL) {
         param.CopyFromBytes(data, nbytes);
+        return;
       }
       // special handle OpenCL
       // OpenCL runtime can create a host side memory mirror
@@ -130,47 +213,16 @@ class NDArrayCache {
       TVMSynchronize(device_type, device_id, nullptr);
     };
 
-    for (auto shard_item : shard_records) {
-      auto shard_rec = shard_item.get<picojson::object>();
-      ICHECK(shard_rec["dataPath"].is<std::string>());
-      std::string data_path = shard_rec["dataPath"].get<std::string>();
-
-      LoadBinaryFromFile(cache_path + "/" + data_path, &raw_data);
-      CHECK_EQ(shard_rec["format"].get<std::string>(), "raw-shard");
-      int64_t raw_nbytes = shard_rec["nbytes"].get<int64_t>();
-      CHECK_EQ(raw_nbytes, raw_data.length())
+    Map<String, NDArray> result;
+    std::string raw_data;
+    for (const auto& shard_rec : metadata.records) {
+      LoadBinaryFromFile(shard_rec.data_path, &raw_data);
+      CHECK_EQ(shard_rec.format, "raw-shard") << "ValueError: Only `raw-shard` format is supported";
+      CHECK_EQ(shard_rec.nbytes, raw_data.length())
           << "ValueError: Parameters are not loaded properly. Please check your parameter shards "
              "and git lfs installation";
-
-      for (auto nd_item : shard_rec["records"].get<picojson::array>()) {
-        auto nd_rec = nd_item.get<picojson::object>();
-        CHECK(nd_rec["name"].is<std::string>());
-        String name = nd_rec["name"].get<std::string>();
-
-        std::vector<int64_t> shape;
-        for (auto value : nd_rec["shape"].get<picojson::array>()) {
-          shape.push_back(value.get<int64_t>());
-        }
-
-        DataType dtype(String2DLDataType(nd_rec["dtype"].get<std::string>()));
-        std::string encode_format = nd_rec["format"].get<std::string>();
-        int64_t offset = nd_rec["byteOffset"].get<int64_t>();
-        int64_t nbytes = nd_rec["nbytes"].get<int64_t>();
-        NDArray arr = NDArray::Empty(ShapeTuple(shape.begin(), shape.end()), dtype, device);
-
-        if (dtype == DataType::Float(32) && encode_format == "f32-to-bf16") {
-          // decode bf16 to f32
-          std::vector<uint16_t> buffer(nbytes / 2);
-          std::vector<uint32_t> decoded(nbytes / 2);
-          std::memcpy(buffer.data(), raw_data.data() + offset, nbytes);
-          for (size_t i = 0; i < buffer.size(); ++i) {
-            decoded[i] = static_cast<uint32_t>(buffer[i]) << 16;
-          }
-          fcopy_param_from_bytes(arr, decoded.data(), decoded.size() * sizeof(uint32_t));
-        } else {
-          fcopy_param_from_bytes(arr, raw_data.data() + offset, nbytes);
-        }
-        Update(name, arr, true);
+      for (const auto& nd_rec : shard_rec.records) {
+        Update(nd_rec.name, nd_rec.Load(device, &raw_data, fcopy_param_from_bytes), true);
       }
     }
   }
diff --git a/src/runtime/relax_vm/ndarray_cache_support.h b/src/runtime/relax_vm/ndarray_cache_support.h
new file mode 100644
index 0000000000..8ae93ce249
--- /dev/null
+++ b/src/runtime/relax_vm/ndarray_cache_support.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_
+#define TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <filesystem>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace relax_vm {
+
+/*!
+ * \brief Metadata for NDArray cache, which by default, is named as "ndarray-cache.json".
+ */
+struct NDArrayCacheMetadata {
+  /*! \brief Each shard of NDArray cache, which by default, is named as "params_shard_x.bin". */
+  struct FileRecord {
+    /*! \brief Metadata of each parameter */
+    struct ParamRecord {
+      /*!
+       * \brief Load the parameter from raw data.
+       * \param device The device to load the parameter onto.
+       * \param raw_data The raw data stream
+       * \param f_load The function to load the parameter from raw data.
+       */
+      NDArray Load(Device device, const std::string* raw_data,
+                   std::function<void(NDArray, const void*, int64_t)> f_load) const;
+
+      /*! \brief Name of the parameter */
+      std::string name;
+      /*! \brief Shape of the parameter */
+      ShapeTuple shape;
+      /*! \brief Data type of the parameter */
+      DataType dtype;
+      /*! \brief Format of the parameter */
+      std::string format;
+      /*! \brief Number of bytes */
+      int64_t nbytes;
+      /*! \brief Offset from the raw stream */
+      int64_t byte_offset;
+    };
+
+    /*! \brief Relative path to the bin file */
+    std::filesystem::path data_path;
+    /*! \brief Format of the file */
+    std::string format;
+    /*! \brief Size of the file */
+    int64_t nbytes;
+    /*! \brief The parameters in the file */
+    std::vector<ParamRecord> records;
+  };
+  /*! \brief The files in the NDArray cache */
+  std::vector<FileRecord> records;
+
+  /*! \brief Load the metadata from a specific path */
+  static NDArrayCacheMetadata LoadFromFile(const std::filesystem::path& path);
+};
+
+/*!
+ * \brief Load the shard information from dist
+ * \param path Path to the file to be loaded
+ * \return Mapping from parameter name to its shard dim
+ */
+std::unordered_map<std::string, int> LoadShardInfoFromFile(const std::filesystem::path& path);
+
+}  // namespace relax_vm
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_