You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/26 08:00:51 UTC

[tvm] branch main updated: TVM Vertical Integration with PyTorch (#11911)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ea6ea42757 TVM Vertical Integration with PyTorch (#11911)
ea6ea42757 is described below

commit ea6ea42757f275573391eb3ff67034b2749948ae
Author: Yaoda Zhou <ju...@sjtu.edu.cn>
AuthorDate: Tue Jul 26 16:00:44 2022 +0800

    TVM Vertical Integration with PyTorch (#11911)
    
    * optimize_torch & as_torch
    
    * split files
    
    * code formatting
    
    * optimizing optimized_torch
    
    * scrap your boilerplate
    
    * as_torch polished
    
    * configuration fixed
    
    * Apply suggestions from code review
    
    Co-authored-by: Lite Ye <li...@gmail.com>
    
    * more document
    
    * file deleter
    
    * optimize deleter
    
    * drop how-to guides
    
    * clang-format-10
    
    * formatter changes
    
    * reformat
    
    * reformat
    
    * reformat
    
    * reformatting
    
    * fixed
    
    * auto setting
    
    * fixed
    
    * split long string
    
    * tune_tir
    
    * upgrade as_torch
    
    * optimize as_torch
    
    * as_torch
    
    * fixed typo
    
    Co-authored-by: juda <yz...@octoml.ai>
    Co-authored-by: Lite Ye <li...@gmail.com>
---
 apps/pt_tvmdsoop/tests/test_as_torch.py            | 257 ++++++++++++++++++++
 apps/pt_tvmdsoop/tests/test_optimize_torch.py      | 161 +++++++++++++
 python/tvm/contrib/torch/__init__.py               |  12 +-
 python/tvm/contrib/torch/as_torch.py               | 124 ++++++++++
 python/tvm/contrib/torch/optimize_torch.py         | 198 ++++++++++++++++
 python/tvm/script/parser.py                        |  16 +-
 src/contrib/torch/base64.h                         |  75 ++++++
 .../torch/pt_call_tvm/RuntimeModuleWrapper.cc      | 259 +++++++++++++++++++++
 8 files changed, 1099 insertions(+), 3 deletions(-)

diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py
new file mode 100644
index 0000000000..2c454e9454
--- /dev/null
+++ b/apps/pt_tvmdsoop/tests/test_as_torch.py
@@ -0,0 +1,257 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for tvm torch module"""
+import numpy as np
+
+import torch
+import torch.nn
+
+import tvm
+from tvm.meta_schedule.tune import TuneConfig
+from tvm.target.target import Target
+import tvm.testing
+from tvm.contrib.torch import as_torch
+from tvm.script import tir as T
+
+
+@as_torch
+def matmul(M: int, N: int, K: int, dtype: str):
+    @T.prim_func
+    def main(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [M, K], dtype=dtype)
+        B = T.match_buffer(b, [N, K], dtype=dtype)
+        C = T.match_buffer(c, [M, N], dtype=dtype)
+        for i, j, k in T.grid(M, N, K):
+            with T.block():
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                with T.init():
+                    C[vi, vj] = T.float32(0)
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    return main
+
+
+@as_torch
+@tvm.script.ir_module
+class ModuleGPU:
+    @T.prim_func
+    def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i_0 in T.thread_binding(2, thread="blockIdx.x"):
+            for i_2 in T.thread_binding(2, thread="threadIdx.x"):
+                for i_1 in T.serial(2):
+                    with T.block("B"):
+                        vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2)
+                        T.reads(A[vi])
+                        T.writes(B[vi])
+                        B[vi] = A[vi] + T.float32(1)
+
+
+@as_torch
+@T.prim_func
+def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+
+    with T.block():
+        for i, j in T.grid(128, 128):
+            with T.block("s1"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(A[vi, vj])
+                B[vi, vj] = A[vi, vj] + T.float32(1)
+
+        for i, j in T.grid(128, 128):
+            with T.block("s2"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.writes(C[vi, vj])
+                C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
+config = TuneConfig(
+    strategy="replay_trace",
+    num_trials_per_iter=128,
+    max_trials_per_task=128,
+    max_trials_global=128,
+)
+
+
+@as_torch
+@tvm.script.ir_module
+class MyModule:
+    @T.prim_func
+    def main(a: T.handle, b: T.handle):
+        # We exchange data between function by handles, which are similar to pointer.
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # Create buffer from handles.
+        A = T.match_buffer(a, (8,), dtype="float32")
+        B = T.match_buffer(b, (8,), dtype="float32")
+        for i in range(8):
+            # A block is an abstraction for computation.
+            with T.block("B"):
+                # Define a spatial block iterator and bind it to value i.
+                vi = T.axis.spatial(8, i)
+                B[vi] = A[vi] + 1.0
+
+
+@as_torch
+@T.prim_func
+def loop_split(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i, ko in T.grid(128, 4):
+        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("B"):
+                vi = T.axis.S(128, i)
+                vk = T.axis.R(128, ko * 32 + ki)
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vi, vk]
+
+
+@as_torch
+def elementwise_with_root(M: int, N: int, dtype: str):
+    @T.prim_func
+    def f(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [M, N])
+        B = T.match_buffer(b, [M, N])
+        C = T.match_buffer(c, [M, N])
+
+        with T.block():
+            for i, j in T.grid(M, N):
+                with T.block("s1"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + T.float32(1)
+            for i, j in T.grid(M, N):
+                with T.block("s2"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    C[vi, vj] = B[vi, vj] + T.float32(1)
+
+    return f
+
+
+class MinuesOnes(torch.nn.Module):
+    def __init__(self):
+        super(MinuesOnes, self).__init__()
+        self.engine = MyModule
+
+    def forward(self, *input):
+        self.engine.forward(*input)
+        return input[-1] - 1
+
+
+def test_tvmscript_torch_matmul():
+    s1 = np.random.rand(128, 128).astype("float32")
+    s2 = np.random.rand(128, 128).astype("float32")
+    s3 = np.random.rand(128, 128).astype("float32")
+
+    q1 = torch.from_numpy(s1)
+    q2 = torch.from_numpy(s2)
+    q3 = torch.from_numpy(s3)
+
+    numpy_result = np.matmul(s1, np.transpose(s2))
+
+    nn_module = matmul(128, 128, 128, "float32")
+
+    nn_module(q1, q2, q3)
+
+    tvm.testing.assert_allclose(q3.numpy(), numpy_result, atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_decorator():
+    q1 = torch.arange(8).type(torch.float32)
+    q2 = torch.zeros((8,), dtype=torch.float32)
+
+    MyModule(q1, q2)
+
+    tvm.testing.assert_allclose(q2.numpy(), (q1 + 1).numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_gpu():
+    cuda0 = torch.device("cuda:0")
+    q1 = torch.arange(8, device=cuda0).type(torch.float32)
+    q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)
+
+    ModuleGPU(q1, q2)
+
+    tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_torch_with_tvmscript():
+    ref_result = np.arange(8).astype("float32")
+
+    q1 = torch.arange(8).type(torch.float32)
+    q2 = torch.zeros((8,), dtype=torch.float32)
+
+    nn_module = MinuesOnes()
+
+    ret = nn_module.forward(q1, q2)
+
+    tvm.testing.assert_allclose(ret.numpy(), ref_result, atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_func_with_part_access_region():
+    a1 = torch.rand(128, 128)
+    a2 = torch.zeros(128, 128)
+    a3 = torch.zeros(128, 128)
+
+    result = a1 + 2
+
+    func_with_part_access_region.tune()
+    func_with_part_access_region(a1, a2, a3)
+
+    tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_loop_split():
+    x = torch.rand(128, 128).cuda()
+    y = torch.zeros(128).cuda()
+
+    result = torch.sum(x.cpu(), dim=1).numpy()
+
+    loop_split.tune(config, Target("nvidia/geforce-rtx-3070"))
+    loop_split(x, y)
+
+    tvm.testing.assert_allclose(y.cpu().numpy(), result, atol=1e-5, rtol=1e-5)
+
+
+def test_tvmscript_torch_elementwise_with_root():
+    a1 = torch.rand(128, 128)
+    a2 = torch.zeros(128, 128)
+    a3 = torch.zeros(128, 128)
+
+    result = a1 + 2
+
+    func = elementwise_with_root(128, 128, "float32")
+    func.tune(config)
+    func(a1, a2, a3)
+
+    tvm.testing.assert_allclose(a3.numpy(), result.numpy(), atol=1e-5, rtol=1e-5)
+
+
+if __name__ == "__main__":
+    test_tvmscript_torch_matmul()
+    test_tvmscript_torch_decorator()
+    test_tvmscript_torch_gpu()
+    test_torch_with_tvmscript()
+    test_tvmscript_torch_func_with_part_access_region()
+    test_tvmscript_torch_loop_split()
+    test_tvmscript_torch_elementwise_with_root()
diff --git a/apps/pt_tvmdsoop/tests/test_optimize_torch.py b/apps/pt_tvmdsoop/tests/test_optimize_torch.py
new file mode 100644
index 0000000000..258dfe55c4
--- /dev/null
+++ b/apps/pt_tvmdsoop/tests/test_optimize_torch.py
@@ -0,0 +1,161 @@
+# pylint: disable=missing-class-docstring
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for tvm torch module"""
+import tempfile
+
+import torch
+from torch.utils import benchmark
+from torchvision.models import resnet18
+
+import tvm
+import tvm.testing
+from tvm.contrib.torch import optimize_torch
+from tvm.meta_schedule import TuneConfig
+
+
+def test_matmul_tuning_relay():
+    def matmul(x, w):
+        return torch.matmul(x, w)
+
+    x = torch.randn(15, 20)
+    w = torch.randn(20, 30)
+    example_inputs = (x, w)
+
+    rt_mod = optimize_torch(matmul, example_inputs)
+    torch_answer = torch.matmul(x, w).numpy()
+    tvm_answer = rt_mod(x, w).numpy()
+
+    tvm.testing.assert_allclose(torch_answer, tvm_answer, atol=1e-5, rtol=1e-5)
+
+
+class InnerModel(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(1, 20, 5)
+
+    def forward(self, x):
+        return torch.nn.functional.relu(self.conv(x))
+
+
+class SimpleModel(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(20, 20, 5)
+        self.relu = InnerModel()
+
+    def forward(self, x):
+        x = self.relu(x)
+        return torch.nn.functional.relu(self.conv(x))
+
+
+def test_nested_module():
+    simple_module = SimpleModel()
+    example_input = torch.randn(20, 1, 10, 10)
+    optimized_module = optimize_torch(simple_module, example_input)
+    ret1 = simple_module(example_input).detach().numpy()
+    ret2 = optimized_module(example_input).detach().numpy()
+    tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
+
+
+def test_save_load_function():
+    def foo(x):
+        return 2 * x + 1
+
+    example_input = torch.rand(3)
+    opt_foo = optimize_torch(foo, example_input)
+    ret1 = opt_foo(example_input)
+    with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
+        torch.save(opt_foo, tmp.name)
+        loaded_mod = torch.load(tmp.name)
+        ret2 = loaded_mod(example_input)
+    tvm.testing.assert_allclose(ret1.numpy(), ret2.numpy(), atol=1e-5, rtol=1e-5)
+
+
+class MyResNet18(torch.nn.Module):
+    def __init__(self, config, target=None):
+        super(MyResNet18, self).__init__()
+        self.means = torch.nn.Parameter(
+            torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1)
+        ).cuda()
+        self.resnet = optimize_torch(resnet18(), [torch.rand(1, 3, 224, 224)], config, target)
+
+    def forward(self, input):
+        return self.resnet(input - self.means)
+
+
+class JitModule(torch.nn.Module):
+    def __init__(self):
+        super(JitModule, self).__init__()
+        self.means = torch.nn.Parameter(
+            torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1)
+        ).cuda()
+        self.resnet = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval()))
+
+    def forward(self, input):
+        return self.resnet(input - self.means)
+
+
+# default config for testing
+config = TuneConfig(
+    strategy="evolutionary",
+    num_trials_per_iter=4,
+    max_trials_per_task=8,
+    max_trials_global=16,
+)
+
+if torch.cuda.is_available():
+    target_cuda = "nvidia/geforce-rtx-3070"
+    meta_module_resnet18 = MyResNet18(config, target_cuda)
+    jit_module_resnet18 = JitModule()
+
+
+def compare_optimize_resnet18_to_torchscript():
+    results = []
+    for i in range(20):
+        test_input = torch.rand(1, 3, 224, 224).half().cuda()
+        sub_label = f"[test {i}]"
+        results.append(
+            benchmark.Timer(
+                stmt="meta_module_resnet18(test_input)",
+                setup="from __main__ import meta_module_resnet18",
+                globals={"test_input": test_input},
+                sub_label=sub_label,
+                description="tuning by meta",
+            ).blocked_autorange()
+        )
+        results.append(
+            benchmark.Timer(
+                stmt="jit_module_resnet18(test_input)",
+                setup="from __main__ import jit_module_resnet18",
+                globals={"test_input": test_input},
+                sub_label=sub_label,
+                description="tuning by jit",
+            ).blocked_autorange()
+        )
+    compare = benchmark.Compare(results)
+    compare.print()
+
+
+if __name__ == "__main__":
+    test_matmul_tuning_relay()
+    test_nested_module()
+    test_save_load_function()
+    if torch.cuda.is_available():
+        compare_optimize_resnet18_to_torchscript()
diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py
index 720ac29cc6..340f9cef9e 100644
--- a/python/tvm/contrib/torch/__init__.py
+++ b/python/tvm/contrib/torch/__init__.py
@@ -20,7 +20,6 @@ import os
 import platform
 import torch
 from tvm._ffi import libinfo
