You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/17 08:33:44 UTC
[tvm] branch main updated: [TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support (#12232)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 073304dadb [TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support (#12232)
073304dadb is described below
commit 073304dadb91ce70b3198cab8b3ae98ee4061b26
Author: Yaoda Zhou <ju...@sjtu.edu.cn>
AuthorDate: Wed Aug 17 16:33:37 2022 +0800
[TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support (#12232)
* first commit
* rename
* cmake
* deprecated
* newline
* config
* config
* typo
* skip tvm_class
* rename
* delete ptr
* delete ptr
* save progress
* boolean support
* cmake file
* polish code
* compile config
* improving the codes
* format
* doc&errormsg
* zero-cost copy
* one step
* to ndarray
* extra output
* delete extra codes
* update test
* boolean support
* strong test
* decrease memory copy
* polish
* reformat
* polish
* remove redundant import
Co-authored-by: juda <yz...@octoml.ai>
---
apps/pt_tvmdsoop/tests/test_as_torch.py | 7 +-
apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 129 ++++++++++
cmake/modules/contrib/PT_TVMDSOOP.cmake | 68 ++++--
python/tvm/contrib/torch/__init__.py | 25 +-
python/tvm/contrib/torch/module.py | 17 ++
python/tvm/contrib/torch/pytorch_tvm.py | 21 ++
.../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 259 --------------------
.../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 266 +++++++++++++++++++++
.../RuntimeModuleWrapperTorch.cc | 215 +++++++++++++++++
.../torch/tvm_module_wrapper/runtime_bridge.h | 116 +++++++++
10 files changed, 844 insertions(+), 279 deletions(-)
diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py
index 2c454e9454..a13d669e7f 100644
--- a/apps/pt_tvmdsoop/tests/test_as_torch.py
+++ b/apps/pt_tvmdsoop/tests/test_as_torch.py
@@ -17,6 +17,8 @@
# specific language governing permissions and limitations
# under the License.
"""Test script for tvm torch module"""
+import tempfile
+
import numpy as np
import torch
@@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu():
q1 = torch.arange(8, device=cuda0).type(torch.float32)
q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)
- ModuleGPU(q1, q2)
+ with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
+ torch.save(ModuleGPU, tmp.name)
+ loaded_mod = torch.load(tmp.name)
+ loaded_mod(q1, q2)
tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)
diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py
new file mode 100644
index 0000000000..4718b40439
--- /dev/null
+++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for boolean tensor support"""
+import tempfile
+
+import torch
+
+import tvm
+import tvm.testing
+from tvm.contrib.torch import as_torch, optimize_torch
+from tvm.script import tir as T
+
+
+def negate(x):
+ return x.logical_not()
+
+
+def sum_up_tensor(x):
+ return x.size(dim=0) - torch.sum(x.int())
+
+
+def tensor_boolean_operation(x):
+ arr1 = (x + 0.3).floor().bool()
+ arr2 = (~((x + 0.7).int().bool())).bool()
+ ret = ((arr1 & arr2).byte() + 0.5).half()
+ return ~(ret.bool())
+
+
+def test_bool_tensor_negate():
+ input = torch.ones(1, dtype=torch.bool)
+ optimized_negate = optimize_torch(
+ negate,
+ input,
+ )
+ with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
+ torch.save(optimized_negate, tmp.name)
+ loaded_mod = torch.load(tmp.name)
+ output = loaded_mod(negate(input))
+ tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_sum_up_tensor():
+ x = torch.randint(0, 2, (16,))
+ y = x.bool()
+ optimized_func = optimize_torch(
+ sum_up_tensor,
+ (y,),
+ )
+ ret1 = (x[x == 0]).size(dim=0)
+ ret2 = optimized_func(y).numpy()
+ tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
+
+
+def test_tensor_boolean_operation():
+ input = torch.rand(200)
+ model = optimize_torch(
+ tensor_boolean_operation,
+ input,
+ )
+ ret1 = tensor_boolean_operation(input)
+ ret2 = model(input)
+ tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
+
+
+@as_torch
+@T.prim_func
+def negate_tvmscript(
+ X: T.Buffer[(8, 8), "bool"],
+ Y: T.Buffer[(8, 8), "float32"],
+ Z: T.Buffer[(8, 8), "bool"],
+ U: T.Buffer[(8, 8), "float32"],
+) -> None:
+ for i, j in T.grid(8, 8):
+ with T.block():
+ if Y[i, j] > 0.0:
+ Z[i, j] = X[i, j]
+ U[i, j] = Y[i, j]
+ else:
+ Z[i, j] = not X[i, j]
+ U[i, j] = 0.0 - Y[i, j]
+
+
+def negate_vanila(x, y):
+ z = torch.zeros(8, 8).bool()
+ for i in range(8):
+ for j in range(8):
+ if y[i, j] > 0:
+ z[i, j] = x[i, j]
+ else:
+ z[i, j] = ~x[i, j]
+ return z
+
+
+def test_tvmscript_torch_decorator():
+ q1 = (torch.rand(8, 8) + 0.5).int().bool()
+ q2 = torch.rand(8, 8) - 0.5
+ q3 = torch.zeros(8, 8).bool()
+ q4 = torch.zeros(8, 8)
+
+ std1 = negate_vanila(q1, q2)
+ std2 = torch.abs(q2)
+
+ negate_tvmscript(q1, q2, q3, q4)
+
+ tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5)
+ tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5)
+
+
+if __name__ == "__main__":
+ test_tvmscript_torch_decorator()
+ test_bool_tensor_negate()
+ test_sum_up_tensor()
+ test_tensor_boolean_operation()
diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake
index 3bad3fd966..a73d3f38e9 100644
--- a/cmake/modules/contrib/PT_TVMDSOOP.cmake
+++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake
@@ -6,7 +6,7 @@
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
@@ -17,42 +17,80 @@
if(NOT USE_PT_TVMDSOOP STREQUAL "OFF")
find_package(PythonInterp REQUIRED)
-
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())"
OUTPUT_VARIABLE PT_PATH
RESULT_VARIABLE PT_STATUS)
- if (NOT ${PT_STATUS} EQUAL 0)
+
+ if(NOT ${PT_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get pytorch path")
endif()
string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}")
message(STATUS "PyTorch path: ${PT_PATH}")
- set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0")
+ execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch;print(torch.compiled_with_cxx11_abi())"
+ OUTPUT_VARIABLE PT_CXX_FLAG
+ RESULT_VARIABLE PT_STATUS)
+
+ string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}")
+ message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ")
+
+ if(${PT_CXX_FLAG} STREQUAL "False")
+ set(CXX_ABI_ENABLED 0)
+ else()
+ set(CXX_ABI_ENABLED 1)
+ endif()
+
+ set_property(
+ SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
+ APPEND PROPERTY
+ COMPILE_OPTIONS
+ "-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}"
+ "-I${PT_PATH}/include"
+ )
+
+ set_property(
+ SOURCE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc
+ APPEND PROPERTY
+ COMPILE_OPTIONS
+ "-I${PT_PATH}/include"
+ )
+
set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so")
if(NOT USE_CUDA STREQUAL "OFF")
add_definitions(-DPT_TVMDSOOP_ENABLE_GPU)
endif()
-
string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}")
- separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR})
+ separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND)
separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR})
+ # This old version is depereated and will be removed after tvm 0.11
+ set(LIBRARY_OLD_NAME pt_tvmdsoop)
- set(LIBRARY_NAME pt_tvmdsoop)
- tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc)
- add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS})
+ # This new library is set for pytorch integration, which solves the c++ abi imcompability issue
+ set(LIBRARY_NEW_NAME pt_tvmdsoop_new)
+ tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc)
+
+ tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc)
+
+ add_library(${LIBRARY_OLD_NAME} SHARED ${PTTVM_SRCS})
+ add_library(${LIBRARY_NEW_NAME} SHARED ${PTTVM_TORCH})
set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
- if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
- add_dependencies(${LIBRARY_NAME} tvm)
+ if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
+ add_dependencies(${LIBRARY_OLD_NAME} tvm)
+ add_dependencies(${LIBRARY_NEW_NAME} tvm)
endif()
- target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
- target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
- target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+ target_compile_options(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
+ target_link_libraries(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
+ target_compile_definitions(${LIBRARY_OLD_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+ target_compile_options(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
+ target_link_libraries(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
+ target_compile_definitions(${LIBRARY_NEW_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
endif()
-
diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py
index 340f9cef9e..c3dd34d470 100644
--- a/python/tvm/contrib/torch/__init__.py
+++ b/python/tvm/contrib/torch/__init__.py
@@ -18,11 +18,12 @@
"""Module container of Pytorch custom class"""
import os
import platform
+import warnings
import torch
from tvm._ffi import libinfo
-def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
+def _load_platform_specific_library(lib_name):
system = platform.system()
if system == "Darwin":
lib_file_name = lib_name + ".dylib"
@@ -33,11 +34,27 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
lib_path = libinfo.find_lib_path()[0]
lib_dir = os.path.dirname(lib_path)
lib_file_path = os.path.join(lib_dir, lib_file_name)
- torch.classes.load_library(lib_file_path)
+ try:
+ torch.classes.load_library(lib_file_path)
+ except OSError as err:
+ errmsg = str(err)
+ if errmsg.find("undefined symbol") != -1:
+ reason = " ".join(
+ (
+ "Got undefined symbol error,",
+ "which might be due to the CXXABI incompatibility.",
+ )
+ )
+ else:
+ reason = errmsg
+ warnings.warn(
+ f"The library {lib_name} is not built successfully. {reason}",
+ RuntimeWarning,
+ )
-_load_platform_specific_library()
-
+_load_platform_specific_library("libpt_tvmdsoop")
+_load_platform_specific_library("libpt_tvmdsoop_new")
from . import module
diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py
index 3da9c6f591..cfa3ad264c 100644
--- a/python/tvm/contrib/torch/module.py
+++ b/python/tvm/contrib/torch/module.py
@@ -16,7 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""Module container of PyTorch custom class"""
+import warnings
from typing import List
+
import torch
@@ -29,6 +31,11 @@ class GraphModule(torch.nn.Module):
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
def __init__(self, num_inputs, num_outputs, device=None):
+ warnings.warn(
+ "This module will be removed at TVM version 0.11",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.engine = None
@@ -67,6 +74,11 @@ class VMModule(torch.nn.Module):
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
def __init__(self, num_inputs, num_outputs, device=None):
+ warnings.warn(
+ "This module will be removed at TVM version 0.11",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.engine = None
@@ -113,6 +125,11 @@ class TraceTvmModule(torch.nn.Module):
"""
def __init__(self, tvm_module):
+ warnings.warn(
+ "This module will be removed at TVM version 0.11",
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.tvm_module = tvm_module
diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py
index 1e50c98ab8..ffab4fa0d2 100644
--- a/python/tvm/contrib/torch/pytorch_tvm.py
+++ b/python/tvm/contrib/torch/pytorch_tvm.py
@@ -19,6 +19,7 @@
# pylint: disable=redefined-builtin
"""`compile` api that convert torch module to torch tvm module"""
import os
+import warnings
import tvm
import tvm.testing
from tvm import relay, autotvm
@@ -183,6 +184,16 @@ class PyTorchTVMModule:
def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None):
"""Build pytorch module containing TVM Graph Module"""
+ warnings.warn(
+ " ".join(
+ (
+ "This function will be removed at TVM version 0.11,",
+ "we suggest users to use `optimized_torch` for tuning Torch modules instead.",
+ )
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
assert self.export_dir, "you must build_tvm or load_tvm before"
input_infos = input_infos or self.input_infos
assert input_infos
@@ -224,6 +235,16 @@ def compile(script_module, option):
pytorch_tvm_module = compile(script_module, option)
pytorch_tvm_module("model_tvm.pt")
"""
+ warnings.warn(
+ " ".join(
+ (
+ "This function will be removed at TVM version 0.11,",
+ "we suggest users to use `optimized_torch` for tuning Torch modules instead.",
+ )
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
input_infos = option["input_infos"]
default_dtype = option.get("default_dtype", "float32")
export_dir = option.get("export_dir", "pytorch_compiled")
diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
deleted file mode 100644
index 12c1017bea..0000000000
--- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
+++ /dev/null
@@ -1,259 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <ATen/DLConvertor.h>
-#include <dlpack/dlpack.h>
-#include <dmlc/memory_io.h>
-#include <torch/custom_class.h>
-#include <torch/script.h>
-#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/target/codegen.h>
-#include <tvm/target/target.h>
-
-#include <cstdio>
-#include <map>
-#include <string>
-#include <vector>
-
-#include "../../../runtime/graph_executor/graph_executor_factory.h"
-#include "../base64.h"
-
-namespace tvm {
-namespace contrib {
-
-/**
- * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects
- */
-struct ThreadLocalStore {
- tvm::runtime::Module mod;
- static ThreadLocalStore* ThreadLocal() {
- thread_local ThreadLocalStore tls;
- return &tls;
- }
-};
-
-using SerializationType = std::string; // base64 stream
-
-SerializationType serialize(tvm::runtime::Module module) {
- static const runtime::PackedFunc* f_to_str =
- runtime::Registry::Get("script_torch.save_to_base64");
- ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
- "`script_torch.save_to_base64` in the global registry";
- return (*f_to_str)(module);
-}
-
-struct Deleter { // deleter
- explicit Deleter(std::string file_name) { this->file_name = file_name; }
- void operator()(FILE* p) const {
- fclose(p);
- ICHECK(remove(file_name.c_str()) == 0)
- << "Failed to remove temporary file (" << file_name << ")";
- }
- std::string file_name;
-};
-
-tvm::runtime::Module deserialize(SerializationType state) {
- auto length = tvm::support::b64strlen(state);
-
- std::vector<u_char> bytes(length);
- tvm::support::b64decode(state, bytes.data());
-
- const std::string name = tmpnam(NULL);
- auto file_name = name + ".so";
- std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
- fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
- fflush(pFile.get());
-
- std::string load_f_name = "runtime.module.loadfile_so";
- const PackedFunc* f = runtime::Registry::Get(load_f_name);
- ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
- << " resolved to (" << load_f_name << ") in the global registry."
- << "Ensure that you have loaded the correct runtime code, and"
- << "that you are on the correct hardware architecture.";
-
- tvm::runtime::Module ret = (*f)(file_name, "");
-
- return ret;
-}
-
-/**
- * @brief A Torch's module which wraps TVM's OperatorModule Class.
- * The basic forward function calling TVM's runtime is provided.
- * The TVM module can be serialized/deserialized as a Torch module.
- */
-class OperatorModuleWrapper : public torch::jit::CustomClassHolder {
- public:
- OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; }
-
- void forward(const c10::List<at::Tensor>& inputs) {
- int input_length = inputs.size();
-
- std::vector<DLManagedTensor*> tensors;
-
- for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i]));
-
- tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__");
-
- std::vector<TVMValue> tvm_values(input_length);
- std::vector<int> tvm_type_codes(input_length);
- tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
- for (int k = 0; k < input_length; ++k) {
- setter(k, &tensors[k]->dl_tensor);
- }
-
- run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length),
- nullptr);
-
- for (int k = 0; k < input_length; ++k) {
- tensors[k]->deleter(tensors[k]);
- }
- }
-
- SerializationType Serialize() { return serialize(runtime_module); }
-
- explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); }
-
- private:
- tvm::runtime::Module runtime_module;
-};
-
-tvm::Device getDevice(const at::Tensor& tensor) {
- tvm::Device dev;
- dev.device_id = tensor.get_device();
- switch (tensor.device().type()) {
- case at::DeviceType::CPU:
- dev.device_type = DLDeviceType::kDLCPU;
- if (dev.device_id == -1) {
- /*
- * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning
- * Thus we manually set the device ID as 0 for avoiding potentially error of index out of
- * bounds
- */
- dev.device_id = 0;
- }
- break;
- case at::DeviceType::CUDA:
- dev.device_type = DLDeviceType::kDLCUDA;
- break;
- default:
- TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str());
- }
- return dev;
-}
-
-/**
- * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class.
- * The basic forward function calling TVM's runtime is provided.
- * The TVM module can be serialized/deserialized as a Torch module.
- */
-class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
- public:
- explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory)
- : executor_factory_(executor_factory) {
- CHECK(executor_factory_->IsInstance<runtime::GraphExecutorFactory>())
- << "module is not an instance of GraphExecutorFactory";
- }
-
- GraphExecutorFactoryWrapper()
- : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {}
-
- c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) {
- int input_length = inputs.size();
-
- if (!executor_.defined()) {
- TORCH_CHECK(input_length > 0, "Receive empty list of input tensors");
- DLDevice input_device = getDevice(inputs.get(0));
-
- auto tmp = executor_factory_.GetFunction("default");
-
- executor_ = tmp(input_device);
- }
-
- std::vector<DLManagedTensor*> tensors;
-
- for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i]));
-
- tvm::runtime::PackedFunc run = executor_.GetFunction("run");
- tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input");
- tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output");
- tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs");
-
- for (int k = 0; k < input_length; ++k) {
- set_input(k, &tensors[k]->dl_tensor);
- }
-
- run();
-
- int64_t output_length = get_num_outputs();
-
- c10::List<at::Tensor> outputs;
- outputs.reserve(output_length);
-
- for (int k = 0; k < output_length; ++k) {
- tvm::runtime::NDArray results = get_output(k);
- at::Tensor atTensor = at::fromDLPack(results.ToDLPack());
- outputs.emplace_back(atTensor);
- }
-
- for (int k = 0; k < input_length; ++k) {
- tensors[k]->deleter(tensors[k]);
- }
- return outputs;
- }
-
- SerializationType Serialize() { return serialize(executor_factory_); }
-
- explicit GraphExecutorFactoryWrapper(SerializationType state) {
- executor_factory_ = deserialize(state);
- }
-
- private:
- tvm::runtime::Module executor_factory_;
- tvm::runtime::Module executor_;
-};
-
-TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
- ThreadLocalStore::ThreadLocal()->mod = mod;
-});
-
-TORCH_LIBRARY(tvm_torch, m) {
- m.class_<OperatorModuleWrapper>("OperatorModuleWrapper")
- .def(torch::init<>())
- .def("forward", &OperatorModuleWrapper::forward)
- .def_pickle(
- [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> SerializationType {
- return self->Serialize();
- },
- [](SerializationType state) {
- return c10::make_intrusive<OperatorModuleWrapper>(state);
- });
- m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper")
- .def(torch::init<>())
- .def("forward", &GraphExecutorFactoryWrapper::forward)
- .def_pickle(
- [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> SerializationType {
- return self->Serialize();
- },
- [](SerializationType state) {
- return c10::make_intrusive<GraphExecutorFactoryWrapper>(state);
- });
-}
-
-} // namespace contrib
-} // namespace tvm
diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
new file mode 100644
index 0000000000..fb570c163f
--- /dev/null
+++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <dlpack/dlpack.h>
+#include <dmlc/memory_io.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
+#include <tvm/target/target.h>
+
+#include <cstdio>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "../../../runtime/graph_executor/graph_executor_factory.h"
+#include "../base64.h"
+#include "runtime_bridge.h"
+
+namespace tvm {
+namespace contrib {
+
+/*
+ * TVM's FFI for passing module from python to C++
+ */
+struct ThreadLocalStore {
+ tvm::runtime::Module mod;
+ static ThreadLocalStore* ThreadLocal() {
+ thread_local ThreadLocalStore tls;
+ return &tls;
+ }
+};
+
+/*
+ * Encode TVM runtime module to base64 stream
+ */
+std::string serialize(tvm::runtime::Module module) {
+ static const runtime::PackedFunc* f_to_str =
+ runtime::Registry::Get("script_torch.save_to_base64");
+ ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
+ "`script_torch.save_to_base64` in the global registry";
+ return (*f_to_str)(module);
+}
+
+struct Deleter { // deleter
+ explicit Deleter(std::string file_name) { this->file_name = file_name; }
+ void operator()(FILE* p) const {
+ fclose(p);
+ ICHECK(remove(file_name.c_str()) == 0)
+ << "remove temporary file (" << file_name << ") unsuccessfully";
+ }
+ std::string file_name;
+};
+
+/*
+ * Decode TVM runtime module from base64 stream
+ */
+tvm::runtime::Module deserialize(std::string state) {
+ auto length = tvm::support::b64strlen(state);
+
+ std::vector<u_char> bytes(length); // bytes stream
+ tvm::support::b64decode(state, bytes.data());
+
+ const std::string name = tmpnam(NULL);
+ auto file_name = name + ".so";
+ std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
+ fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
+ fflush(pFile.get());
+
+ std::string load_f_name = "runtime.module.loadfile_so";
+ const PackedFunc* f = runtime::Registry::Get(load_f_name);
+ ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
+ << " resolved to (" << load_f_name << ") in the global registry."
+ << "Ensure that you have loaded the correct runtime code, and"
+ << "that you are on the correct hardware architecture.";
+
+ tvm::runtime::Module ret = (*f)(file_name, "");
+
+ return ret;
+}
+
+TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
+ ThreadLocalStore::ThreadLocal()->mod = mod;
+});
+
+/*
+ * Convert NDArray to DLPack extend tensor. It should be zero-cost.
+ * @param src Pointer to NDArray
+ * @return DLPack extended tensor
+ */
+DLPackTensorExt CreateDLpackTensorExt(tvm::runtime::NDArray* src) {
+ auto is_bool = src->DataType().is_bool();
+ DLManagedTensor* tensor;
+ if (is_bool) {
+ // If we change DLDataType{kDLInt, 8, 1} to DataType::Bool()
+ // we will get `RuntimeError: Unsupported kUInt bits 1`
+ auto tmp = src->CreateView(src->Shape(), DLDataType{kDLInt, 8, 1});
+ tensor = tmp.ToDLPack();
+ } else {
+ tensor = src->ToDLPack();
+ }
+ DLPackTensorExt ret{tensor, is_bool};
+ return ret;
+}
+
+/*
+ * Create an NDArray with boolean type. (One memory copy)
+ * @param src DLpack extended tensor
+ * @return a new NDArray
+ */
+tvm::runtime::NDArray CreateBoolNDarray(DLPackTensorExt* src) {
+ auto& tensor = src->dl_managed_tensor->dl_tensor;
+ std::vector<int64_t> shape;
+ for (int64_t i = 0; i < tensor.ndim; i++) {
+ shape.push_back(tensor.shape[i]);
+ }
+ auto ret = tvm::runtime::NDArray::Empty(shape, DataType::Bool(), tensor.device);
+ ret.CopyFrom(&src->dl_managed_tensor->dl_tensor);
+ return std::move(ret);
+}
+
+bool IsZeroCopy(DLPackTensorExt* src) {
+ auto& dl_tensor = src->dl_managed_tensor->dl_tensor;
+ return tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device);
+}
+
+/*
+ * Create an NDArray from DLpack extended tensor.
+ * @param src DLpack extended tensor
+ * @return a new NDArray
+ */
+tvm::runtime::NDArray NDarrayFromDLpack(DLPackTensorExt* src) {
+ using tvm::runtime::NDArray;
+
+ NDArray array;
+ auto& dl_tensor = src->dl_managed_tensor->dl_tensor;
+ if (src->is_bool) {
+ // one memory copy
+ // the code is similar to NewFromDLTensor except for the type
+ array = CreateBoolNDarray(src);
+ } else if (IsZeroCopy(src)) {
+ array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor);
+ } else {
+ // one memory copy
+ array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device);
+ }
+ return array;
+}
+
+} // namespace contrib
+} // namespace tvm
+
+extern "C" {
+
+struct TVMContribTorchRuntimeModule {
+ tvm::runtime::Module mod;
+
+ explicit TVMContribTorchRuntimeModule(tvm::runtime::Module& mod) : mod(mod) {}
+};
+
+bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt* src) {
+ return (!src->is_bool) && (tvm::contrib::IsZeroCopy(src));
+}
+
+TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() {
+ return new TVMContribTorchRuntimeModule(tvm::contrib::ThreadLocalStore::ThreadLocal()->mod);
+}
+
+void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module,
+ DLPackTensorExt* inputs, size_t input_size) {
+ tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__");
+
+ std::vector<TVMValue> tvm_values(input_size);
+ std::vector<int> tvm_type_codes(input_size);
+ tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
+
+ std::vector<tvm::runtime::NDArray> input_cache(input_size);
+
+ for (size_t k = 0; k < input_size; ++k) {
+ auto datum = tvm::contrib::NDarrayFromDLpack(&inputs[k]); // could have one memory copy
+ input_cache[k] = datum; // we keep the datum in a vector for future use, otherwise the datum
+ // will be freed after the loop
+ setter(k, datum);
+ }
+
+ run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size),
+ nullptr);
+
+ for (size_t k = 0; k < input_size; ++k) {
+ if (!tvm_contrib_torch_tensor_ability_of_zero_copy(&inputs[k]))
+ input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor);
+ }
+}
+
+TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module(
+ TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* input_example) {
+ tvm::runtime::PackedFunc built_module = graph_executor_factory->mod.GetFunction("default");
+ tvm::Device device_info = input_example->dl_tensor.device;
+ tvm::runtime::Module runtime_module = built_module(device_info);
+ return new TVMContribTorchRuntimeModule(runtime_module);
+}
+
+size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* runtime_module,
+ DLPackTensorExt* inputs, size_t input_size,
+ DLPackTensorExt** outputs) {
+ tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("run");
+ tvm::runtime::PackedFunc set_input = runtime_module->mod.GetFunction("set_input");
+ tvm::runtime::PackedFunc get_output = runtime_module->mod.GetFunction("get_output");
+ tvm::runtime::PackedFunc get_num_outputs = runtime_module->mod.GetFunction("get_num_outputs");
+
+ for (size_t k = 0; k < input_size; ++k) {
+ set_input(k, &inputs[k].dl_managed_tensor->dl_tensor);
+ }
+
+ run();
+
+ int64_t output_length = get_num_outputs();
+
+ DLPackTensorExt* outputs_ptr = new DLPackTensorExt[output_length];
+ *outputs = outputs_ptr;
+
+ for (int64_t k = 0; k < output_length; ++k) {
+ tvm::runtime::NDArray results = get_output(k);
+ outputs_ptr[k] = tvm::contrib::CreateDLpackTensorExt(&results);
+ }
+
+ return output_length;
+}
+
+char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
+ std::string std = tvm::contrib::serialize(runtime_module->mod);
+ char* ret = new char[std.length() + 1];
+ snprintf(ret, std.length() + 1, "%s", std.c_str());
+ return ret;
+}
+
+TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
+ tvm::runtime::Module ret = tvm::contrib::deserialize(state);
+ return new TVMContribTorchRuntimeModule(ret);
+}
+
+void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr) {
+ delete module_ptr;
+}
+
+void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt* dlpack_ptr) {
+ delete[] dlpack_ptr;
+}
+
+void tvm_contrib_torch_free_encoding(char* encoding) { delete[] encoding; }
+}
diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
new file mode 100644
index 0000000000..3159438d72
--- /dev/null
+++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <ATen/DLConvertor.h>
+#include <torch/custom_class.h>
+#include <torch/script.h>
+
+#include <iostream>
+
+#include "runtime_bridge.h"
+
+namespace tvm {
+namespace contrib {
+
+/*
+ * Convert Torch tensor to DLPack extended tensor.
+ * The boolean Torch tensor will convert to DLtensor with `is_bool=True` flag.
+ * @param src Torch tensor
+ * @return DLPack extended tensor
+ */
+DLPackTensorExt ToDLPackExt(const at::Tensor& src) {
+ if (!src.is_contiguous()) {
+ return ToDLPackExt(src.contiguous());
+ }
+ DLPackTensorExt ret;
+ if (src.dtype().isScalarType(torch::kBool)) {
+ auto temp = src.toType(torch::kUInt8);
+ ret.dl_managed_tensor = at::toDLPack(temp);
+ ret.is_bool = true;
+ } else {
+ ret.dl_managed_tensor = at::toDLPack(src);
+ ret.is_bool = false;
+ }
+
+ return ret;
+}
+
+/*
+ * Convert DLPack extended tensor to Torch tensor.
+ * @param src DLPack extended tensor
+ * @return Torch tensor
+ */
+at::Tensor FromDLPackExt(const DLPackTensorExt& src) {
+ if (src.is_bool) {
+ return at::fromDLPack(src.dl_managed_tensor).toType(torch::kBool);
+ } else {
+ return at::fromDLPack(src.dl_managed_tensor);
+ }
+}
+
+/**
+ * @brief A Torch's module which wraps TVM's OperatorModule Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class OperatorModuleWrapper : public torch::jit::CustomClassHolder {
+ public:
+ OperatorModuleWrapper() { runtime_module_ = tvm_contrib_torch_get_last_saved_runtime_module(); }
+ ~OperatorModuleWrapper() { tvm_contrib_torch_free_runtime_module(runtime_module_); }
+
+ void forward(const c10::List<at::Tensor>& inputs) {
+ int input_length = inputs.size();
+
+ std::vector<DLPackTensorExt> tensors;
+
+ // Torch tensor supports boolean type while DLpack does not,
+ // we convert Torch tensor to an extension of DLPack tensor
+ for (int i = 0; i < input_length; ++i) tensors.push_back(ToDLPackExt(inputs[i]));
+ tvm_contrib_torch_operator_module_forward(this->runtime_module_, tensors.data(),
+ tensors.size());
+
+ for (int k = 0; k < input_length; ++k) {
+ if (tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) {
+ // We need to free memory manually
+ tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor);
+ } else {
+ // Ownership transferred
+ inputs[k].copy_(FromDLPackExt(tensors[k]));
+ }
+ }
+ }
+
+ std::string Serialize() {
+ auto encoding = tvm_contrib_torch_encode(runtime_module_);
+ auto ret = std::string(encoding);
+ tvm_contrib_torch_free_encoding(encoding);
+ return ret;
+ }
+
+ explicit OperatorModuleWrapper(std::string state) {
+ runtime_module_ = tvm_contrib_torch_decode(state.c_str());
+ }
+
+ private:
+ /*
+ * TVM runtime module wrapper
+ */
+ TVMContribTorchRuntimeModule* runtime_module_;
+};
+
+/**
+ * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
+ public:
+ explicit GraphExecutorFactoryWrapper(TVMContribTorchRuntimeModule* executor_factory)
+ : executor_factory_(executor_factory), executor_factory_runtime_(nullptr) {}
+
+ ~GraphExecutorFactoryWrapper() {
+ tvm_contrib_torch_free_runtime_module(executor_factory_);
+ tvm_contrib_torch_free_runtime_module(executor_factory_runtime_);
+ }
+
+ GraphExecutorFactoryWrapper()
+ : GraphExecutorFactoryWrapper(tvm_contrib_torch_get_last_saved_runtime_module()) {}
+
+ std::string Serialize() {
+ auto encoding = tvm_contrib_torch_encode(executor_factory_);
+ auto ret = std::string(encoding);
+ tvm_contrib_torch_free_encoding(encoding);
+ return ret;
+ }
+
+ explicit GraphExecutorFactoryWrapper(std::string state) {
+ executor_factory_ = tvm_contrib_torch_decode(state.c_str());
+ executor_factory_runtime_ = nullptr;
+ }
+
+ c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) {
+ int input_length = inputs.size();
+
+ TORCH_CHECK(input_length > 0, "Receive empty list of input tensors");
+
+ std::vector<DLPackTensorExt> tensors;
+
+ // Torch tensor supports boolean type while DLpack does not,
+ // we convert Torch tensor to an extension of DLPack tensor
+ for (int i = 0; i < input_length; ++i) tensors.push_back(ToDLPackExt(inputs[i]));
+
+ DLPackTensorExt* outputs;
+ if (executor_factory_runtime_ == nullptr) {
+ executor_factory_runtime_ = tvm_contrib_torch_create_graph_runtime_module(
+ this->executor_factory_, tensors[0].dl_managed_tensor);
+ }
+ auto num_outputs = tvm_contrib_torch_graph_executor_module_forward(
+ executor_factory_runtime_, tensors.data(), tensors.size(), &outputs);
+
+ c10::List<at::Tensor> ret;
+ ret.reserve(num_outputs);
+
+ for (size_t k = 0; k < num_outputs; ++k) {
+ at::Tensor atTensor = FromDLPackExt(outputs[k]);
+ ret.emplace_back(atTensor);
+ }
+
+ for (int k = 0; k < input_length; ++k) {
+ tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor);
+ }
+ tvm_contrib_torch_free_dlpack_tensor_ext_array(outputs);
+
+ return ret;
+ }
+
+ private:
+ /*
+ * TVM Graph Executor Factory module wrapper
+ */
+ TVMContribTorchRuntimeModule* executor_factory_;
+
+ /*
+ * TVM runtime module wrapper
+ */
+ TVMContribTorchRuntimeModule* executor_factory_runtime_;
+};
+
+TORCH_LIBRARY(tvm_torch, m) {
+ m.class_<OperatorModuleWrapper>("OperatorModuleWrapper")
+ .def(torch::init<>())
+ .def("forward", &OperatorModuleWrapper::forward)
+ .def_pickle(
+ [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> std::string {
+ return self->Serialize();
+ },
+ [](std::string state) { return c10::make_intrusive<OperatorModuleWrapper>(state); });
+ m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper")
+ .def(torch::init<>())
+ .def("forward", &GraphExecutorFactoryWrapper::forward)
+ .def_pickle(
+ [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> std::string {
+ return self->Serialize();
+ },
+ [](std::string state) {
+ return c10::make_intrusive<GraphExecutorFactoryWrapper>(state);
+ });
+}
+
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h
new file mode 100644
index 0000000000..58cd53a284
--- /dev/null
+++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file runtime_bridge.h
+ * \brief Util functions for pytorch tvm interaction.
+ */
+#ifndef TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_
+#define TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_
+
+extern "C" {
+
+/*
+ * DLPack data structure extend with `is_bool` flag.
+ * DLPack haven't support boolean tensor
+ * (https://github.com/pytorch/pytorch/blob/4618371da56c887195e2e1d16dad2b9686302800/aten/src/ATen/DLConvertor.cpp#L42),
+ * thus a boolean tensor will be regarded as a UInt8 tensor
+ * (https://github.com/apache/tvm/blob/de124862714e747764aa8b7f41a90bcb25f3c6a8/python/tvm/_ffi/runtime_ctypes.py#L91).
+ */
+struct DLPackTensorExt {
+ DLManagedTensor* dl_managed_tensor;
+ bool is_bool;
+};
+
+/*
+ * A wrapper pointing to TVM runtime module.
+ */
+struct TVMContribTorchRuntimeModule;
+
+/*
+ * Obtain a saved runtime module passed by TVM FFI.
+ * @return A TVM runtime module wrapper.
+ */
+TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module();
+
+/*
+ * Delete TVMContribTorchRuntimeModule pointer.
+ */
+void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr);
+
+/*
+ * Obtain ExecutorFactory runtime module from ExecutorFactory class.
+ * @param graph_executor_factory ExecutorFactory class
+ * @param input_example For obtaining device information
+ * @return ExecutorFactory TVM runtime module wrapper
+ */
+TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module(
+ TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* input_example);
+
+/*
+ * Forward method for OperatorModuleWrapper.
+ * @param runtime_module TVM runtime module wrapper
+ * @param inputs Array pointer of the input tensors
+ * @param input_size The number of input tensors
+ */
+void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module,
+ DLPackTensorExt* inputs, size_t input_size);
+
+/*
+ * Forward method for GraphExecutorFactoryWrapper.
+ * @param graph_executor_factory TVM runtime module wrapper
+ * @param inputs Array pointer of the input tensors
+ * @param input_size The number of input tensors
+ * @param outputs The resulting output tensors pointer
+ * @return The number of output tensors
+ */
+size_t tvm_contrib_torch_graph_executor_module_forward(
+ TVMContribTorchRuntimeModule* graph_executor_factory, DLPackTensorExt* inputs,
+ size_t input_size, DLPackTensorExt** outputs);
+
+/*
+ * Encode TVM runtime module.
+ * @param runtime_module TVM runtime module wrapper
+ * @return The encoding stream (char array)
+ */
+char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module);
+
+/*
+ * Decode TVM runtime module.
+ * @param state The encoding stream (char array) of TVM runtime module
+ * @return TVM runtime module wrapper
+ */
+TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state);
+
+/*
+ * Delete DLPackTensorExt pointer.
+ */
+void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt*);
+
+/*
+ * Delete char array pointer.
+ */
+void tvm_contrib_torch_free_encoding(char* encoding);
+
+/*
+ * Checking if a DLPackTensorExt is boolean or cannot be copied in zero cost.
+ */
+bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt*);
+}
+
+#endif // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_