You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ya...@apache.org on 2023/07/28 21:13:37 UTC

[tvm] branch unity updated: [Unity] nn.Module Torch Integration (#15424)

This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7a4c85e5fb [Unity] nn.Module Torch Integration (#15424)
7a4c85e5fb is described below

commit 7a4c85e5fbedb0b4066b2c17465742e4e5afe378
Author: Junru Shao <ju...@apache.org>
AuthorDate: Fri Jul 28 14:13:32 2023 -0700

    [Unity] nn.Module Torch Integration (#15424)
    
    This PR introduces a `nn.Module.jit` method that effectively converts an
    `nn.Module` to a collection of torch-in-torch-out callables for users to
    debug with.
---
 python/tvm/relax/frontend/nn/core.py  |  14 ++--
 python/tvm/relax/frontend/nn/spec.py  |   8 ++-
 python/tvm/relax/frontend/nn/torch.py | 125 ++++++++++++++++++++++++++++++++++
 3 files changed, 141 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py
index 1ac4a6fbdb..46197b7141 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -345,7 +345,10 @@ class Module:
             if hasattr(item, "to") and callable(item.to):
                 item.to(dtype=dtype)
 
-    def export_tvm(self, spec: "_spec.Module") -> Tuple[IRModule, List[Tuple[str, Parameter]]]:
+    def export_tvm(
+        self,
+        spec: "_spec.ModuleSpecType",
+    ) -> Tuple[IRModule, List[Tuple[str, Parameter]]]:
         """Export the module to TVM IRModule and parameters"""
         from . import spec as _spec  # pylint: disable=import-outside-toplevel
 
@@ -353,7 +356,7 @@ class Module:
         mod, params = _spec.SpecBuilder().build(spec)
         return mod, params
 
-    def jit(
+    def jit(  # pylint: disable=too-many-arguments
         self,
         spec: "_spec.Module",
         target: Union[str, Target] = "llvm",
@@ -377,10 +380,13 @@ class Module:
         # Compile mod and feed it to VM
         mod = relax.pipeline.get_pipeline(pipeline)(mod)  # pylint: disable=no-value-for-parameter
         mod = relax.build(mod, target=target)
-        VirtualMachine(mod, device)
+        vm = VirtualMachine(mod, device)  # pylint: disable=invalid-name
 
         if out_format == "torch":
-            raise NotImplementedError
+            from . import torch  # pylint: disable=import-outside-toplevel
+
+            return torch.TorchModule(spec=spec, params=params, vm=vm)
+
         raise ValueError(f"Unknown out_format: {out_format}")
 
 
diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py
index 73d7c80638..a279616f31 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -117,9 +117,13 @@ class MethodSpec:
         return MethodSpec(method, arg_names, arg_specs)
 
     @staticmethod
-    def from_torch(torch_args: List[Any], method: Callable) -> "MethodSpec":
+    def from_torch(args: List[Any], method: Callable) -> "MethodSpec":
         """Converts a list of torch tensors to MethodSpec."""
-        raise NotImplementedError
+        from .torch import (  # pylint: disable=import-outside-toplevel
+            _method_spec_from_torch,
+        )
+
+        return _method_spec_from_torch(args, method)
 
     def as_inputs(self) -> List[Union[tir.Var, core.Tensor]]:
         """Convert the MethodSpec to a list of inputs to Module's method."""
diff --git a/python/tvm/relax/frontend/nn/torch.py b/python/tvm/relax/frontend/nn/torch.py
new file mode 100644
index 0000000000..5d8c890845
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/torch.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""PyTorch integration with nn.Module"""
+import inspect
+from typing import Any, Callable, List
+
+import torch
+
+from tvm.ir import Array
+from tvm.runtime import NDArray, ShapeTuple, ndarray
+from tvm.runtime.relax_vm import VirtualMachine
+
+from . import core
+from . import spec as _spec
+
+
+class TorchModule:  # pylint: disable=too-few-public-methods
+    """A wrapper on top of TVM VirtualMachine that takes torch tensors as inputs and returns torch
+    tensors as outputs"""
+
+    spec: _spec.ModuleSpec
+    vm: VirtualMachine  # pylint: disable=invalid-name
+    params: List[NDArray]
+    effects: List[Any]
+
+    def __init__(  # pylint: disable=invalid-name
+        self,
+        spec: _spec.ModuleSpec,
+        vm: VirtualMachine,
+        params: List[NDArray],
+    ):
+        effects = vm["_initialize_effect"]()
+        self.spec = spec
+        self.vm = vm
+        self.params = params
+        self.effects = effects
+
+    def __getitem__(self, method_name: str) -> Callable:
+        def _find_method(method_name):
+            for key, value in zip(self.spec.method_names, self.spec.method_specs):
+                if method_name == key:
+                    return value
+            raise ValueError(f"Method `{method_name}` is not found in the module spec. {self.spec}")
+
+        method_spec = _find_method(method_name)
+        method = self.vm[method_name]
+
+        def _closure(*args):
+            if len(args) != len(method_spec.arg_names):
+                raise TypeError(
+                    f"Argument length mismatch. Expected {len(method_spec.args)} arguments, "
+                    f"but got {len(args)} arguments. The spec is: {method_spec}"
+                )
+            args = [
+                _torch_to_tvm(arg_name, arg_spec, arg)
+                for arg_name, arg_spec, arg in zip(
+                    method_spec.arg_names, method_spec.arg_specs, args
+                )
+            ]
+            outputs, self.effects = method(*args, *self.params, *self.effects)
+            return _tvm_to_torch(outputs)
+
+        _closure.__name__ = method_name
+        return _closure
+
+
+@staticmethod
+def _tvm_to_torch(arg):
+    if isinstance(arg, (list, tuple, Array)):
+        return [_tvm_to_torch(i) for i in arg]
+    if isinstance(arg, ndarray.NDArray):
+        return torch.utils.dlpack.from_dlpack(arg)
+    if isinstance(arg, ShapeTuple):
+        return list(arg)
+    raise TypeError(f"Unsupported argument type: {type(arg)}")
+
+
+def _torch_to_tvm(arg_name, arg_spec, arg_torch):
+    if isinstance(arg_spec, _spec.Tensor):
+        if not isinstance(arg_torch, torch.Tensor):
+            raise TypeError(
+                f"Expected argument `{arg_name}` to be `torch.Tensor`, "
+                f"but got {type(arg_torch)}"
+            )
+        return core._from_dlpack(arg_torch)  # pylint: disable=protected-access
+    if isinstance(arg_spec, _spec.Int):
+        if not isinstance(arg_torch, int):
+            raise TypeError(
+                f"Expected argument `{arg_name}` to be `int`, but got {type(arg_torch)}"
+            )
+        return ShapeTuple([arg_torch])
+    raise TypeError(f"Unsupported spec item type: {type(arg_spec)}")
+
+
+def _method_spec_from_torch(
+    args_torch: List[Any],
+    method: Callable,
+):
+    def _as_spec(arg_torch):
+        if isinstance(arg_torch, torch.Tensor):
+            _, dtype = str(arg_torch.dtype).rsplit(".", maxsplit=1)
+            return _spec.Tensor(shape=list(arg_torch.shape), dtype=dtype)
+        if isinstance(arg_torch, int):
+            return _spec.Int()
+        raise TypeError(f"Unsupported argument type: {type(arg_torch)}")
+
+    arg_names = list(inspect.signature(method).parameters.keys())
+    if len(arg_names) != len(args_torch):
+        raise TypeError(f"Expected {len(arg_names)} arguments, but got {len(args_torch)} arguments")
+    arg_specs = [_as_spec(i) for i in args_torch]
+    return _spec.MethodSpec(method, arg_names, arg_specs)