You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/08/28 14:10:41 UTC

[tvm] branch main updated: [Runtime] Serialization/Deserialization of runtime module (#15244)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f60213cb5 [Runtime] Serialization/Deserialization of runtime module (#15244)
8f60213cb5 is described below

commit 8f60213cb5a75cc87bb39d2645a9594363cb70e1
Author: Sunghyun Park <su...@umich.edu>
AuthorDate: Mon Aug 28 07:10:33 2023 -0700

    [Runtime] Serialization/Deserialization of runtime module (#15244)
---
 include/tvm/runtime/module.h                       |   5 +
 include/tvm/target/codegen.h                       |  16 +++
 python/tvm/contrib/torch/optimize_torch.py         |  24 ++--
 python/tvm/runtime/module.py                       |   1 -
 src/contrib/torch/base64.h                         |  75 -----------
 .../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc  | 143 +++++++++++++--------
 src/node/structural_hash.cc                        |  17 +++
 src/runtime/library_module.cc                      |   9 --
 src/runtime/library_module.h                       |  10 ++
 src/target/codegen.cc                              |  87 ++++++++++---
 .../unittest/test_roundtrip_runtime_module.py      | 121 +++++++++++++++++
 11 files changed, 344 insertions(+), 164 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index 60e3535319..cb88bc53f9 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -232,6 +232,11 @@ class TVM_DLL ModuleNode : public Object {
     return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0;
   }
 
+  /*! \brief Returns true if this module is 'Binary Serializable'. */
+  bool IsBinarySerializable() const {
+    return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0;
+  }
+
   /*!
    * \brief Returns true if this module has a definition for a function of \p name. If
    * \p query_imports is true, also search in any imported modules.
diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h
index 46a19ad71b..0490e14e22 100644
--- a/include/tvm/target/codegen.h
+++ b/include/tvm/target/codegen.h
@@ -47,6 +47,21 @@ using runtime::TVMRetValue;
  */
 runtime::Module Build(IRModule mod, Target target);
 
+/*!
+ * \brief Serialize runtime module including its submodules
+ * \param mod The runtime module to serialize including its import tree.
+ * \param export_dso By default, include the info of DSOExportable modules. If disabled, an error
+ * will be raised when encountering DSO modules.
+ */
+std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true);
+
+/*!
+ * \brief Deserialize runtime module including its submodules
+ * \param blob byte stream, which are generated by `SerializeModuleToBytes`.
+ * \return runtime::Module runtime module constructed from the given stream
+ */
+runtime::Module DeserializeModuleFromBytes(std::string blob);
+
 /*!
  * \brief Pack imported device library to a C file.
  *  Compile the C file and link with the host library
@@ -77,6 +92,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib,
 runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
                                   const std::string& target_triple,
                                   const std::string& c_symbol_prefix = "");
+
 }  // namespace codegen
 }  // namespace tvm
 #endif  // TVM_TARGET_CODEGEN_H_
diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py
index cbba590e85..dfe35f2aae 100644
--- a/python/tvm/contrib/torch/optimize_torch.py
+++ b/python/tvm/contrib/torch/optimize_torch.py
@@ -25,17 +25,19 @@ optimize_torch: a function similar to `torch.jit.trace`,
 which is used to optimize the `torch.nn.module` by TVM metaSchedule,
 and returns a custom TorchScript operator
 """
-import base64
+
 import contextlib
 import tempfile
 from typing import Optional, Tuple, Union
-
+import base64
 import torch
 import torch.utils.dlpack
 import tvm
+import tvm._ffi
+from tvm._ffi import register_func
 from tvm import meta_schedule as ms
 from tvm import relay
-from tvm._ffi import get_global_func, register_func
+from tvm._ffi import get_global_func
 from tvm.target import Target
 
 
@@ -51,14 +53,6 @@ class GraphExecutorFactoryWrapper(torch.nn.Module):
         return ret
 
 
-@register_func("script_torch.save_to_base64")
-def save_to_base64(obj) -> bytes:
-    with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
-        obj.export_library(tmpfile.name)
-        with open(tmpfile.name, "rb") as temp_file:
-            return base64.b64encode(temp_file.read())
-
-
 def optimize_torch(
     func,
     example_inputs,
@@ -173,3 +167,11 @@ def optimize_torch(
     save_runtime_mod(executor_factory.module)
 
     return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())
+
+
+@register_func("export_runtime_module")
+def save_to_base64(obj) -> bytes:
+    with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
+        obj.export_library(tmpfile.name)
+        with open(tmpfile.name, "rb") as temp_file:
+            return base64.b64encode(temp_file.read())
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index c9e3eb6add..2a1db2cbb2 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -23,7 +23,6 @@ import struct
 from typing import Sequence
 import numpy as np
 
-import tvm._ffi
 from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
 from tvm._ffi.libinfo import find_include_path
 from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h
deleted file mode 100644
index d7dac4b86c..0000000000
--- a/src/contrib/torch/base64.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file base64.h
- * \brief Util functions for converting plain bytes back to plain bytes
- */
-
-#ifndef TVM_CONTRIB_TORCH_BASE64_H_
-#define TVM_CONTRIB_TORCH_BASE64_H_
-
-#include <tvm/runtime/logging.h>
-
-#include <cctype>
-#include <cstdio>
-#include <string>
-
-#include "../../support/base64.h"
-
-namespace tvm {
-namespace support {
-
-inline size_t b64strlen(const std::string b64str) {
-  ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
-  size_t length = b64str.size() / 4 * 3;
-  if (b64str[b64str.size() - 2] == '=') {
-    length -= 2;
-  } else if (b64str[b64str.size() - 1] == '=') {
-    length -= 1;
-  }
-  return length;
-}
-
-inline void b64decode(const std::string b64str, u_char* ret) {
-  size_t index = 0;
-  const auto length = b64str.size();
-  for (size_t i = 0; i < length; i += 4) {
-    int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
-    int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
-    int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
-    int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
-    u_char st1 = (ch0 << 2) + (ch1 >> 4);
-    ret[index++] = st1;
-    if (b64str[i + 2] != '=') {
-      u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
-      ret[index++] = st2;
-      if (b64str[i + 3] != '=') {
-        u_char st3 = ((ch2 & 0b11) << 6) + ch3;
-        ret[index++] = st3;
-      }
-    }
-  }
-  ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
-}
-
-}  // namespace support
-}  // namespace tvm
-
-#endif  // TVM_CONTRIB_TORCH_BASE64_H_
diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
index fb570c163f..c77996cf67 100644
--- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
+++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
@@ -29,7 +29,7 @@
 #include <vector>
 
 #include "../../../runtime/graph_executor/graph_executor_factory.h"
-#include "../base64.h"
+#include "../../support/base64.h"
 #include "runtime_bridge.h"
 
 namespace tvm {
@@ -46,54 +46,6 @@ struct ThreadLocalStore {
   }
 };
 
-/*
- * Encode TVM runtime module to base64 stream
- */
-std::string serialize(tvm::runtime::Module module) {
-  static const runtime::PackedFunc* f_to_str =
-      runtime::Registry::Get("script_torch.save_to_base64");
-  ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
-                      "`script_torch.save_to_base64` in the global registry";
-  return (*f_to_str)(module);
-}
-
-struct Deleter {  // deleter
-  explicit Deleter(std::string file_name) { this->file_name = file_name; }
-  void operator()(FILE* p) const {
-    fclose(p);
-    ICHECK(remove(file_name.c_str()) == 0)
-        << "remove temporary file (" << file_name << ") unsuccessfully";
-  }
-  std::string file_name;
-};
-
-/*
- * Decode TVM runtime module from base64 stream
- */
-tvm::runtime::Module deserialize(std::string state) {
-  auto length = tvm::support::b64strlen(state);
-
-  std::vector<u_char> bytes(length);  // bytes stream
-  tvm::support::b64decode(state, bytes.data());
-
-  const std::string name = tmpnam(NULL);
-  auto file_name = name + ".so";
-  std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
-  fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
-  fflush(pFile.get());
-
-  std::string load_f_name = "runtime.module.loadfile_so";
-  const PackedFunc* f = runtime::Registry::Get(load_f_name);
-  ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
-                       << " resolved to (" << load_f_name << ") in the global registry."
-                       << "Ensure that you have loaded the correct runtime code, and"
-                       << "that you are on the correct hardware architecture.";
-
-  tvm::runtime::Module ret = (*f)(file_name, "");
-
-  return ret;
-}
-
 TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
   ThreadLocalStore::ThreadLocal()->mod = mod;
 });