-from tvm.relay.frontend import pytorch
 
 
 def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
@@ -39,6 +38,7 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
 
 _load_platform_specific_library()
 
+
 from . import module
 
 GraphModule = module.GraphModule
@@ -49,3 +49,13 @@ from . import pytorch_tvm
 
 PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule
 compile = pytorch_tvm.compile
+
+from . import as_torch
+
+TVMScriptIRModule = as_torch.OperatorModuleWrapper
+as_torch = as_torch.as_torch
+
+from . import optimize_torch
+
+GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper
+optimize_torch = optimize_torch.optimize_torch
diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py
new file mode 100644
index 0000000000..3a2b4dda9e
--- /dev/null
+++ b/python/tvm/contrib/torch/as_torch.py
@@ -0,0 +1,124 @@
+# pylint: disable=inconsistent-return-statements
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring
+# pylint: disable=missing-class-docstring
+# pylint: disable=missing-function-docstring
+"""
+as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`.
+"""
+import tempfile
+from typing import Callable, List, Union
+
+import torch
+import torch.utils.dlpack
+
+import tvm
+from tvm.meta_schedule.tune import TuneConfig, tune_tir
+from tvm.target.target import Target
+from tvm.tir.schedule.schedule import Schedule
+
+
+# python wrapper for OperatorModule
+class OperatorModuleWrapper(torch.nn.Module):
+    def __init__(
+        self,
+        module: Union[
+            tvm.ir.module.IRModule,
+            tvm.tir.function.PrimFunc,
+        ],
+    ):
+        super().__init__()
+        self.rt_module = None  # runtime module
+        self.ir_module = module  # IR modules
+
+    def tune(self, config: TuneConfig = None, target: Union[str, Target] = None):
+        """
+        Tune the TVMscript code.
+
+        Parameters
+        ----------
+        config: Optional[TuneConfig]
+            The tuning configuration.
+
+        target : Optional[str, Target]
+            The target to tune for.
+        """
+        if config is None:
+            config = TuneConfig(
+                # Default setting
+                strategy="replay_trace",
+                num_trials_per_iter=32,
+                max_trials_per_task=32,
+                max_trials_global=32,
+            )
+        if target is None:
+            target = Target("llvm --num-cores=16")
+        with tempfile.TemporaryDirectory() as work_dir:
+            sch: Schedule = tune_tir(
+                mod=self.ir_module,
+                target=target,
+                config=config,
+                work_dir=work_dir,
+            )
+            self.ir_module = sch.mod
+            self.build(target)
+
+    def build(self, target=None):
+        runtime_module = tvm.build(self.ir_module, target=target)
+        func = tvm.get_global_func("tvmtorch.save_runtime_mod")
+        func(runtime_module)
+
+        self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper()
+
+    def forward(self, *torch_inputs: List[torch.Tensor]) -> List[torch.Tensor]:
+        if self.rt_module is None:
+            if torch_inputs[0].is_cuda:
+                self.build(target="cuda")
+            elif torch_inputs[0].device.type == "cpu":
+                self.build()
+            else:
+                raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet")
+
+        return self.rt_module.forward(torch_inputs)
+
+
+def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]):
+    """A decorator of converting TensorIR to PyTorch nn.Module.
+
+    Parameters
+    ----------
+    func: Optional[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]
+        The function written by TVMscript.
+
+    Returns
+    -------
+    mod : Union[OperatorModuleWrapper, Callable]
+        It will return an object, or a templated function of OperatorModuleWrapper,
+        which is the subclass of the original nn.Module.
+
+    """
+    if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)):
+        return OperatorModuleWrapper(func)
+    if isinstance(func, Callable):
+
+        def func_get_param(*args, **kargs):
+            return OperatorModuleWrapper(func(*args, **kargs))
+
+        return func_get_param
diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py
new file mode 100644
index 0000000000..282e6c5dc8
--- /dev/null
+++ b/python/tvm/contrib/torch/optimize_torch.py
@@ -0,0 +1,198 @@
+# pylint: disable=inconsistent-return-statements
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring
+# pylint: disable=missing-class-docstring
+# pylint: disable=missing-function-docstring
+"""
+optimize_torch: a function similar to `torch.jit.trace`,
+which is used to optimize the `torch.nn.module` by TVM metaSchedule,
+and returns a custom TorchScript operator
+"""
+import base64
+import contextlib
+import tempfile
+from typing import Dict, Optional, Tuple, Union
+import warnings
+
+import torch
+import torch.utils.dlpack
+
+import tvm
+from tvm import relay
+from tvm._ffi import get_global_func, register_func
+from tvm.ir.module import IRModule
+from tvm.ir.transform import PassContext
+from tvm.meta_schedule import TuneConfig, default_config
+from tvm.meta_schedule.apply_history_best import ApplyHistoryBest
+from tvm.meta_schedule.relay_integration import extract_task_from_relay
+from tvm.meta_schedule.tune import tune_extracted_tasks
+from tvm.meta_schedule.utils import autotvm_silencer
+from tvm.runtime import vm
+from tvm.runtime.module import Module
+from tvm.runtime.ndarray import NDArray
+from tvm.target.target import Target
+
+
+# The python wrapper for GraphExecutorFactory
+class GraphExecutorFactoryWrapper(torch.nn.Module):
+    def __init__(self, module: tvm.runtime.Module):
+        super().__init__()
+        self.inner_module = module
+
+    def forward(self, *torch_inputs: Tuple[torch.Tensor]):
+        ret = self.inner_module.forward(torch_inputs)
+        if len(ret) == 1:
+            return ret[0]
+        return ret
+
+
+def llvm_target():
+    return "llvm -num-cores"
+
+
+@register_func("script_torch.save_to_base64")
+def save_to_base64(obj) -> bytes:
+    with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
+        obj.export_library(tmpfile.name)
+        with open(tmpfile.name, "rb") as tfile:
+            return base64.b64encode(tfile.read())
+
+
+def tune_relay_auto(
+    mod: IRModule,
+    target: Union[str, Target],
+    config: TuneConfig,
+    work_dir: str,
+    backend: str = "graph",
+    params: Optional[Dict[str, NDArray]] = None,
+) -> Union[Module, vm.Executable]:
+    """A wrapper of `tune_relay` but provide a default setting for the config.
+
+    Parameters
+    ----------
+    mod : IRModule
+        The module to tune.
+    target : Union[str, Target]
+        The target to tune for.
+    config : TuneConfig
+        The search strategy config.
+    params : Optional[Dict[str, tvm.runtime.NDArray]]
+        The associated parameters of the program
+    work_dir : Optional[str]
+        The working directory to save intermediate results.
+    backend : str = "graph"
+        The backend to use for relay compilation(graph / vm).
+
+    Returns
+    -------
+    lib : Union[Module, tvm.runtime.vm.Executable]
+        The built runtime module or vm Executable for the given relay workload.
+    """
+    target = default_config.target(target)
+    extracted_tasks = extract_task_from_relay(mod, target, params)
+    if config is None:
+        config = TuneConfig(
+            num_trials_per_iter=16,
+            max_trials_global=16 * len(extracted_tasks),
+        )
+    database = tune_extracted_tasks(extracted_tasks, config, work_dir)
+    relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend]
+    with target, autotvm_silencer(), ApplyHistoryBest(database):
+        with PassContext(
+            opt_level=3,
+            config={
+                "relay.backend.use_meta_schedule": True,
+                "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda",
+            },
+        ):
+            return relay_build(mod, target=target, params=params)
+
+
+def optimize_torch(
+    func,
+    example_inputs,
+    tuning_config=None,
+    target=None,
+    work_dir=None,
+):
+    """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule.
+
+    Parameters
+    ----------
+    func : callable or torch.nn.Module
+        A Python function or nn.Module that could run by TorchScript's trace.
+        (ie: torch.jit.trace(model, input))
+
+    example_inputs : tuple or torch.Tensor
+        Inputs to `torch.jit.trace`.
+
+    tuning_config : tvm.meta_schedule.TuneConfig
+        The configuration for tuning by MetaSchedule.
+        If user doesn't set the config, the tuning will run with a default setting.
+        Here, the total number of trials is proportional
+        to the number of tunable tasks in the input module.
+
+    target : Optional[Union[str, Target]]
+        The target of the compilation.
+        If user doesn't set the target, the module will be built for the CPU target.
+
+    work_dir : Optional[str]
+        The working directory to save intermediate results.
+
+    Returns
+    -------
+    mod : GraphExecutorFactoryWrapper
+        It will return an object of GraphExecutorFactoryWrapper,
+        which is the subclass of the original nn.Module.
+    """
+
+    if target is None:
+        target = llvm_target()
+
+    if tuning_config is None:
+        warning_msg = (
+            "Using the default tuning parameters.",
+            "The default number of trials is set to a small value to let tuning finish quickly.",
+            "For optimal performance, it is recommended to provide",
+            "the `tuning_config` argument with a bigger number of trials.",
+        )
+        warnings.warn(" ".join(warning_msg), stacklevel=2)
+
+    # If `func` is already a traced module this statement makes no effect
+    jit_mod = torch.jit.trace(func, example_inputs)
+
+    if isinstance(example_inputs, torch.Tensor):
+        example_inputs = [example_inputs]
+
+    shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
+    mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)  # IRmodule
+    if work_dir:
+        context_manager = contextlib.nullcontext(work_dir)
+    else:
+        context_manager = tempfile.TemporaryDirectory()
+    with context_manager as work_dir_path:
+        executor_factory = tune_relay_auto(
+            mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path
+        )
+
+    save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod")
+    save_runtime_mod(executor_factory.module)
+
+    return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 7f5b3e86f3..908af081c9 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -435,11 +435,23 @@ class TVMScriptParser(Transformer):
                 T.evaluate(0)  # 4. This function returns 0
         """
 
+        def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
+            if isinstance(decorator, ast.Call):
+                if len(decorator.params) != 1:
+                    return False
+                func_name = decorator.func_name
+            else:
+                func_name = decorator
+            if isinstance(func_name, ast.Var):
+                return func_name.id.name == "as_torch"
+
         def check_decorator(decorators: List[ast.Expr]) -> bool:
             """Check the decorator is `T.prim_func"""
-            if len(decorators) != 1:
+            if len(decorators) > 2 or len(decorators) == 0:
+                return False
+            if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
                 return False
-            d: ast.Expr = decorators[0]
+            d: ast.Expr = decorators[-1]
             return (
                 isinstance(d, ast.Attr)
                 and isinstance(d.object, ast.Var)
diff --git a/src/contrib/torch/base64.h b/src/contrib/torch/base64.h
new file mode 100644
index 0000000000..859fd1abcf
--- /dev/null
+++ b/src/contrib/torch/base64.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file base64.h
+ * \brief Util functions for converting plain bytes back to plain bytes
+ */
+
+#ifndef TVM_CONTRIB_TORCH_BASE64_H_
+#define TVM_CONTRIB_TORCH_BASE64_H_
+
+#include <tvm/runtime/logging.h>
+
+#include <cctype>
+#include <cstdio>
+#include <string>
+
+#include "../../support/base64.h"
+
+namespace tvm {
+namespace support {
+
+size_t b64strlen(const std::string b64str) {
+  ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
+  size_t length = b64str.size() / 4 * 3;
+  if (b64str[b64str.size() - 2] == '=') {
+    length -= 2;
+  } else if (b64str[b64str.size() - 1] == '=') {
+    length -= 1;
+  }
+  return length;
+}
+
+void b64decode(const std::string b64str, u_char* ret) {
+  size_t index = 0;
+  const auto length = b64str.size();
+  for (size_t i = 0; i < length; i += 4) {
+    int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
+    int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
+    int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
+    int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
+    u_char st1 = (ch0 << 2) + (ch1 >> 4);
+    ret[index++] = st1;
+    if (b64str[i + 2] != '=') {
+      u_char st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
+      ret[index++] = st2;
+      if (b64str[i + 3] != '=') {
+        u_char st3 = ((ch2 & 0b11) << 6) + ch3;
+        ret[index++] = st3;
+      }
+    }
+  }
+  ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
+}
+
+}  // namespace support
+}  // namespace tvm
+
+#endif  // TVM_CONTRIB_TORCH_BASE64_H_
diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
new file mode 100644
index 0000000000..12c1017bea
--- /dev/null
+++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <ATen/DLConvertor.h>
+#include <dlpack/dlpack.h>
+#include <dmlc/memory_io.h>
+#include <torch/custom_class.h>
+#include <torch/script.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
+#include <tvm/target/target.h>
+
+#include <cstdio>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "../../../runtime/graph_executor/graph_executor_factory.h"
+#include "../base64.h"
+
+namespace tvm {
+namespace contrib {
+
+/**
+ * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects
+ */
+struct ThreadLocalStore {
+  tvm::runtime::Module mod;
+  static ThreadLocalStore* ThreadLocal() {
+    thread_local ThreadLocalStore tls;
+    return &tls;
+  }
+};
+
+using SerializationType = std::string;  // base64 stream
+
+SerializationType serialize(tvm::runtime::Module module) {
+  static const runtime::PackedFunc* f_to_str =
+      runtime::Registry::Get("script_torch.save_to_base64");
+  ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
+                      "`script_torch.save_to_base64` in the global registry";
+  return (*f_to_str)(module);
+}
+
+struct Deleter {  // deleter
+  explicit Deleter(std::string file_name) { this->file_name = file_name; }
+  void operator()(FILE* p) const {
+    fclose(p);
+    ICHECK(remove(file_name.c_str()) == 0)
+        << "Failed to  remove temporary file (" << file_name << ")";
+  }
+  std::string file_name;
+};
+
+tvm::runtime::Module deserialize(SerializationType state) {
+  auto length = tvm::support::b64strlen(state);
+
+  std::vector<u_char> bytes(length);
+  tvm::support::b64decode(state, bytes.data());
+
+  const std::string name = tmpnam(NULL);
+  auto file_name = name + ".so";
+  std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
+  fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
+  fflush(pFile.get());
+
+  std::string load_f_name = "runtime.module.loadfile_so";
+  const PackedFunc* f = runtime::Registry::Get(load_f_name);
+  ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
+                       << " resolved to (" << load_f_name << ") in the global registry."
+                       << "Ensure that you have loaded the correct runtime code, and"
+                       << "that you are on the correct hardware architecture.";
+
+  tvm::runtime::Module ret = (*f)(file_name, "");
+
+  return ret;
+}
+
+/**
+ * @brief A Torch's module which wraps TVM's OperatorModule Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class OperatorModuleWrapper : public torch::jit::CustomClassHolder {
+ public:
+  OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; }
+
+  void forward(const c10::List<at::Tensor>& inputs) {
+    int input_length = inputs.size();
+
+    std::vector<DLManagedTensor*> tensors;
+
+    for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i]));
+
+    tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__");
+
+    std::vector<TVMValue> tvm_values(input_length);
+    std::vector<int> tvm_type_codes(input_length);
+    tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
+    for (int k = 0; k < input_length; ++k) {
+      setter(k, &tensors[k]->dl_tensor);
+    }
+
+    run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length),
+                   nullptr);
+
+    for (int k = 0; k < input_length; ++k) {
+      tensors[k]->deleter(tensors[k]);
+    }
+  }
+
+  SerializationType Serialize() { return serialize(runtime_module); }
+
+  explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); }
+
+ private:
+  tvm::runtime::Module runtime_module;
+};
+
+tvm::Device getDevice(const at::Tensor& tensor) {
+  tvm::Device dev;
+  dev.device_id = tensor.get_device();
+  switch (tensor.device().type()) {
+    case at::DeviceType::CPU:
+      dev.device_type = DLDeviceType::kDLCPU;
+      if (dev.device_id == -1) {
+        /*
+         * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning
+         * Thus we manually set the device ID as 0 for avoiding potentially error of index out of
+         * bounds
+         */
+        dev.device_id = 0;
+      }
+      break;
+    case at::DeviceType::CUDA:
+      dev.device_type = DLDeviceType::kDLCUDA;
+      break;
+    default:
+      TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str());
+  }
+  return dev;
+}
+
+/**
+ * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
+ public:
+  explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory)
+      : executor_factory_(executor_factory) {
+    CHECK(executor_factory_->IsInstance<runtime::GraphExecutorFactory>())
+        << "module is not an instance of GraphExecutorFactory";
+  }
+
+  GraphExecutorFactoryWrapper()
+      : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {}
+
+  c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) {
+    int input_length = inputs.size();
+
+    if (!executor_.defined()) {
+      TORCH_CHECK(input_length > 0, "Receive empty list of input tensors");
+      DLDevice input_device = getDevice(inputs.get(0));
+
+      auto tmp = executor_factory_.GetFunction("default");
+
+      executor_ = tmp(input_device);
+    }
+
+    std::vector<DLManagedTensor*> tensors;
+
+    for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i]));
+
+    tvm::runtime::PackedFunc run = executor_.GetFunction("run");
+    tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input");
+    tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output");
+    tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs");
+
+    for (int k = 0; k < input_length; ++k) {
+      set_input(k, &tensors[k]->dl_tensor);
+    }
+
+    run();
+
+    int64_t output_length = get_num_outputs();
+
+    c10::List<at::Tensor> outputs;
+    outputs.reserve(output_length);
+
+    for (int k = 0; k < output_length; ++k) {
+      tvm::runtime::NDArray results = get_output(k);
+      at::Tensor atTensor = at::fromDLPack(results.ToDLPack());
+      outputs.emplace_back(atTensor);
+    }
+
+    for (int k = 0; k < input_length; ++k) {
+      tensors[k]->deleter(tensors[k]);
+    }
+    return outputs;
+  }
+
+  SerializationType Serialize() { return serialize(executor_factory_); }
+
+  explicit GraphExecutorFactoryWrapper(SerializationType state) {
+    executor_factory_ = deserialize(state);
+  }
+
+ private:
+  tvm::runtime::Module executor_factory_;
+  tvm::runtime::Module executor_;
+};
+
+TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
+  ThreadLocalStore::ThreadLocal()->mod = mod;
+});
+
+TORCH_LIBRARY(tvm_torch, m) {
+  m.class_<OperatorModuleWrapper>("OperatorModuleWrapper")
+      .def(torch::init<>())
+      .def("forward", &OperatorModuleWrapper::forward)
+      .def_pickle(
+          [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> SerializationType {
+            return self->Serialize();
+          },
+          [](SerializationType state) {
+            return c10::make_intrusive<OperatorModuleWrapper>(state);
+          });
+  m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper")
+      .def(torch::init<>())
+      .def("forward", &GraphExecutorFactoryWrapper::forward)
+      .def_pickle(
+          [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> SerializationType {
+            return self->Serialize();
+          },
+          [](SerializationType state) {
+            return c10::make_intrusive<GraphExecutorFactoryWrapper>(state);
+          });
+}
+
+}  // namespace contrib
+}  // namespace tvm