You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/08/28 14:10:41 UTC
[tvm] branch main updated: [Runtime] Serialization/Deserialization of runtime module (#15244)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8f60213cb5 [Runtime] Serialization/Deserialization of runtime module (#15244)
8f60213cb5 is described below
commit 8f60213cb5a75cc87bb39d2645a9594363cb70e1
Author: Sunghyun Park <su...@umich.edu>
AuthorDate: Mon Aug 28 07:10:33 2023 -0700
[Runtime] Serialization/Deserialization of runtime module (#15244)
---
include/tvm/runtime/module.h | 5 +
include/tvm/target/codegen.h | 16 +++
python/tvm/contrib/torch/optimize_torch.py | 24 ++--
python/tvm/runtime/module.py | 1 -
src/contrib/torch/base64.h | 75 -----------
.../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 143 +++++++++++++--------
src/node/structural_hash.cc | 17 +++
src/runtime/library_module.cc | 9 --
src/runtime/library_module.h | 10 ++
src/target/codegen.cc | 87 ++++++++++---
.../unittest/test_roundtrip_runtime_module.py | 121 +++++++++++++++++
11 files changed, 344 insertions(+), 164 deletions(-)
diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index 60e3535319..cb88bc53f9 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -232,6 +232,11 @@ class TVM_DLL ModuleNode : public Object {
return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0;
}
+ /*! \brief Returns true if this module is 'Binary Serializable'. */
+ bool IsBinarySerializable() const {
+ return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0;
+ }
+
/*!
* \brief Returns true if this module has a definition for a function of \p name. If
* \p query_imports is true, also search in any imported modules.
diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h
index 46a19ad71b..0490e14e22 100644
--- a/include/tvm/target/codegen.h
+++ b/include/tvm/target/codegen.h
@@ -47,6 +47,21 @@ using runtime::TVMRetValue;
*/
runtime::Module Build(IRModule mod, Target target);
+/*!
+ * \brief Serialize runtime module including its submodules
+ * \param mod The runtime module to serialize including its import tree.
+ * \param export_dso By default, include the info of DSOExportable modules. If disabled, an error
+ * will be raised when encountering DSO modules.
+ */
+std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true);
+
+/*!
+ * \brief Deserialize runtime module including its submodules
+ * \param blob byte stream, which are generated by `SerializeModuleToBytes`.
+ * \return runtime::Module runtime module constructed from the given stream
+ */
+runtime::Module DeserializeModuleFromBytes(std::string blob);
+
/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
@@ -77,6 +92,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib,
runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
const std::string& target_triple,
const std::string& c_symbol_prefix = "");
+
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_H_
diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py
index cbba590e85..dfe35f2aae 100644
--- a/python/tvm/contrib/torch/optimize_torch.py
+++ b/python/tvm/contrib/torch/optimize_torch.py
@@ -25,17 +25,19 @@ optimize_torch: a function similar to `torch.jit.trace`,
which is used to optimize the `torch.nn.module` by TVM metaSchedule,
and returns a custom TorchScript operator
"""
-import base64
+
import contextlib
import tempfile
from typing import Optional, Tuple, Union
-
+import base64
import torch
import torch.utils.dlpack
import tvm
+import tvm._ffi
+from tvm._ffi import register_func
from tvm import meta_schedule as ms
from tvm import relay
-from tvm._ffi import get_global_func, register_func
+from tvm._ffi import get_global_func
from tvm.target import Target
@@ -51,14 +53,6 @@ class GraphExecutorFactoryWrapper(torch.nn.Module):
return ret
-@register_func("script_torch.save_to_base64")
-def save_to_base64(obj) -> bytes:
- with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
- obj.export_library(tmpfile.name)
- with open(tmpfile.name, "rb") as temp_file:
- return base64.b64encode(temp_file.read())
-
-
def optimize_torch(
func,
example_inputs,
@@ -173,3 +167,11 @@ def optimize_torch(
save_runtime_mod(executor_factory.module)
return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())
+
+
+@register_func("export_runtime_module")
+def save_to_base64(obj) -> bytes:
+ with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
+ obj.export_library(tmpfile.name)
+ with open(tmpfile.name, "rb") as temp_file:
+ return base64.b64encode(temp_file.read())
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index c9e3eb6add..2a1db2cbb2 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -23,7 +23,6 @@ import struct
from typing import Sequence
import numpy as np
-import tvm._ffi
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
from tvm._ffi.libinfo import find_include_path
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h
deleted file mode 100644
index d7dac4b86c..0000000000
--- a/src/contrib/torch/base64.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file base64.h
- * \brief Util functions for converting plain bytes back to plain bytes
- */
-
-#ifndef TVM_CONTRIB_TORCH_BASE64_H_
-#define TVM_CONTRIB_TORCH_BASE64_H_
-
-#include <tvm/runtime/logging.h>
-
-#include <cctype>
-#include <cstdio>
-#include <string>
-
-#include "../../support/base64.h"
-
-namespace tvm {
-namespace support {
-
-inline size_t b64strlen(const std::string b64str) {
- ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
- size_t length = b64str.size() / 4 * 3;
- if (b64str[b64str.size() - 2] == '=') {
- length -= 2;
- } else if (b64str[b64str.size() - 1] == '=') {
- length -= 1;
- }
- return length;
-}
-
-inline void b64decode(const std::string b64str, u_char* ret) {
- size_t index = 0;
- const auto length = b64str.size();
- for (size_t i = 0; i < length; i += 4) {
- int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
- int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
- int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
- int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
- u_char st1 = (ch0 << 2) + (ch1 >> 4);
- ret[index++] = st1;
- if (b64str[i + 2] != '=') {
- u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
- ret[index++] = st2;
- if (b64str[i + 3] != '=') {
- u_char st3 = ((ch2 & 0b11) << 6) + ch3;
- ret[index++] = st3;
- }
- }
- }
- ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
-}
-
-} // namespace support
-} // namespace tvm
-
-#endif // TVM_CONTRIB_TORCH_BASE64_H_
diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
index fb570c163f..c77996cf67 100644
--- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
+++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
@@ -29,7 +29,7 @@
#include <vector>
#include "../../../runtime/graph_executor/graph_executor_factory.h"
-#include "../base64.h"
+#include "../../support/base64.h"
#include "runtime_bridge.h"
namespace tvm {
@@ -46,54 +46,6 @@ struct ThreadLocalStore {
}
};
-/*
- * Encode TVM runtime module to base64 stream
- */
-std::string serialize(tvm::runtime::Module module) {
- static const runtime::PackedFunc* f_to_str =
- runtime::Registry::Get("script_torch.save_to_base64");
- ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
- "`script_torch.save_to_base64` in the global registry";
- return (*f_to_str)(module);
-}
-
-struct Deleter { // deleter
- explicit Deleter(std::string file_name) { this->file_name = file_name; }
- void operator()(FILE* p) const {
- fclose(p);
- ICHECK(remove(file_name.c_str()) == 0)
- << "remove temporary file (" << file_name << ") unsuccessfully";
- }
- std::string file_name;
-};
-
-/*
- * Decode TVM runtime module from base64 stream
- */
-tvm::runtime::Module deserialize(std::string state) {
- auto length = tvm::support::b64strlen(state);
-
- std::vector<u_char> bytes(length); // bytes stream
- tvm::support::b64decode(state, bytes.data());
-
- const std::string name = tmpnam(NULL);
- auto file_name = name + ".so";
- std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
- fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
- fflush(pFile.get());
-
- std::string load_f_name = "runtime.module.loadfile_so";
- const PackedFunc* f = runtime::Registry::Get(load_f_name);
- ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
- << " resolved to (" << load_f_name << ") in the global registry."
- << "Ensure that you have loaded the correct runtime code, and"
- << "that you are on the correct hardware architecture.";
-
- tvm::runtime::Module ret = (*f)(file_name, "");
-
- return ret;
-}
-
TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
ThreadLocalStore::ThreadLocal()->mod = mod;
});
@@ -242,15 +194,104 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod
return output_length;
}
+inline size_t b64strlen(const std::string b64str) {
+ ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
+ size_t length = b64str.size() / 4 * 3;
+ if (b64str[b64str.size() - 2] == '=') {
+ length -= 2;
+ } else if (b64str[b64str.size() - 1] == '=') {
+ length -= 1;
+ }
+ return length;
+}
+
+inline void b64decode(const std::string b64str, uint8_t* ret) {
+ size_t index = 0;
+ const auto length = b64str.size();
+ for (size_t i = 0; i < length; i += 4) {
+ int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
+ int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
+ int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
+ int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
+ uint8_t st1 = (ch0 << 2) + (ch1 >> 4);
+ ret[index++] = st1;
+ if (b64str[i + 2] != '=') {
+ uint8_t st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
+ ret[index++] = st2;
+ if (b64str[i + 3] != '=') {
+ uint8_t st3 = ((ch2 & 0b11) << 6) + ch3;
+ ret[index++] = st3;
+ }
+ }
+ }
+ ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
+}
+
+/*!
+ * \brief Export TVM runtime module to base64 stream including its submodules.
+ * Note that this targets modules that are binary serializable and DSOExportable.
+ * \param module The runtime module to export
+ * \return std::string The content of exported file
+ */
+std::string ExportModuleToBase64(tvm::runtime::Module module) {
+ static const tvm::runtime::PackedFunc* f_to_str =
+ tvm::runtime::Registry::Get("export_runtime_module");
+ ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
+ "`export_runtime_module` in the global registry";
+ return (*f_to_str)(module);
+}
+
+struct Deleter { // deleter
+ explicit Deleter(std::string file_name) { this->file_name = file_name; }
+ void operator()(FILE* p) const {
+ fclose(p);
+ ICHECK(remove(file_name.c_str()) == 0)
+ << "remove temporary file (" << file_name << ") unsuccessfully";
+ }
+ std::string file_name;
+};
+
+/*!
+ * \brief Import TVM runtime module from base64 stream
+ * Note that this targets modules that are binary serializable and DSOExportable.
+ * \param base64str base64 stream, which are generated by `ExportModuleToBase64`.
+ * \return runtime::Module runtime module constructed from the given stream
+ */
+tvm::runtime::Module ImportModuleFromBase64(std::string base64str) {
+ auto length = b64strlen(base64str);
+
+ std::vector<uint8_t> bytes(length); // bytes stream
+ b64decode(base64str, bytes.data());
+
+ auto now = std::chrono::system_clock::now();
+ auto in_time_t = std::chrono::system_clock::to_time_t(now);
+ std::stringstream datetime;
+ datetime << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d-%X");
+ const std::string file_name = "tmp-module-" + datetime.str() + ".so";
+ LOG(INFO) << file_name;
+ std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
+ fwrite(bytes.data(), sizeof(uint8_t), length, pFile.get());
+ fflush(pFile.get());
+
+ std::string load_f_name = "runtime.module.loadfile_so";
+ const tvm::runtime::PackedFunc* f = tvm::runtime::Registry::Get(load_f_name);
+ ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
+ << " resolved to (" << load_f_name << ") in the global registry."
+ << "Ensure that you have loaded the correct runtime code, and"
+ << "that you are on the correct hardware architecture.";
+ tvm::runtime::Module ret = (*f)(file_name, "");
+ return ret;
+}
+
char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
- std::string std = tvm::contrib::serialize(runtime_module->mod);
+ std::string std = ExportModuleToBase64(runtime_module->mod);
char* ret = new char[std.length() + 1];
snprintf(ret, std.length() + 1, "%s", std.c_str());
return ret;
}
TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
- tvm::runtime::Module ret = tvm::contrib::deserialize(state);
+ tvm::runtime::Module ret = ImportModuleFromBase64(state);
return new TVMContribTorchRuntimeModule(ret);
}
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 6cf796d344..9643480292 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -28,6 +28,7 @@
#include <tvm/runtime/container/adt.h>
#include <tvm/runtime/profiling.h>
#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
#include <algorithm>
#include <unordered_map>
@@ -360,6 +361,22 @@ struct ADTObjTrait {
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
+struct ModuleNodeTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+ static constexpr const std::nullptr_t SHashReduce = nullptr;
+ static constexpr const std::nullptr_t SEqualReduce = nullptr;
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait)
+ .set_creator([](const std::string& blob) {
+ runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob);
+ return RefToObjectPtr::Get(rtmod);
+ })
+ .set_repr_bytes([](const Object* n) -> std::string {
+ const auto* rtmod = static_cast<const runtime::ModuleNode*>(n);
+ return codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*export_dso*/ false);
+ });
+
void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
bool hash_data) {
ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index eb5e85beb5..a1a86d0388 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -67,15 +67,6 @@ class LibraryModuleNode final : public ModuleNode {
PackedFuncWrapper packed_func_wrapper_;
};
-/*!
- * \brief Helper classes to get into internal of a module.
- */
-class ModuleInternal {
- public:
- // Get mutable reference of imports.
- static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
-};
-
PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value;
diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h
index 167e819601..d4d32abe21 100644
--- a/src/runtime/library_module.h
+++ b/src/runtime/library_module.h
@@ -30,6 +30,7 @@
#include <functional>
#include <string>
+#include <vector>
namespace tvm {
namespace runtime {
@@ -78,6 +79,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
*/
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);
+/*!
+ * \brief Helper classes to get into internal of a module.
+ */
+class ModuleInternal {
+ public:
+ // Get mutable reference of imports.
+ static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
+};
+
/*!
* \brief Type alias for function to wrap a TVMBackendPackedCFunc.
* \param The function address imported from a module.
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index 6e31db4f60..d1f2d4a479 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -37,6 +37,9 @@
#include <unordered_set>
#include <vector>
+#include "../runtime/library_module.h"
+#include "../support/base64.h"
+
namespace tvm {
namespace codegen {
@@ -76,13 +79,16 @@ class ModuleSerializer {
public:
explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); }
- void SerializeModule(dmlc::Stream* stream) {
+ void SerializeModuleToBytes(dmlc::Stream* stream, bool export_dso) {
// Only have one DSO module and it is in the root, then
// we will not produce import_tree_.
bool has_import_tree = true;
- if (mod_->IsDSOExportable() && mod_->imports().empty()) {
- has_import_tree = false;
+
+ if (mod_->IsDSOExportable()) {
+ ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
+ has_import_tree = !mod_->imports().empty();
}
+
uint64_t sz = 0;
if (has_import_tree) {
// we will append one key for _import_tree
@@ -96,17 +102,26 @@ class ModuleSerializer {
for (const auto& group : mod_group_vec_) {
ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module";
- if (!group[0]->IsDSOExportable()) {
+ // we prioritize export dso when a module is both serializable and exportable
+ if (export_dso) {
+ if (group[0]->IsDSOExportable()) {
+ if (has_import_tree) {
+ std::string mod_type_key = "_lib";
+ stream->Write(mod_type_key);
+ }
+ } else if (group[0]->IsBinarySerializable()) {
+ ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
+ std::string mod_type_key = group[0]->type_key();
+ stream->Write(mod_type_key);
+ group[0]->SaveToBinary(stream);
+ }
+ } else {
+ ICHECK(group[0]->IsBinarySerializable())
+ << group[0]->type_key() << " is not binary serializable.";
ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
std::string mod_type_key = group[0]->type_key();
stream->Write(mod_type_key);
group[0]->SaveToBinary(stream);
- } else {
- // DSOExportable: do not need binary
- if (has_import_tree) {
- std::string mod_type_key = "_lib";
- stream->Write(mod_type_key);
- }
}
}
@@ -240,22 +255,60 @@ class ModuleSerializer {
std::vector<uint64_t> import_tree_child_indices_;
};
-namespace {
-std::string SerializeModule(const runtime::Module& mod) {
+std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
ModuleSerializer module_serializer(mod);
- module_serializer.SerializeModule(stream);
-
+ module_serializer.SerializeModuleToBytes(stream, export_dso);
return bin;
}
-} // namespace
+
+runtime::Module DeserializeModuleFromBytes(std::string blob) {
+ dmlc::MemoryStringStream ms(&blob);
+ dmlc::Stream* stream = &ms;
+
+ uint64_t size;
+ ICHECK(stream->Read(&size));
+ std::vector<runtime::Module> modules;
+ std::vector<uint64_t> import_tree_row_ptr;
+ std::vector<uint64_t> import_tree_child_indices;
+
+ for (uint64_t i = 0; i < size; ++i) {
+ std::string tkey;
+ ICHECK(stream->Read(&tkey));
+ // "_lib" serves as a placeholder in the module import tree to indicate where
+ // to place the DSOModule
+ ICHECK(tkey != "_lib") << "Should not contain any placeholder for DSOModule.";
+ if (tkey == "_import_tree") {
+ ICHECK(stream->Read(&import_tree_row_ptr));
+ ICHECK(stream->Read(&import_tree_child_indices));
+ } else {
+ auto m = runtime::LoadModuleFromBinary(tkey, stream);
+ modules.emplace_back(m);
+ }
+ }
+
+ for (size_t i = 0; i < modules.size(); ++i) {
+ for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
+ auto module_import_addr = runtime::ModuleInternal::GetImportsAddr(modules[i].operator->());
+ auto child_index = import_tree_child_indices[j];
+ ICHECK(child_index < modules.size());
+ module_import_addr->emplace_back(modules[child_index]);
+ }
+ }
+
+ ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present";
+ // invariance: root module is always at location 0.
+ // The module order is collected via DFS
+ runtime::Module root_mod = modules[0];
+ return root_mod;
+}
std::string PackImportsToC(const runtime::Module& mod, bool system_lib,
const std::string& c_symbol_prefix) {
- std::string bin = SerializeModule(mod);
+ std::string bin = SerializeModuleToBytes(mod);
std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;
if (c_symbol_prefix.length() != 0) {
@@ -317,7 +370,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
<< "c_symbol_prefix advanced option should be used in conjuction with system-lib";
}
- std::string bin = SerializeModule(mod);
+ std::string bin = SerializeModuleToBytes(mod);
uint64_t nbytes = bin.length();
std::string header;
diff --git a/tests/python/unittest/test_roundtrip_runtime_module.py b/tests/python/unittest/test_roundtrip_runtime_module.py
new file mode 100644
index 0000000000..6a1abeedd9
--- /dev/null
+++ b/tests/python/unittest/test_roundtrip_runtime_module.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test roundtrip of runtime modules """
+# pylint: disable=missing-docstring
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import TVMError
+from tvm import relay
+
+
+def test_csource_module():
+ mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None)
+ # source module that is not binary serializable.
+ # Thus, it would raise an error.
+ assert not mod.is_binary_serializable
+ with pytest.raises(TVMError):
+ tvm.ir.load_json(tvm.ir.save_json(mod))
+
+
+def test_aot_module():
+ mod = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()
+ # aot module that is not binary serializable.
+ # Thus, it would raise an error.
+ assert not mod.is_binary_serializable
+ with pytest.raises(TVMError):
+ tvm.ir.load_json(tvm.ir.save_json(mod))
+
+
+def get_test_mod():
+ x = relay.var("x", shape=(1, 10), dtype="float32")
+ y = relay.var("y", shape=(1, 10), dtype="float32")
+ z = relay.add(x, y)
+ func = relay.Function([x, y], z)
+ return relay.build_module._build_module_no_factory(func, target="cuda")
+
+
+def get_cuda_mod():
+ # Get Cuda module which is binary serializable
+ return get_test_mod().imported_modules[0].imported_modules[0]
+
+
+@tvm.testing.requires_cuda
+def test_cuda_module():
+ mod = get_cuda_mod()
+ assert mod.type_key == "cuda"
+ assert mod.is_binary_serializable
+ new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
+ assert new_mod.type_key == "cuda"
+ assert new_mod.is_binary_serializable
+
+
+@tvm.testing.requires_cuda
+def test_valid_submodules():
+ mod, mod2, mod3, mod4 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
+
+ # Create the nested cuda module
+ mod.import_module(mod2)
+ mod2.import_module(mod3)
+ mod2.import_module(mod4)
+
+ # Root module and all submodules should be binary serializable since they are cuda module
+ assert mod.type_key == "cuda"
+ assert mod.is_binary_serializable
+ assert mod.imported_modules[0].type_key == "cuda"
+ assert mod.imported_modules[0].is_binary_serializable
+ assert mod.imported_modules[0].imported_modules[0].type_key == "cuda"
+ assert mod.imported_modules[0].imported_modules[1].type_key == "cuda"
+ assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
+ assert mod.imported_modules[0].imported_modules[1].is_binary_serializable
+
+ # The roundtripped mod should have the same structure
+ new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
+ assert new_mod.type_key == "cuda"
+ assert new_mod.is_binary_serializable
+ assert new_mod.imported_modules[0].type_key == "cuda"
+ assert new_mod.imported_modules[0].is_binary_serializable
+ assert new_mod.imported_modules[0].imported_modules[0].type_key == "cuda"
+ assert new_mod.imported_modules[0].imported_modules[1].type_key == "cuda"
+ assert new_mod.imported_modules[0].imported_modules[0].is_binary_serializable
+ assert new_mod.imported_modules[0].imported_modules[1].is_binary_serializable
+
+
+@tvm.testing.requires_cuda
+def test_invalid_submodules():
+ mod, mod2, mod3 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
+ mod4 = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()
+
+ # Create the nested cuda module
+ mod.import_module(mod2)
+ mod2.import_module(mod3)
+ mod2.import_module(mod4)
+
+ # One of submodules is not binary serializable.
+ assert mod.is_binary_serializable
+ assert mod.imported_modules[0].is_binary_serializable
+ assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
+ assert not mod.imported_modules[0].imported_modules[1].is_binary_serializable
+
+ # Therefore, we cannot roundtrip.
+ with pytest.raises(TVMError):
+ tvm.ir.load_json(tvm.ir.save_json(mod))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()