@@ -242,15 +194,104 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod
   return output_length;
 }
 
+inline size_t b64strlen(const std::string b64str) {
+  ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
+  size_t length = b64str.size() / 4 * 3;
+  if (b64str[b64str.size() - 2] == '=') {
+    length -= 2;
+  } else if (b64str[b64str.size() - 1] == '=') {
+    length -= 1;
+  }
+  return length;
+}
+
+inline void b64decode(const std::string b64str, uint8_t* ret) {
+  size_t index = 0;
+  const auto length = b64str.size();
+  for (size_t i = 0; i < length; i += 4) {
+    int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
+    int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
+    int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
+    int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
+    uint8_t st1 = (ch0 << 2) + (ch1 >> 4);
+    ret[index++] = st1;
+    if (b64str[i + 2] != '=') {
+      uint8_t st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
+      ret[index++] = st2;
+      if (b64str[i + 3] != '=') {
+        uint8_t st3 = ((ch2 & 0b11) << 6) + ch3;
+        ret[index++] = st3;
+      }
+    }
+  }
+  ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
+}
+
+/*!
+ * \brief Export TVM runtime module to base64 stream including its submodules.
+ * Note that this targets modules that are binary serializable and DSOExportable.
+ * \param module The runtime module to export
+ * \return std::string The content of exported file
+ */
+std::string ExportModuleToBase64(tvm::runtime::Module module) {
+  static const tvm::runtime::PackedFunc* f_to_str =
+      tvm::runtime::Registry::Get("export_runtime_module");
+  ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
+                      "`export_runtime_module` in the global registry";
+  return (*f_to_str)(module);
+}
+
+struct Deleter {  // deleter
+  explicit Deleter(std::string file_name) { this->file_name = file_name; }
+  void operator()(FILE* p) const {
+    fclose(p);
+    ICHECK(remove(file_name.c_str()) == 0)
+        << "remove temporary file (" << file_name << ") unsuccessfully";
+  }
+  std::string file_name;
+};
+
+/*!
+ * \brief Import TVM runtime module from base64 stream
+ * Note that this targets modules that are binary serializable and DSOExportable.
+ * \param base64str base64 stream, which are generated by `ExportModuleToBase64`.
+ * \return runtime::Module runtime module constructed from the given stream
+ */
+tvm::runtime::Module ImportModuleFromBase64(std::string base64str) {
+  auto length = b64strlen(base64str);
+
+  std::vector<uint8_t> bytes(length);  // bytes stream
+  b64decode(base64str, bytes.data());
+
+  auto now = std::chrono::system_clock::now();
+  auto in_time_t = std::chrono::system_clock::to_time_t(now);
+  std::stringstream datetime;
+  datetime << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d-%X");
+  const std::string file_name = "tmp-module-" + datetime.str() + ".so";
+  LOG(INFO) << file_name;
+  std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
+  fwrite(bytes.data(), sizeof(uint8_t), length, pFile.get());
+  fflush(pFile.get());
+
+  std::string load_f_name = "runtime.module.loadfile_so";
+  const tvm::runtime::PackedFunc* f = tvm::runtime::Registry::Get(load_f_name);
+  ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
+                       << " resolved to (" << load_f_name << ") in the global registry."
+                       << "Ensure that you have loaded the correct runtime code, and"
+                       << "that you are on the correct hardware architecture.";
+  tvm::runtime::Module ret = (*f)(file_name, "");
+  return ret;
+}
+
 char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
