You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ya...@apache.org on 2023/07/28 21:13:37 UTC
[tvm] branch unity updated: [Unity] nn.Module Torch Integration (#15424)
This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7a4c85e5fb [Unity] nn.Module Torch Integration (#15424)
7a4c85e5fb is described below
commit 7a4c85e5fbedb0b4066b2c17465742e4e5afe378
Author: Junru Shao <ju...@apache.org>
AuthorDate: Fri Jul 28 14:13:32 2023 -0700
[Unity] nn.Module Torch Integration (#15424)
This PR introduces a `nn.Module.jit` method that effectively converts an
`nn.Module` to a collection of torch-in-torch-out callables for users to
debug with.
---
python/tvm/relax/frontend/nn/core.py | 14 ++--
python/tvm/relax/frontend/nn/spec.py | 8 ++-
python/tvm/relax/frontend/nn/torch.py | 125 ++++++++++++++++++++++++++++++++++
3 files changed, 141 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py
index 1ac4a6fbdb..46197b7141 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -345,7 +345,10 @@ class Module:
if hasattr(item, "to") and callable(item.to):
item.to(dtype=dtype)
- def export_tvm(self, spec: "_spec.Module") -> Tuple[IRModule, List[Tuple[str, Parameter]]]:
+ def export_tvm(
+ self,
+ spec: "_spec.ModuleSpecType",
+ ) -> Tuple[IRModule, List[Tuple[str, Parameter]]]:
"""Export the module to TVM IRModule and parameters"""
from . import spec as _spec # pylint: disable=import-outside-toplevel
@@ -353,7 +356,7 @@ class Module:
mod, params = _spec.SpecBuilder().build(spec)
return mod, params
- def jit(
+ def jit( # pylint: disable=too-many-arguments
self,
spec: "_spec.Module",
target: Union[str, Target] = "llvm",
@@ -377,10 +380,13 @@ class Module:
# Compile mod and feed it to VM
mod = relax.pipeline.get_pipeline(pipeline)(mod) # pylint: disable=no-value-for-parameter
mod = relax.build(mod, target=target)
- VirtualMachine(mod, device)
+ vm = VirtualMachine(mod, device) # pylint: disable=invalid-name
if out_format == "torch":
- raise NotImplementedError
+ from . import torch # pylint: disable=import-outside-toplevel
+
+ return torch.TorchModule(spec=spec, params=params, vm=vm)
+
raise ValueError(f"Unknown out_format: {out_format}")
diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py
index 73d7c80638..a279616f31 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -117,9 +117,13 @@ class MethodSpec:
return MethodSpec(method, arg_names, arg_specs)
@staticmethod
- def from_torch(torch_args: List[Any], method: Callable) -> "MethodSpec":
+ def from_torch(args: List[Any], method: Callable) -> "MethodSpec":
"""Converts a list of torch tensors to MethodSpec."""
- raise NotImplementedError
+ from .torch import ( # pylint: disable=import-outside-toplevel
+ _method_spec_from_torch,
+ )
+
+ return _method_spec_from_torch(args, method)
def as_inputs(self) -> List[Union[tir.Var, core.Tensor]]:
"""Convert the MethodSpec to a list of inputs to Module's method."""
diff --git a/python/tvm/relax/frontend/nn/torch.py b/python/tvm/relax/frontend/nn/torch.py
new file mode 100644
index 0000000000..5d8c890845
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/torch.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""PyTorch integration with nn.Module"""
+import inspect
+from typing import Any, Callable, List
+
+import torch
+
+from tvm.ir import Array
+from tvm.runtime import NDArray, ShapeTuple, ndarray
+from tvm.runtime.relax_vm import VirtualMachine
+
+from . import core
+from . import spec as _spec
+
+
+class TorchModule: # pylint: disable=too-few-public-methods
+ """A wrapper on top of TVM VirtualMachine that takes torch tensors as inputs and returns torch
+ tensors as outputs"""
+
+ spec: _spec.ModuleSpec
+ vm: VirtualMachine # pylint: disable=invalid-name
+ params: List[NDArray]
+ effects: List[Any]
+
+ def __init__( # pylint: disable=invalid-name
+ self,
+ spec: _spec.ModuleSpec,
+ vm: VirtualMachine,
+ params: List[NDArray],
+ ):
+ effects = vm["_initialize_effect"]()
+ self.spec = spec
+ self.vm = vm
+ self.params = params
+ self.effects = effects
+
+ def __getitem__(self, method_name: str) -> Callable:
+ def _find_method(method_name):
+ for key, value in zip(self.spec.method_names, self.spec.method_specs):
+ if method_name == key:
+ return value
+ raise ValueError(f"Method `{method_name}` is not found in the module spec. {self.spec}")
+
+ method_spec = _find_method(method_name)
+ method = self.vm[method_name]
+
+ def _closure(*args):
+ if len(args) != len(method_spec.arg_names):
+ raise TypeError(
+ f"Argument length mismatch. Expected {len(method_spec.args)} arguments, "
+ f"but got {len(args)} arguments. The spec is: {method_spec}"
+ )
+ args = [
+ _torch_to_tvm(arg_name, arg_spec, arg)
+ for arg_name, arg_spec, arg in zip(
+ method_spec.arg_names, method_spec.arg_specs, args
+ )
+ ]
+ outputs, self.effects = method(*args, *self.params, *self.effects)
+ return _tvm_to_torch(outputs)
+
+ _closure.__name__ = method_name
+ return _closure
+
+
+@staticmethod
+def _tvm_to_torch(arg):
+ if isinstance(arg, (list, tuple, Array)):
+ return [_tvm_to_torch(i) for i in arg]
+ if isinstance(arg, ndarray.NDArray):
+ return torch.utils.dlpack.from_dlpack(arg)
+ if isinstance(arg, ShapeTuple):
+ return list(arg)
+ raise TypeError(f"Unsupported argument type: {type(arg)}")
+
+
+def _torch_to_tvm(arg_name, arg_spec, arg_torch):
+ if isinstance(arg_spec, _spec.Tensor):
+ if not isinstance(arg_torch, torch.Tensor):
+ raise TypeError(
+ f"Expected argument `{arg_name}` to be `torch.Tensor`, "
+ f"but got {type(arg_torch)}"
+ )
+ return core._from_dlpack(arg_torch) # pylint: disable=protected-access
+ if isinstance(arg_spec, _spec.Int):
+ if not isinstance(arg_torch, int):
+ raise TypeError(
+ f"Expected argument `{arg_name}` to be `int`, but got {type(arg_torch)}"
+ )
+ return ShapeTuple([arg_torch])
+ raise TypeError(f"Unsupported spec item type: {type(arg_spec)}")
+
+
+def _method_spec_from_torch(
+ args_torch: List[Any],
+ method: Callable,
+):
+ def _as_spec(arg_torch):
+ if isinstance(arg_torch, torch.Tensor):
+ _, dtype = str(arg_torch.dtype).rsplit(".", maxsplit=1)
+ return _spec.Tensor(shape=list(arg_torch.shape), dtype=dtype)
+ if isinstance(arg_torch, int):
+ return _spec.Int()
+ raise TypeError(f"Unsupported argument type: {type(arg_torch)}")
+
+ arg_names = list(inspect.signature(method).parameters.keys())
+ if len(arg_names) != len(args_torch):
+ raise TypeError(f"Expected {len(arg_names)} arguments, but got {len(args_torch)} arguments")
+ arg_specs = [_as_spec(i) for i in args_torch]
+ return _spec.MethodSpec(method, arg_names, arg_specs)