You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/26 08:00:51 UTC
[tvm] branch main updated: TVM Vertical Integration with PyTorch (#11911)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ea6ea42757 TVM Vertical Integration with PyTorch (#11911)
ea6ea42757 is described below
commit ea6ea42757f275573391eb3ff67034b2749948ae
Author: Yaoda Zhou <ju...@sjtu.edu.cn>
AuthorDate: Tue Jul 26 16:00:44 2022 +0800
TVM Vertical Integration with PyTorch (#11911)
* optimize_torch & as_torch
* split files
* code formatting
* optimizing optimized_torch
* scrap your boilerplate
* as_torch polished
* configuration fixed
* Apply suggestions from code review
Co-authored-by: Lite Ye <li...@gmail.com>
* more document
* file deleter
* optimize deleter
* drop how-to guides
* clang-format-10
* formatter changes
* reformat
* reformat
* reformat
* reformatting
* fixed
* auto setting
* fixed
* split long string
* tune_tir
* upgrade as_torch
* optimize as_torch
* as_torch
* fixed typo
Co-authored-by: juda <yz...@octoml.ai>
Co-authored-by: Lite Ye <li...@gmail.com>
---
apps/pt_tvmdsoop/tests/test_as_torch.py | 257 ++++++++++++++++++++
apps/pt_tvmdsoop/tests/test_optimize_torch.py | 161 +++++++++++++
python/tvm/contrib/torch/__init__.py | 12 +-
python/tvm/contrib/torch/as_torch.py | 124 ++++++++++
python/tvm/contrib/torch/optimize_torch.py | 198 ++++++++++++++++
python/tvm/script/parser.py | 16 +-
src/contrib/torch/base64.h | 75 ++++++
.../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 259 +++++++++++++++++++++
8 files changed, 1099 insertions(+), 3 deletions(-)
diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py
new file mode 100644
index 0000000000..2c454e9454
--- /dev/null
+++ b/apps/pt_tvmdsoop/tests/test_as_torch.py
@@ -0,0 +1,257 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for tvm torch module"""
+import numpy as np
+
+import torch
+import torch.nn
+
+import tvm
+from tvm.meta_schedule.tune import TuneConfig
+from tvm.target.target import Target
+import tvm.testing
+from tvm.contrib.torch import as_torch
+from tvm.script import tir as T
+
+
+@as_torch
+def matmul(M: int, N: int, K: int, dtype: str):
+ @T.prim_func
+ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [M, K], dtype=dtype)
+ B = T.match_buffer(b, [N, K], dtype=dtype)
+ C = T.match_buffer(c, [M, N], dtype=dtype)
+ for i, j, k in T.grid(M, N, K):
+ with T.block():
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ return main
+
+
+@as_torch
+@tvm.script.ir_module
+class ModuleGPU:
+ @T.prim_func
+ def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ for i_0 in T.thread_binding(2, thread="blockIdx.x"):
+ for i_2 in T.thread_binding(2, thread="threadIdx.x"):
+ for i_1 in T.serial(2):
+ with T.block("B"):
+ vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
+ T.reads(A[vi])
+ T.writes(B[vi])
+ B[vi] = A[vi] + T.float32(1)
+
+
+@as_torch
+@T.prim_func
+def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128])
+ B = T.match_buffer(b, [128, 128])
+ C = T.match_buffer(c, [128, 128])
+
+ with T.block():
+ for i, j in T.grid(128, 128):
+ with T.block("s1"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ B[vi, vj] = A[vi, vj] + T.float32(1)
+
+ for i, j in T.grid(128, 128):
+ with T.block("s2"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
+config = TuneConfig(
+ strategy="replay_trace",
+ num_trials_per_iter=128,
+ max_trials_per_task=128,
+ max_trials_global=128,
+)
+
+
+@as_torch
+@tvm.script.ir_module
+class MyModule:
+ @T.prim_func
+ def main(a: T.handle, b: T.handle):
+ # We exchange data between function by handles, which are similar to pointer.
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # Create buffer from handles.
+ A = T.match_buffer(a, (8,), dtype="float32")
+ B = T.match_buffer(b, (8,), dtype="float32")
+ for i in range(8):
+ # A block is an abstraction for computation.
+ with T.block("B"):
+ # Define a spatial block iterator and bind it to value i.
+ vi = T.axis.spatial(8, i)
+ B[vi] = A[vi] + 1.0
+
+
+@as_torch
+@T.prim_func
+def loop_split(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i, ko in T.grid(128, 4):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("B"):
+ vi = T.axis.S(128, i)
+ vk = T.axis.R(128, ko * 32 + ki)
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
+@as_torch
+def elementwise_with_root(M: int, N: int, dtype: str):
+ @T.prim_func
+ def f(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [M, N])
+ B = T.match_buffer(b, [M, N])
+ C = T.match_buffer(c, [M, N])
+
+ with T.block():
+ for i, j in T.grid(M, N):
+ with T.block("s1"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.float32(1)
+ for i, j in T.grid(M, N):
+ with T.block("s2"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+ return f
+
+
+class MinuesOnes(torch.nn.Module):
+ def __init__(self):
+ super(MinuesOnes, self).__init__()
+ self.engine = MyModule
+
+ def forward(self, *input):
+ self.engine.forward(*input)
+ return input[-1] - 1
+
+
+def test_tvmscript_torch_matmul():
+ s1 = np.random.rand(128, 128).astype("float32")
+ s2 = np.random.rand(128, 128).astype("float32")
+ s3 = np.random.rand(128, 128).astype("float32")
+
+ q1 = torch.from_numpy(s1)
+ q2 = torch.from_numpy(s2)
+ q3 = torch.from_numpy(s3)
+
+ numpy_result = np.matmul(s1, np.transpose(s2))
+
+ nn_module = matmul(128, 128, 128, "float32")
+
+ nn_module(q1, q2, q3)
+
+ tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_decorator():
+ q1 = torch.arange(8).type(torch.float32)
+ q2 = torch.zeros((8,), dtype=torch.float32)
+
+ MyModule(q1, q2)
+
+ tvm.testing.assert_allclose(q2.numpy(), (q1 + 1).numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_gpu():
+ cuda0 = torch.device("cuda:0")
+ q1 = torch.arange(8, device=cuda0).type(torch.float32)
+ q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)
+
+ ModuleGPU(q1, q2)
+
+ tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_torch_with_tvmscript():
+ ref_result = np.arange(8).astype("float32")
+
+ q1 = torch.arange(8).type(torch.float32)
+ q2 = torch.zeros((8,), dtype=torch.float32)
+
+ nn_module = MinuesOnes()
+
+ ret = nn_module.forward(q1, q2)
+
+ tvm.testing.assert_allclose(ret.numpy(), ref_result, atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_func_with_part_access_region():
+ a1 = torch.rand(128, 128)
+ a2 = torch.zeros(128, 128)
+ a3 = torch.zeros(128, 128)
+
+ result = a1 + 2
+
+ func_with_part_access_region.tune()
+ func_with_part_access_region(a1, a2, a3)
+
+ tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_loop_split():
+ x = torch.rand(128, 128).cuda()
+ y = torch.zeros(128).cuda()
+
+ result = torch.sum(x.cpu(), dim=1).numpy()
+
+ loop_split.tune(config, Target("nvidia/geforce-rtx-3070"))
+ loop_split(x, y)
+
+ tvm.testing.assert_allclose(y.cpu().numpy(), result, atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_elementwise_with_root():
+ a1 = torch.rand(128, 128)
+ a2 = torch.zeros(128, 128)
+ a3 = torch.zeros(128, 128)
+
+ result = a1 + 2
+
+ func = elementwise_with_root(128, 128, "float32")
+ func.tune(config)
+ func(a1, a2, a3)
+
+ tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5)
+
+
+if __name__ == "__main__":
+ test_tvmscript_torch_matmul()
+ test_tvmscript_torch_decorator()
+ test_tvmscript_torch_gpu()
+ test_torch_with_tvmscript()
+ test_tvmscript_torch_func_with_part_access_region()
+ test_tvmscript_torch_loop_split()
+ test_tvmscript_torch_elementwise_with_root()
diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py
new file mode 100644
index 0000000000..258dfe55c4
--- /dev/null
+++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py
@@ -0,0 +1,161 @@
+# pylint: disable=missing-class-docstring
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for tvm torch module"""
+import tempfile
+
+import torch
+from torch.utils import benchmark
+from torchvision.models import resnet18
+
+import tvm
+import tvm.testing
+from tvm.contrib.torch import optimize_torch
+from tvm.meta_schedule import TuneConfig
+
+
+def test_matmul_tuning_relay():
+ def matmul(x, w):
+ return torch.matmul(x, w)
+
+ x = torch.randn(15, 20)
+ w = torch.randn(20, 30)
+ example_inputs = (x, w)
+
+ rt_mod = optimize_torch(matmul, example_inputs)
+ torch_answer = torch.matmul(x, w).numpy()
+ tvm_answer = rt_mod(x, w).numpy()
+
+ tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5)
+
+
+class InnerModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(1, 20, 5)
+
+ def forward(self, x):
+ return torch.nn.functional.relu(self.conv(x))
+
+
+class SimpleModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(20, 20, 5)
+ self.relu = InnerModel()
+
+ def forward(self, x):
+ x = self.relu(x)
+ return torch.nn.functional.relu(self.conv(x))
+
+
+def test_nested_module():
+ simple_module = SimpleModel()
+ example_input = torch.randn(20, 1, 10, 10)
+ optimized_module = optimize_torch(simple_module, example_input)
+ ret1 = simple_module(example_input).detach().numpy()
+ ret2 = optimized_module(example_input).detach().numpy()
+ tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
+
+
+def test_save_load_function():
+ def foo(x):
+ return 2 * x + 1
+
+ example_input = torch.rand(3)
+ opt_foo = optimize_torch(foo, example_input)
+ ret1 = opt_foo(example_input)
+ with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
+ torch.save(opt_foo, tmp.name)
+ loaded_mod = torch.load(tmp.name)
+ ret2 = loaded_mod(example_input)
+ tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5)
+
+
+class MyResNet18(torch.nn.Module):
+ def __init__(self, config, target=None):
+ super(MyResNet18, self).__init__()
+ self.means = torch.nn.Parameter(
+ torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1)
+ ).cuda()
+ self.resnet = optimize_torch(resnet18(), [torch.rand(1, 3, 224, 224)], config, target)
+
+ def forward(self, input):
+ return self.resnet(input - self.means)
+
+
+class JitModule(torch.nn.Module):
+ def __init__(self):
+ super(JitModule, self).__init__()
+ self.means = torch.nn.Parameter(
+ torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1)
+ ).cuda()
+ self.resnet = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval()))
+
+ def forward(self, input):
+ return self.resnet(input - self.means)
+
+
+# default config for testing
+config = TuneConfig(
+ strategy="evolutionary",
+ num_trials_per_iter=4,
+ max_trials_per_task=8,
+ max_trials_global=16,
+)
+
+if torch.cuda.is_available():
+ target_cuda = "nvidia/geforce-rtx-3070"
+ meta_module_resnet18 = MyResNet18(config, target_cuda)
+ jit_module_resnet18 = JitModule()
+
+
+def compare_optimize_resnet18_to_torchscript():
+ results = []
+ for i in range(20):
+ test_input = torch.rand(1, 3, 224, 224).half().cuda()
+ sub_label = f"[test {i}]"
+ results.append(
+ benchmark.Timer(
+ stmt="meta_module_resnet18(test_input)",
+ setup="from __main__ import meta_module_resnet18",
+ globals={"test_input": test_input},
+ sub_label=sub_label,
+ description="tuning by meta",
+ ).blocked_autorange()
+ )
+ results.append(
+ benchmark.Timer(
+ stmt="jit_module_resnet18(test_input)",
+ setup="from __main__ import jit_module_resnet18",
+ globals={"test_input": test_input},
+ sub_label=sub_label,
+ description="tuning by jit",
+ ).blocked_autorange()
+ )
+ compare = benchmark.Compare(results)
+ compare.print()
+
+
+if __name__ == "__main__":
+ test_matmul_tuning_relay()
+ test_nested_module()
+ test_save_load_function()
+ if torch.cuda.is_available():
+ compare_optimize_resnet18_to_torchscript()
diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py
index 720ac29cc6..340f9cef9e 100644
--- a/python/tvm/contrib/torch/__init__.py
+++ b/python/tvm/contrib/torch/__init__.py
@@ -20,7 +20,6 @@ import os
import platform
import torch
from tvm._ffi import libinfo
-from tvm.relay.frontend import pytorch
def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
@@ -39,6 +38,7 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
_load_platform_specific_library()
+
from . import module
GraphModule = module.GraphModule
@@ -49,3 +49,13 @@ from . import pytorch_tvm
PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule
compile = pytorch_tvm.compile
+
+from . import as_torch
+
+TVMScriptIRModule = as_torch.OperatorModuleWrapper
+as_torch = as_torch.as_torch
+
+from . import optimize_torch
+
+GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper
+optimize_torch = optimize_torch.optimize_torch
diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py
new file mode 100644
index 0000000000..3a2b4dda9e
--- /dev/null
+++ b/python/tvm/contrib/torch/as_torch.py
@@ -0,0 +1,124 @@
+# pylint: disable=inconsistent-return-statements
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring
+# pylint: disable=missing-class-docstring
+# pylint: disable=missing-function-docstring
+"""
+as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`.
+"""
+import tempfile
+from typing import Callable, List, Union
+
+import torch
+import torch.utils.dlpack
+
+import tvm
+from tvm.meta_schedule.tune import TuneConfig, tune_tir
+from tvm.target.target import Target
+from tvm.tir.schedule.schedule import Schedule
+
+
+# python wrapper for OperatorModule
+class OperatorModuleWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ module: Union[
+ tvm.ir.module.IRModule,
+ tvm.tir.function.PrimFunc,
+ ],
+ ):
+ super().__init__()
+ self.rt_module = None # runtime module
+ self.ir_module = module # IR modules
+
+ def tune(self, config: TuneConfig = None, target: Union[str, Target] = None):
+ """
+ Tune the TVMscript code.
+
+ Parameters
+ ----------
+ config: Optional[TuneConfig]
+ The tuning configuration.
+
+ target : Optional[str, Target]
+ The target to tune for.
+ """
+ if config is None:
+ config = TuneConfig(
+ # Default setting
+ strategy="replay_trace",
+ num_trials_per_iter=32,
+ max_trials_per_task=32,
+ max_trials_global=32,
+ )
+ if target is None:
+ target = Target("llvm --num-cores=16")
+ with tempfile.TemporaryDirectory() as work_dir:
+ sch: Schedule = tune_tir(
+ mod=self.ir_module,
+ target=target,
+ config=config,
+ work_dir=work_dir,
+ )
+ self.ir_module = sch.mod
+ self.build(target)
+
+ def build(self, target=None):
+ runtime_module = tvm.build(self.ir_module, target=target)
+ func = tvm.get_global_func("tvmtorch.save_runtime_mod")
+ func(runtime_module)
+
+ self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper()
+
+ def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
+ if self.rt_module is None:
+ if torch_inputs[0].is_cuda:
+ self.build(target="cuda")
+ elif torch_inputs[0].device.type == "cpu":
+ self.build()
+ else:
+ raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet")
+
+ return self.rt_module.forward(torch_inputs)
+
+
+def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]):
+ """A decorator of converting TensorIR to PyTorch nn.Module.
+
+ Parameters
+ ----------
+ func: Optional[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]
+ The function written by TVMscript.
+
+ Returns
+ -------
+ mod : Union[OperatorModuleWrapper, Callable]
+ It will return an object, or a templated function of OperatorModuleWrapper,
+ which is the subclass of the original nn.Module.
+
+ """
+ if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)):
+ return OperatorModuleWrapper(func)
+ if isinstance(func, Callable):
+
+ def func_get_param(*args, **kargs):
+ return OperatorModuleWrapper(func(*args, **kargs))
+
+ return func_get_param
diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py
new file mode 100644
index 0000000000..282e6c5dc8
--- /dev/null
+++ b/python/tvm/contrib/torch/optimize_torch.py
@@ -0,0 +1,198 @@
+# pylint: disable=inconsistent-return-statements
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring
+# pylint: disable=missing-class-docstring
+# pylint: disable=missing-function-docstring
+"""
+optimize_torch: a function similar to `torch.jit.trace`,
+which is used to optimize the `torch.nn.module` by TVM metaSchedule,
+and returns a custom TorchScript operator
+"""
+import base64
+import contextlib
+import tempfile
+from typing import Dict, Optional, Tuple, Union
+import warnings
+
+import torch
+import torch.utils.dlpack
+
+import tvm
+from tvm import relay
+from tvm._ffi import get_global_func, register_func
+from tvm.ir.module import IRModule
+from tvm.ir.transform import PassContext
+from tvm.meta_schedule import TuneConfig, default_config
+from tvm.meta_schedule.apply_history_best import ApplyHistoryBest
+from tvm.meta_schedule.relay_integration import extract_task_from_relay
+from tvm.meta_schedule.tune import tune_extracted_tasks
+from tvm.meta_schedule.utils import autotvm_silencer
+from tvm.runtime import vm
+from tvm.runtime.module import Module
+from tvm.runtime.ndarray import NDArray
+from tvm.target.target import Target
+
+
+# The python wrapper for GraphExecutorFactory
+class GraphExecutorFactoryWrapper(torch.nn.Module):
+ def __init__(self, module: tvm.runtime.Module):
+ super().__init__()
+ self.inner_module = module
+
+ def forward(self, *torch_inputs: Tuple[torch.Tensor]):
+ ret = self.inner_module.forward(torch_inputs)
+ if len(ret) == 1:
+ return ret[0]
+ return ret
+
+
+def llvm_target():
+ return "llvm -num-cores"
+
+
+@register_func("script_torch.save_to_base64")
+def save_to_base64(obj) -> bytes:
+ with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
+ obj.export_library(tmpfile.name)
+ with open(tmpfile.name, "rb") as tfile:
+ return base64.b64encode(tfile.read())
+
+
+def tune_relay_auto(
+ mod: IRModule,
+ target: Union[str, Target],
+ config: TuneConfig,
+ work_dir: str,
+ backend: str = "graph",
+ params: Optional[Dict[str, NDArray]] = None,
+) -> Union[Module, vm.Executable]:
+ """A wrapper of `tune_relay` but provide a default setting for the config.
+
+ Parameters
+ ----------
+ mod : IRModule
+ The module to tune.
+ target : Union[str, Target]
+ The target to tune for.
+ config : TuneConfig
+ The search strategy config.
+ params : Optional[Dict[str, tvm.runtime.NDArray]]
+ The associated parameters of the program
+ work_dir : Optional[str]
+ The working directory to save intermediate results.
+ backend : str = "graph"
+ The backend to use for relay compilation(graph / vm).
+
+ Returns
+ -------
+ lib : Union[Module, tvm.runtime.vm.Executable]
+ The built runtime module or vm Executable for the given relay workload.
+ """
+ target = default_config.target(target)
+ extracted_tasks = extract_task_from_relay(mod, target, params)
+ if config is None:
+ config = TuneConfig(
+ num_trials_per_iter=16,
+ max_trials_global=16 * len(extracted_tasks),
+ )
+ database = tune_extracted_tasks(extracted_tasks, config, work_dir)
+ relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend]
+ with target, autotvm_silencer(), ApplyHistoryBest(database):
+ with PassContext(
+ opt_level=3,
+ config={
+ "relay.backend.use_meta_schedule": True,
+ "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda",
+ },
+ ):
+ return relay_build(mod, target=target, params=params)
+
+
+def optimize_torch(
+ func,
+ example_inputs,
+ tuning_config=None,
+ target=None,
+ work_dir=None,
+):
+ """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule.
+
+ Parameters
+ ----------
+ func : callable or torch.nn.Module
+ A Python function or nn.Module that could run by TorchScript's trace.
+ (ie: torch.jit.trace(model, input))
+
+ example_inputs : tuple or torch.Tensor
+ Inputs to `torch.jit.trace`.
+
+ tuning_config : tvm.meta_schedule.TuneConfig
+ The configuration for tuning by MetaSchedule.
+ If user doesn't set the config, the tuning will run with a default setting.
+ Here, the total number of trials is proportional
+ to the number of tunable tasks in the input module.
+
+ target : Optional[Union[str, Target]]
+ The target of the compilation.
+ If user doesn't set the target, the module will be built for the CPU target.
+
+ work_dir : Optional[str]
+ The working directory to save intermediate results.
+
+ Returns
+ -------
+ mod : GraphExecutorFactoryWrapper
+ It will return an object of GraphExecutorFactoryWrapper,
+ which is the subclass of the original nn.Module.
+ """
+
+ if target is None:
+ target = llvm_target()
+
+ if tuning_config is None:
+ warning_msg = (
+ "Using the default tuning parameters.",
+ "The default number of trials is set to a small value to let tuning finish quickly.",
+ "For optimal performance, it is recommended to provide",
+ "the `tuning_config` argument with a bigger number of trials.",
+ )
+ warnings.warn(" ".join(warning_msg), stacklevel=2)
+
+ # If `func` is already a traced module this statement makes no effect
+ jit_mod = torch.jit.trace(func, example_inputs)
+
+ if isinstance(example_inputs, torch.Tensor):
+ example_inputs = [example_inputs]
+
+ shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
+ mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule
+ if work_dir:
+ context_manager = contextlib.nullcontext(work_dir)
+ else:
+ context_manager = tempfile.TemporaryDirectory()
+ with context_manager as work_dir_path:
+ executor_factory = tune_relay_auto(
+ mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path
+ )
+
+ save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod")
+ save_runtime_mod(executor_factory.module)
+
+ return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 7f5b3e86f3..908af081c9 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -435,11 +435,23 @@ class TVMScriptParser(Transformer):
T.evaluate(0) # 4. This function returns 0
"""
+ def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
+ if isinstance(decorator, ast.Call):
+ if len(decorator.params) != 1:
+ return False
+ func_name = decorator.func_name
+ else:
+ func_name = decorator
+ if isinstance(func_name, ast.Var):
+ return func_name.id.name == "as_torch"
+
def check_decorator(decorators: List[ast.Expr]) -> bool:
"""Check the decorator is `T.prim_func"""
- if len(decorators) != 1:
+ if len(decorators) > 2 or len(decorators) == 0:
+ return False
+ if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
return False
- d: ast.Expr = decorators[0]
+ d: ast.Expr = decorators[-1]
return (
isinstance(d, ast.Attr)
and isinstance(d.object, ast.Var)
diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h
new file mode 100644
index 0000000000..859fd1abcf
--- /dev/null
+++ b/src/contrib/torch/base64.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file base64.h
+ * \brief Util functions for converting plain bytes back to plain bytes
+ */
+
+#ifndef TVM_CONTRIB_TORCH_BASE64_H_
+#define TVM_CONTRIB_TORCH_BASE64_H_
+
+#include <tvm/runtime/logging.h>
+
+#include <cctype>
+#include <cstdio>
+#include <string>
+
+#include "../../support/base64.h"
+
+namespace tvm {
+namespace support {
+
+size_t b64strlen(const std::string b64str) {
+ ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
+ size_t length = b64str.size() / 4 * 3;
+ if (b64str[b64str.size() - 2] == '=') {
+ length -= 2;
+ } else if (b64str[b64str.size() - 1] == '=') {
+ length -= 1;
+ }
+ return length;
+}
+
+void b64decode(const std::string b64str, u_char* ret) {
+ size_t index = 0;
+ const auto length = b64str.size();
+ for (size_t i = 0; i < length; i += 4) {
+ int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
+ int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
+ int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
+ int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
+ u_char st1 = (ch0 << 2) + (ch1 >> 4);
+ ret[index++] = st1;
+ if (b64str[i + 2] != '=') {
+ u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
+ ret[index++] = st2;
+ if (b64str[i + 3] != '=') {
+ u_char st3 = ((ch2 & 0b11) << 6) + ch3;
+ ret[index++] = st3;
+ }
+ }
+ }
+ ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
+}
+
+} // namespace support
+} // namespace tvm
+
+#endif // TVM_CONTRIB_TORCH_BASE64_H_
diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
new file mode 100644
index 0000000000..12c1017bea
--- /dev/null
+++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <ATen/DLConvertor.h>
+#include <dlpack/dlpack.h>
+#include <dmlc/memory_io.h>
+#include <torch/custom_class.h>
+#include <torch/script.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
+#include <tvm/target/target.h>
+
+#include <cstdio>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "../../../runtime/graph_executor/graph_executor_factory.h"
+#include "../base64.h"
+
+namespace tvm {
+namespace contrib {
+
+/**
+ * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects
+ */
+struct ThreadLocalStore {
+ tvm::runtime::Module mod;
+ static ThreadLocalStore* ThreadLocal() {
+ thread_local ThreadLocalStore tls;
+ return &tls;
+ }
+};
+
+using SerializationType = std::string; // base64 stream
+
+SerializationType serialize(tvm::runtime::Module module) {
+ static const runtime::PackedFunc* f_to_str =
+ runtime::Registry::Get("script_torch.save_to_base64");
+ ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
+ "`script_torch.save_to_base64` in the global registry";
+ return (*f_to_str)(module);
+}
+
+struct Deleter { // deleter
+ explicit Deleter(std::string file_name) { this->file_name = file_name; }
+ void operator()(FILE* p) const {
+ fclose(p);
+ ICHECK(remove(file_name.c_str()) == 0)
+ << "Failed to remove temporary file (" << file_name << ")";
+ }
+ std::string file_name;
+};
+
+tvm::runtime::Module deserialize(SerializationType state) {
+ auto length = tvm::support::b64strlen(state);
+
+ std::vector<u_char> bytes(length);
+ tvm::support::b64decode(state, bytes.data());
+
+ const std::string name = tmpnam(NULL);
+ auto file_name = name + ".so";
+ std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
+ fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
+ fflush(pFile.get());
+
+ std::string load_f_name = "runtime.module.loadfile_so";
+ const PackedFunc* f = runtime::Registry::Get(load_f_name);
+ ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
+ << " resolved to (" << load_f_name << ") in the global registry."
+ << "Ensure that you have loaded the correct runtime code, and"
+ << "that you are on the correct hardware architecture.";
+
+ tvm::runtime::Module ret = (*f)(file_name, "");
+
+ return ret;
+}
+
+/**
+ * @brief A Torch's module which wraps TVM's OperatorModule Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class OperatorModuleWrapper : public torch::jit::CustomClassHolder {
+ public:
+ OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; }
+
+ void forward(const c10::List<at::Tensor>& inputs) {
+ int input_length = inputs.size();
+
+ std::vector<DLManagedTensor*> tensors;
+
+ for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i]));
+
+ tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__");
+
+ std::vector<TVMValue> tvm_values(input_length);
+ std::vector<int> tvm_type_codes(input_length);
+ tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
+ for (int k = 0; k < input_length; ++k) {
+ setter(k, &tensors[k]->dl_tensor);
+ }
+
+ run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length),
+ nullptr);
+
+ for (int k = 0; k < input_length; ++k) {
+ tensors[k]->deleter(tensors[k]);
+ }
+ }
+
+ SerializationType Serialize() { return serialize(runtime_module); }
+
+ explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); }
+
+ private:
+ tvm::runtime::Module runtime_module;
+};
+
+tvm::Device getDevice(const at::Tensor& tensor) {
+ tvm::Device dev;
+ dev.device_id = tensor.get_device();
+ switch (tensor.device().type()) {
+ case at::DeviceType::CPU:
+ dev.device_type = DLDeviceType::kDLCPU;
+ if (dev.device_id == -1) {
+ /*
+ * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning
+ * Thus we manually set the device ID as 0 for avoiding potentially error of index out of
+ * bounds
+ */
+ dev.device_id = 0;
+ }
+ break;
+ case at::DeviceType::CUDA:
+ dev.device_type = DLDeviceType::kDLCUDA;
+ break;
+ default:
+ TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str());
+ }
+ return dev;
+}
+
+/**
+ * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
+ public:
+ explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory)
+ : executor_factory_(executor_factory) {
+ CHECK(executor_factory_->IsInstance<runtime::GraphExecutorFactory>())
+ << "module is not an instance of GraphExecutorFactory";
+ }
+
+ GraphExecutorFactoryWrapper()
+ : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {}
+
+ c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) {
+ int input_length = inputs.size();
+
+ if (!executor_.defined()) {
+ TORCH_CHECK(input_length > 0, "Receive empty list of input tensors");
+ DLDevice input_device = getDevice(inputs.get(0));
+
+ auto tmp = executor_factory_.GetFunction("default");
+
+ executor_ = tmp(input_device);
+ }
+
+ std::vector<DLManagedTensor*> tensors;
+
+ for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i]));
+
+ tvm::runtime::PackedFunc run = executor_.GetFunction("run");
+ tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input");
+ tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output");
+ tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs");
+
+ for (int k = 0; k < input_length; ++k) {
+ set_input(k, &tensors[k]->dl_tensor);
+ }
+
+ run();
+
+ int64_t output_length = get_num_outputs();
+
+ c10::List<at::Tensor> outputs;
+ outputs.reserve(output_length);
+
+ for (int k = 0; k < output_length; ++k) {
+ tvm::runtime::NDArray results = get_output(k);
+ at::Tensor atTensor = at::fromDLPack(results.ToDLPack());
+ outputs.emplace_back(atTensor);
+ }
+
+ for (int k = 0; k < input_length; ++k) {
+ tensors[k]->deleter(tensors[k]);
+ }
+ return outputs;
+ }
+
+ SerializationType Serialize() { return serialize(executor_factory_); }
+
+ explicit GraphExecutorFactoryWrapper(SerializationType state) {
+ executor_factory_ = deserialize(state);
+ }
+
+ private:
+ tvm::runtime::Module executor_factory_;
+ tvm::runtime::Module executor_;
+};
+
+TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
+ ThreadLocalStore::ThreadLocal()->mod = mod;
+});
+
+TORCH_LIBRARY(tvm_torch, m) {
+ m.class_<OperatorModuleWrapper>("OperatorModuleWrapper")
+ .def(torch::init<>())
+ .def("forward", &OperatorModuleWrapper::forward)
+ .def_pickle(
+ [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> SerializationType {
+ return self->Serialize();
+ },
+ [](SerializationType state) {
+ return c10::make_intrusive<OperatorModuleWrapper>(state);
+ });
+ m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper")
+ .def(torch::init<>())
+ .def("forward", &GraphExecutorFactoryWrapper::forward)
+ .def_pickle(
+ [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> SerializationType {
+ return self->Serialize();
+ },
+ [](SerializationType state) {
+ return c10::make_intrusive<GraphExecutorFactoryWrapper>(state);
+ });
+}
+
+} // namespace contrib
+} // namespace tvm