-  std::string std = tvm::contrib::serialize(runtime_module->mod);
+  std::string std = ExportModuleToBase64(runtime_module->mod);
   char* ret = new char[std.length() + 1];
   snprintf(ret, std.length() + 1, "%s", std.c_str());
   return ret;
 }
 
 TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
-  tvm::runtime::Module ret = tvm::contrib::deserialize(state);
+  tvm::runtime::Module ret = ImportModuleFromBase64(state);
   return new TVMContribTorchRuntimeModule(ret);
 }
 
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 6cf796d344..9643480292 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -28,6 +28,7 @@
 #include <tvm/runtime/container/adt.h>
 #include <tvm/runtime/profiling.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
 
 #include <algorithm>
 #include <unordered_map>
@@ -360,6 +361,22 @@ struct ADTObjTrait {
 
 TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
 
+struct ModuleNodeTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+  static constexpr const std::nullptr_t SHashReduce = nullptr;
+  static constexpr const std::nullptr_t SEqualReduce = nullptr;
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait)
+    .set_creator([](const std::string& blob) {
+      runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob);
+      return RefToObjectPtr::Get(rtmod);
+    })
+    .set_repr_bytes([](const Object* n) -> std::string {
+      const auto* rtmod = static_cast<const runtime::ModuleNode*>(n);
+      return codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*export_dso*/ false);
+    });
+
 void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
                  bool hash_data) {
   ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index eb5e85beb5..a1a86d0388 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -67,15 +67,6 @@ class LibraryModuleNode final : public ModuleNode {
   PackedFuncWrapper packed_func_wrapper_;
 };
 
-/*!
- * \brief Helper classes to get into internal of a module.
- */
-class ModuleInternal {
- public:
-  // Get mutable reference of imports.
-  static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
-};
-
 PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
   return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
     TVMValue ret_value;
diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h
index 167e819601..d4d32abe21 100644
--- a/src/runtime/library_module.h
+++ b/src/runtime/library_module.h
@@ -30,6 +30,7 @@
 
 #include <functional>
 #include <string>
+#include <vector>
 
 namespace tvm {
 namespace runtime {
@@ -78,6 +79,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
  */
 void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);
 
+/*!
+ * \brief Helper classes to get into internal of a module.
+ */
+class ModuleInternal {
+ public:
+  // Get mutable reference of imports.
+  static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
+};
+
 /*!
  * \brief Type alias for function to wrap a TVMBackendPackedCFunc.
  * \param The function address imported from a module.
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index 6e31db4f60..d1f2d4a479 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -37,6 +37,9 @@
 #include <unordered_set>
 #include <vector>
 
+#include "../runtime/library_module.h"
+#include "../support/base64.h"
+
 namespace tvm {
 
 namespace codegen {
@@ -76,13 +79,16 @@ class ModuleSerializer {
  public:
   explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); }
 
-  void SerializeModule(dmlc::Stream* stream) {
+  void SerializeModuleToBytes(dmlc::Stream* stream, bool export_dso) {
     // Only have one DSO module and it is in the root, then
     // we will not produce import_tree_.
     bool has_import_tree = true;
-    if (mod_->IsDSOExportable() && mod_->imports().empty()) {
-      has_import_tree = false;
+
+    if (mod_->IsDSOExportable()) {
+      ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
+      has_import_tree = !mod_->imports().empty();
     }
+
     uint64_t sz = 0;
     if (has_import_tree) {
       // we will append one key for _import_tree
@@ -96,17 +102,26 @@ class ModuleSerializer {
 
     for (const auto& group : mod_group_vec_) {
       ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module";
-      if (!group[0]->IsDSOExportable()) {
+      // we prioritize export dso when a module is both serializable and exportable
+      if (export_dso) {
+        if (group[0]->IsDSOExportable()) {
+          if (has_import_tree) {
+            std::string mod_type_key = "_lib";
+            stream->Write(mod_type_key);
+          }
+        } else if (group[0]->IsBinarySerializable()) {
+          ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
+          std::string mod_type_key = group[0]->type_key();
+          stream->Write(mod_type_key);
+          group[0]->SaveToBinary(stream);
+        }
+      } else {
+        ICHECK(group[0]->IsBinarySerializable())
+            << group[0]->type_key() << " is not binary serializable.";
         ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
         std::string mod_type_key = group[0]->type_key();
         stream->Write(mod_type_key);
         group[0]->SaveToBinary(stream);
-      } else {
-        // DSOExportable: do not need binary
-        if (has_import_tree) {
-          std::string mod_type_key = "_lib";
-          stream->Write(mod_type_key);
-        }
       }
     }
 
@@ -240,22 +255,60 @@ class ModuleSerializer {
   std::vector<uint64_t> import_tree_child_indices_;
 };
 
-namespace {
-std::string SerializeModule(const runtime::Module& mod) {
+std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso) {
   std::string bin;
   dmlc::MemoryStringStream ms(&bin);
   dmlc::Stream* stream = &ms;
 
   ModuleSerializer module_serializer(mod);
-  module_serializer.SerializeModule(stream);
-
+  module_serializer.SerializeModuleToBytes(stream, export_dso);
   return bin;
 }
-}  // namespace
+
+runtime::Module DeserializeModuleFromBytes(std::string blob) {
+  dmlc::MemoryStringStream ms(&blob);
+  dmlc::Stream* stream = &ms;
+
+  uint64_t size;
+  ICHECK(stream->Read(&size));
+  std::vector<runtime::Module> modules;
+  std::vector<uint64_t> import_tree_row_ptr;
+  std::vector<uint64_t> import_tree_child_indices;
+
+  for (uint64_t i = 0; i < size; ++i) {
+    std::string tkey;
+    ICHECK(stream->Read(&tkey));
+    // "_lib" serves as a placeholder in the module import tree to indicate where
+    // to place the DSOModule
+    ICHECK(tkey != "_lib") << "Should not contain any placeholder for DSOModule.";
+    if (tkey == "_import_tree") {
+      ICHECK(stream->Read(&import_tree_row_ptr));
+      ICHECK(stream->Read(&import_tree_child_indices));
+    } else {
+      auto m = runtime::LoadModuleFromBinary(tkey, stream);
+      modules.emplace_back(m);
+    }
+  }
+
+  for (size_t i = 0; i < modules.size(); ++i) {
+    for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
+      auto module_import_addr = runtime::ModuleInternal::GetImportsAddr(modules[i].operator->());
+      auto child_index = import_tree_child_indices[j];
+      ICHECK(child_index < modules.size());
+      module_import_addr->emplace_back(modules[child_index]);
+    }
+  }
+
+  ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present";
+  // invariance: root module is always at location 0.
+  // The module order is collected via DFS
+  runtime::Module root_mod = modules[0];
+  return root_mod;
+}
 
 std::string PackImportsToC(const runtime::Module& mod, bool system_lib,
                            const std::string& c_symbol_prefix) {
-  std::string bin = SerializeModule(mod);
+  std::string bin = SerializeModuleToBytes(mod);
   std::string mdev_blob_name = c_symbol_prefix + runtime::symbol::tvm_dev_mblob;
 
   if (c_symbol_prefix.length() != 0) {
@@ -317,7 +370,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
         << "c_symbol_prefix advanced option should be used in conjuction with system-lib";
   }
 
-  std::string bin = SerializeModule(mod);
+  std::string bin = SerializeModuleToBytes(mod);
 
   uint64_t nbytes = bin.length();
   std::string header;
diff --git a/tests/python/unittest/test_roundtrip_runtime_module.py b/tests/python/unittest/test_roundtrip_runtime_module.py
new file mode 100644
index 0000000000..6a1abeedd9
--- /dev/null
+++ b/tests/python/unittest/test_roundtrip_runtime_module.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Test roundtrip of runtime modules """
+# pylint: disable=missing-docstring
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import TVMError
+from tvm import relay
+
+
+def test_csource_module():
+    mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None)
+    # source module that is not binary serializable.
+    # Thus, it would raise an error.
+    assert not mod.is_binary_serializable
+    with pytest.raises(TVMError):
+        tvm.ir.load_json(tvm.ir.save_json(mod))
+
+
+def test_aot_module():
+    mod = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()
+    # aot module that is not binary serializable.
+    # Thus, it would raise an error.
+    assert not mod.is_binary_serializable
+    with pytest.raises(TVMError):
+        tvm.ir.load_json(tvm.ir.save_json(mod))
+
+
+def get_test_mod():
+    x = relay.var("x", shape=(1, 10), dtype="float32")
+    y = relay.var("y", shape=(1, 10), dtype="float32")
+    z = relay.add(x, y)
+    func = relay.Function([x, y], z)
+    return relay.build_module._build_module_no_factory(func, target="cuda")
+
+
+def get_cuda_mod():
+    # Get Cuda module which is binary serializable
+    return get_test_mod().imported_modules[0].imported_modules[0]
+
+
+@tvm.testing.requires_cuda
+def test_cuda_module():
+    mod = get_cuda_mod()
+    assert mod.type_key == "cuda"
+    assert mod.is_binary_serializable
+    new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
+    assert new_mod.type_key == "cuda"
+    assert new_mod.is_binary_serializable
+
+
+@tvm.testing.requires_cuda
+def test_valid_submodules():
+    mod, mod2, mod3, mod4 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
+
+    # Create the nested cuda module
+    mod.import_module(mod2)
+    mod2.import_module(mod3)
+    mod2.import_module(mod4)
+
+    # Root module and all submodules should be binary serializable since they are cuda module
+    assert mod.type_key == "cuda"
+    assert mod.is_binary_serializable
+    assert mod.imported_modules[0].type_key == "cuda"
+    assert mod.imported_modules[0].is_binary_serializable
+    assert mod.imported_modules[0].imported_modules[0].type_key == "cuda"
+    assert mod.imported_modules[0].imported_modules[1].type_key == "cuda"
+    assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
+    assert mod.imported_modules[0].imported_modules[1].is_binary_serializable
+
+    # The roundtripped mod should have the same structure
+    new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
+    assert new_mod.type_key == "cuda"
+    assert new_mod.is_binary_serializable
+    assert new_mod.imported_modules[0].type_key == "cuda"
+    assert new_mod.imported_modules[0].is_binary_serializable
+    assert new_mod.imported_modules[0].imported_modules[0].type_key == "cuda"
+    assert new_mod.imported_modules[0].imported_modules[1].type_key == "cuda"
+    assert new_mod.imported_modules[0].imported_modules[0].is_binary_serializable
+    assert new_mod.imported_modules[0].imported_modules[1].is_binary_serializable
+
+
+@tvm.testing.requires_cuda
+def test_invalid_submodules():
+    mod, mod2, mod3 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
+    mod4 = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()
+
+    # Create the nested cuda module
+    mod.import_module(mod2)
+    mod2.import_module(mod3)
+    mod2.import_module(mod4)
+
+    # One of submodules is not binary serializable.
+    assert mod.is_binary_serializable
+    assert mod.imported_modules[0].is_binary_serializable
+    assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
+    assert not mod.imported_modules[0].imported_modules[1].is_binary_serializable
+
+    # Therefore, we cannot roundtrip.
+    with pytest.raises(TVMError):
+        tvm.ir.load_json(tvm.ir.save_json(mod))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()