You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/12 21:04:28 UTC
[tvm] branch unity updated: [Unity][Frontend][NN] Enable tuple/list input (#15670)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cdf40be5ca [Unity][Frontend][NN] Enable tuple/list input (#15670)
cdf40be5ca is described below
commit cdf40be5ca7f4090b80b3ddfd09805d55e48a3fb
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Tue Sep 12 14:04:22 2023 -0700
[Unity][Frontend][NN] Enable tuple/list input (#15670)
- Enable tuple/list input for models, with nesting support.
- Correct the order of inputs for JIT execution.
---
python/tvm/relax/frontend/nn/spec.py | 168 +++++++++++++++-----
python/tvm/relax/frontend/nn/torch.py | 20 ++-
tests/python/relax/test_frontend_nn_jit.py | 208 +++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_modules.py | 80 ++++++++++
4 files changed, 429 insertions(+), 47 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py
index aeecfee782..6c8b42e46c 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -18,7 +18,7 @@
from collections import defaultdict
import inspect
import threading
-from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional
+import typing
from tvm import tir
from tvm.ir import IRModule
@@ -26,12 +26,13 @@ from tvm.runtime import load_static_library
from ... import expr as rx
from ...block_builder import BlockBuilder
-from ...struct_info import ShapeStructInfo
+from ...struct_info import ShapeStructInfo, TupleStructInfo
from . import core
-ArgSpecType = Union["Int", "Tensor"]
-MethodSpecType = Union["MethodSpec", Dict[str, ArgSpecType]]
-ModuleSpecType = Union["ModuleSpec", Dict[str, MethodSpecType]]
+ArgSpecType = typing.Union["Int", "Tensor"]
+MethodSpecType = typing.Union["MethodSpec", typing.Dict[str, ArgSpecType]]
+ModuleSpecType = typing.Union["ModuleSpec", typing.Dict[str, MethodSpecType]]
+SpecAny = typing.Union["Int", "Tensor", "Tuple"]
class Int: # pylint: disable=too-few-public-methods
@@ -47,10 +48,10 @@ class Int: # pylint: disable=too-few-public-methods
class Tensor: # pylint: disable=too-few-public-methods
"""A tensor input with static ndim and dtype, but can have symbolic shapes."""
- shape: List[Union[int, str]]
+ shape: typing.List[typing.Union[int, str]]
dtype: str
- def __init__(self, shape: Sequence[Union[int, str]], dtype: str) -> None:
+ def __init__(self, shape: typing.Sequence[typing.Union[int, str]], dtype: str) -> None:
self.shape = list(shape)
self.dtype = dtype
@@ -59,14 +60,38 @@ class Tensor: # pylint: disable=too-few-public-methods
return f"Tensor([{shape}], '{self.dtype}')"
+class Tuple:
+ """A tuple input or a list input"""
+
+ name: str
+ elements: typing.Union[typing.List[SpecAny], typing.Tuple[SpecAny, ...]]
+
+ def __init__(
+ self,
+ name: str,
+ elements: typing.Union[typing.List[SpecAny], typing.Tuple[SpecAny, ...]],
+ ) -> None:
+ assert isinstance(elements, (tuple, list)), f"Unsupported container type: {type(elements)}"
+ self.name = name
+ self.elements = elements
+
+ def __repr__(self) -> str:
+ return self.elements.__repr__()
+
+
class MethodSpec:
"""A spec for a compiled method"""
- method: Callable
- arg_names: List[str]
- arg_specs: List[ArgSpecType]
+ method: typing.Callable
+ arg_names: typing.List[str]
+ arg_specs: typing.List[ArgSpecType]
- def __init__(self, method: Callable, arg_names: List[str], arg_specs: List[ArgSpecType]):
+ def __init__(
+ self,
+ method: typing.Callable,
+ arg_names: typing.List[str],
+ arg_specs: typing.List[ArgSpecType],
+ ):
self.method = method
self.arg_names = arg_names
self.arg_specs = arg_specs
@@ -85,7 +110,7 @@ class MethodSpec:
return self._repr(name="MethodSpec")
@staticmethod
- def from_raw(spec: MethodSpecType, method: Callable) -> "MethodSpec":
+ def from_raw(spec: MethodSpecType, method: typing.Callable) -> "MethodSpec":
"""Create MethodSpec from raw python dictionaries.
Examples
@@ -105,22 +130,37 @@ class MethodSpec:
method_signature = inspect.signature(method)
arg_names = list(method_signature.parameters.keys())
arg_specs = []
+
+ def _convert_arg_spec(arg_spec, arg_name):
+ if arg_spec is Int or arg_spec is int:
+ return Int()
+ elif isinstance(arg_spec, str) and arg_spec == "int":
+ return Int()
+ elif isinstance(arg_spec, (Int, Tensor)):
+ return arg_spec
+ elif isinstance(arg_spec, (tuple, list, Tuple)):
+ return Tuple(
+ arg_name,
+ elements=type(arg_spec)(
+ [
+ _convert_arg_spec(arg_spec_i, f"{arg_name}_{i}")
+ for i, arg_spec_i in enumerate(arg_spec)
+ ]
+ ),
+ )
+
+ else:
+ raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
+
for arg_name in arg_names:
if arg_name in spec:
arg_spec = spec[arg_name]
- if arg_spec is Int or arg_spec is int:
- arg_spec = Int()
- elif isinstance(arg_spec, str) and arg_spec == "int":
- arg_spec = Int()
- elif isinstance(arg_spec, (Int, Tensor)):
- pass
- else:
- raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
+ arg_spec = _convert_arg_spec(arg_spec, arg_name)
arg_specs.append(arg_spec)
return MethodSpec(method, arg_names, arg_specs)
@staticmethod
- def from_torch(args: List[Any], method: Callable) -> "MethodSpec":
+ def from_torch(args: typing.List[typing.Any], method: typing.Callable) -> "MethodSpec":
"""Converts a list of torch tensors to MethodSpec."""
from .torch import ( # pylint: disable=import-outside-toplevel
_method_spec_from_torch,
@@ -128,9 +168,9 @@ class MethodSpec:
return _method_spec_from_torch(args, method)
- def as_inputs(self) -> List[Union[tir.Var, core.Tensor]]:
+ def as_inputs(self) -> typing.List[typing.Union[tir.Var, core.Tensor]]:
"""Convert the MethodSpec to a list of inputs to Module's method."""
- str2var: Dict[str, tir.Var] = {}
+ str2var: typing.Dict[str, tir.Var] = {}
def _get_var(name: str) -> tir.Var:
if name in str2var:
@@ -139,8 +179,7 @@ class MethodSpec:
str2var[name] = var
return var
- args = []
- for arg_name, arg_spec in zip(self.arg_names, self.arg_specs):
+ def _convert_input(arg_name, arg_spec):
if isinstance(arg_spec, Int):
arg = _get_var(arg_name)
elif isinstance(arg_spec, Tensor):
@@ -149,8 +188,26 @@ class MethodSpec:
shape=[_get_var(x) if isinstance(x, str) else x for x in arg_spec.shape],
dtype=arg_spec.dtype,
)
+ elif isinstance(arg_spec, Tuple):
+ elements = type(arg_spec.elements)(
+ [
+ _convert_input(
+ arg_name=arg_name + f"_tmp{i}", arg_spec=arg_spec.elements[i]
+ )
+ for i in range(len(arg_spec.elements))
+ ]
+ )
+ arg = Tuple(
+ name=arg_name,
+ elements=elements,
+ )
else:
raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
+ return arg
+
+ args = []
+ for arg_name, arg_spec in zip(self.arg_names, self.arg_specs):
+ arg = _convert_input(arg_name=arg_name, arg_spec=arg_spec)
args.append(arg)
return args
@@ -159,14 +216,14 @@ class ModuleSpec:
"""A spec for a compiled nn.Module"""
module: core.Module
- method_names: List[str]
- method_specs: List[MethodSpecType]
+ method_names: typing.List[str]
+ method_specs: typing.List[MethodSpecType]
def __init__(
self,
module: core.Module,
- method_names: List[str],
- method_specs: List[MethodSpecType],
+ method_names: typing.List[str],
+ method_specs: typing.List[MethodSpecType],
) -> None:
self.module = module
self.method_names = method_names
@@ -226,10 +283,12 @@ class ExternFunctionSpec:
"""A spec for a compiled external function."""
symbol: str
- args: List[Tensor]
- ret: Union[Tensor, List[Tensor]]
+ args: typing.List[Tensor]
+ ret: typing.Union[Tensor, typing.List[Tensor]]
- def __init__(self, symbol: str, args: List[Tensor], ret: Union[Tensor, List[Tensor]]) -> None:
+ def __init__(
+ self, symbol: str, args: typing.List[Tensor], ret: typing.Union[Tensor, typing.List[Tensor]]
+ ) -> None:
self.symbol = symbol
self.args = args
self.ret = ret
@@ -247,9 +306,9 @@ class ExternModuleSpec:
"""A spec for a compiled external Module."""
filename: str
- functions: List[ExternFunctionSpec]
+ functions: typing.List[ExternFunctionSpec]
- def __init__(self, filename: str, functions: List[ExternFunctionSpec]) -> None:
+ def __init__(self, filename: str, functions: typing.List[ExternFunctionSpec]) -> None:
self.filename = filename
self.functions = functions
@@ -290,11 +349,11 @@ class SpecBuilder:
def build(
self, spec: ModuleSpec, debug: bool = False
- ) -> Tuple[IRModule, List[Tuple[str, core.Parameter]]]:
+ ) -> typing.Tuple[IRModule, typing.List[typing.Tuple[str, core.Parameter]]]:
"""Build the ModuleSpec to TVM IRModule. Returns the IRModule and the parameters."""
# pylint: disable=protected-access
- def _params() -> List[Tuple[str, core.Parameter]]:
+ def _params() -> typing.List[typing.Tuple[str, core.Parameter]]:
params = []
for name, param in core._attribute_finder(
spec.module, prefix="", condition_yield=lambda x: isinstance(x, core.Parameter)
@@ -302,7 +361,7 @@ class SpecBuilder:
params.append((name, param))
return params
- def _effects() -> List[Tuple[str, core.Effect]]:
+ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
result = []
if self.io_effect is not None:
result.append(("", self.io_effect))
@@ -312,7 +371,7 @@ class SpecBuilder:
result.append((name, effect))
return result
- def _extern_modules() -> List[Tuple[str, List[str]]]:
+ def _extern_modules() -> typing.List[typing.Tuple[str, typing.List[str]]]:
mod2func = defaultdict(set)
for _, extern_module in core._attribute_finder(
spec.module, "", condition_yield=lambda x: isinstance(x, core.ExternModule)
@@ -358,7 +417,7 @@ class SpecBuilder:
def _emit_effect_init(
builder: BlockBuilder,
- effects: List[Tuple[str, core.Effect]],
+ effects: typing.List[typing.Tuple[str, core.Effect]],
):
outputs = []
for prefix, effect in effects:
@@ -372,11 +431,11 @@ def _emit_effect_init(
def _emit_method(
builder: BlockBuilder,
spec: MethodSpec,
- params: List[Tuple[str, core.Parameter]],
- effects: Optional[List[Tuple[str, core.Effect]]],
+ params: typing.List[typing.Tuple[str, core.Parameter]],
+ effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
- def _unwrap_ret(expr: Any) -> Any:
+ def _unwrap_ret(expr: typing.Any) -> typing.Any:
if isinstance(expr, core.Tensor):
return expr._expr # pylint: disable=protected-access
if isinstance(expr, tuple):
@@ -390,6 +449,13 @@ def _emit_method(
return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
if isinstance(arg, core.Tensor):
return arg._expr # pylint: disable=protected-access
+ if isinstance(arg, Tuple):
+ return rx.Var(
+ arg.name,
+ struct_info=TupleStructInfo(
+ [_convert_input(arg_i).struct_info for arg_i in arg.elements]
+ ),
+ )
raise TypeError(f"Unsupported input type: {type(arg)}")
explicit_inputs = spec.as_inputs()
@@ -403,6 +469,24 @@ def _emit_method(
inputs.append(param._expr)
# pylint: enable=protected-access
+ def _detuple(arg, var: rx.Var, builder: BlockBuilder):
+ if isinstance(arg, Tuple):
+ ret = []
+ for i in range(len(arg.elements)):
+ field = builder.emit(rx.TupleGetItem(var, i), name_hint=f"{arg.name}_{i}")
+ ret.append(_detuple(arg.elements[i], field, builder))
+ return type(arg.elements)(ret)
+ elif isinstance(arg, core.Tensor):
+ return core.Tensor(_expr=var)
+ elif isinstance(arg, tir.Var):
+ return arg
+ else:
+ raise TypeError(f"Unsupported input type: {type(arg)}")
+
+ for arg_idx, (arg, var) in enumerate(zip(explicit_inputs, inputs)):
+ if isinstance(arg, Tuple):
+ explicit_inputs[arg_idx] = _detuple(arg, var, builder)
+
outputs = spec.method(*explicit_inputs)
effect_outputs = []
for _, effect in effects:
diff --git a/python/tvm/relax/frontend/nn/torch.py b/python/tvm/relax/frontend/nn/torch.py
index 5d8c890845..aa77a11ab0 100644
--- a/python/tvm/relax/frontend/nn/torch.py
+++ b/python/tvm/relax/frontend/nn/torch.py
@@ -43,11 +43,14 @@ class TorchModule: # pylint: disable=too-few-public-methods
vm: VirtualMachine,
params: List[NDArray],
):
- effects = vm["_initialize_effect"]()
+ try:
+ self.effects = vm["_initialize_effect"]()
+ except AttributeError:
+ self.effects = None
+
self.spec = spec
self.vm = vm
self.params = params
- self.effects = effects
def __getitem__(self, method_name: str) -> Callable:
def _find_method(method_name):
@@ -62,7 +65,7 @@ class TorchModule: # pylint: disable=too-few-public-methods
def _closure(*args):
if len(args) != len(method_spec.arg_names):
raise TypeError(
- f"Argument length mismatch. Expected {len(method_spec.args)} arguments, "
+ f"Argument length mismatch. Expected {len(method_spec.arg_names)} arguments, "
f"but got {len(args)} arguments. The spec is: {method_spec}"
)
args = [
@@ -71,14 +74,16 @@ class TorchModule: # pylint: disable=too-few-public-methods
method_spec.arg_names, method_spec.arg_specs, args
)
]
- outputs, self.effects = method(*args, *self.params, *self.effects)
+ if self.effects is not None:
+ outputs, self.effects = method(*args, *self.effects, *self.params)
+ else:
+ outputs = method(*args, *self.params)
return _tvm_to_torch(outputs)
_closure.__name__ = method_name
return _closure
-@staticmethod
def _tvm_to_torch(arg):
if isinstance(arg, (list, tuple, Array)):
return [_tvm_to_torch(i) for i in arg]
@@ -103,6 +108,11 @@ def _torch_to_tvm(arg_name, arg_spec, arg_torch):
f"Expected argument `{arg_name}` to be `int`, but got {type(arg_torch)}"
)
return ShapeTuple([arg_torch])
+ if isinstance(arg_spec, _spec.Tuple):
+ return [
+ _torch_to_tvm(f"{arg_name}[{i}]", x, arg_torch[i])
+ for i, x in enumerate(arg_spec.elements)
+ ]
raise TypeError(f"Unsupported spec item type: {type(arg_spec)}")
diff --git a/tests/python/relax/test_frontend_nn_jit.py b/tests/python/relax/test_frontend_nn_jit.py
new file mode 100644
index 0000000000..4feaaf9aaa
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_jit.py
@@ -0,0 +1,208 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from typing import Tuple, List
+import torch
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.relax.frontend.nn import spec
+from tvm.relax.frontend import nn
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit(debug):
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: nn.Tensor):
+ y = nn.add(x, x)
+ return y
+
+ forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32")}}
+ mod = Layer()
+
+ model = mod.jit(spec=forward_spec, debug=debug)
+
+ x = torch.rand((10, 5), dtype=torch.float32)
+ y = model["forward"](x)
+ assert isinstance(y, torch.Tensor)
+ assert torch.allclose(x + x, y)
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_int_input(debug):
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: nn.Tensor, i: tir.Var):
+ y = nn.add(x, x)
+ y = nn.reshape(y, (i, 5, 5))
+ return y
+
+ forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32"), "i": int}}
+ mod = Layer()
+
+ model = mod.jit(spec=forward_spec, debug=debug)
+
+ x = torch.rand((10, 5), dtype=torch.float32)
+ y = model["forward"](x, 2)
+ assert isinstance(y, torch.Tensor)
+ assert torch.allclose(torch.reshape(x + x, (2, 5, 5)), y)
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_with_effect(debug):
+ class Layer(nn.Module):
+ def __init__(self):
+ self.cache = nn.KVCache(10, [10, 5])
+
+ def forward(self, x: nn.Tensor, total_seq_len: tir.Var):
+ self.cache.append(x)
+ y = self.cache.view(total_seq_len)
+ return y
+
+ forward_spec = {
+ "forward": {"x": spec.Tensor([1, 10, 5], dtype="float32"), "total_seq_len": int}
+ }
+ mod = Layer()
+
+ with tvm.transform.PassContext(opt_level=3):
+ model = mod.jit(spec=forward_spec, debug=debug)
+
+ x0 = torch.rand((1, 10, 5), dtype=torch.float32)
+ y = model["forward"](x0, 1)
+ assert isinstance(y, torch.Tensor)
+ assert torch.allclose(x0, y)
+
+ x1 = torch.rand((1, 10, 5), dtype=torch.float32)
+ y = model["forward"](x1, 2)
+ assert torch.allclose(torch.concat([x0, x1], dim=0), y)
+
+ x2 = torch.rand((1, 10, 5), dtype=torch.float32)
+ y = model["forward"](x2, 3)
+ assert torch.allclose(torch.concat([x0, x1, x2], dim=0), y)
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_tuple_input(debug):
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
+ assert isinstance(x, tuple)
+ x0 = x[0]
+ x1 = x[1]
+ y0 = nn.add(x0, x1)
+ y1 = nn.subtract(x0, x1)
+ return (y0, y1)
+
+ forward_spec = {
+ "forward": {
+ "x": (
+ spec.Tensor([10, 5], dtype="float32"),
+ spec.Tensor([10, 5], dtype="float32"),
+ )
+ }
+ }
+ mod = Layer()
+
+ model = mod.jit(spec=forward_spec, debug=debug)
+
+ x0 = torch.rand((10, 5), dtype=torch.float32)
+ x1 = torch.rand((10, 5), dtype=torch.float32)
+ x = (x0, x1)
+ y = model["forward"](x)
+
+ assert torch.allclose(x0 + x1, y[0])
+ assert torch.allclose(x0 - x1, y[1])
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_list_input(debug):
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: List[nn.Tensor]):
+ assert isinstance(x, list)
+ x0 = x[0]
+ x1 = x[1]
+ y0 = nn.add(x0, x1)
+ y1 = nn.subtract(x0, x1)
+ return (y0, y1)
+
+ forward_spec = {
+ "forward": {
+ "x": [
+ spec.Tensor([10, 5], dtype="float32"),
+ spec.Tensor([10, 5], dtype="float32"),
+ ]
+ }
+ }
+ mod = Layer()
+
+ model = mod.jit(spec=forward_spec, debug=debug)
+
+ x0 = torch.rand((10, 5), dtype=torch.float32)
+ x1 = torch.rand((10, 5), dtype=torch.float32)
+ x = (x0, x1)
+ y = model["forward"](x)
+
+ assert torch.allclose(x0 + x1, y[0])
+ assert torch.allclose(x0 - x1, y[1])
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_tuple_input_with_int(debug):
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: Tuple[nn.Tensor, nn.Tensor, int]):
+ x0 = x[0]
+ x1 = x[1]
+ y0 = nn.add(x0, x1)
+ y1 = nn.subtract(x0, x1)
+ y2 = nn.reshape(x0, (5, x[2], 5))
+ return (y0, y1, y2)
+
+ forward_spec = {
+ "forward": {
+ "x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"), int)
+ }
+ }
+ mod = Layer()
+
+ model = mod.jit(spec=forward_spec, debug=debug)
+
+ x0 = torch.rand((10, 5), dtype=torch.float32)
+ x1 = torch.rand((10, 5), dtype=torch.float32)
+ x = (x0, x1, 2)
+ y0, y1, y2 = model["forward"](x)
+
+ assert torch.allclose(x0 + x1, y0)
+ assert torch.allclose(x0 - x1, y1)
+ assert torch.allclose(torch.reshape(x0, (5, 2, 5)), y2)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py
index 9d10a84f55..524472feb4 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -16,12 +16,14 @@
# under the License.
import numpy as np
import pytest
+from typing import Tuple, List
import tvm
import tvm.testing
from tvm import relax
from tvm.ir import assert_structural_equal
from tvm.relax.frontend.nn import core, modules, spec
+from tvm.relax.frontend import nn
from tvm.script import ir as I
from tvm.script import relax as R
@@ -529,5 +531,83 @@ def test_attention():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_nn_module_tuple_input():
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
+ x0 = x[0]
+ x1 = x[1]
+ y0 = nn.add(x0, x1)
+ y1 = nn.subtract(x0, x1)
+ return (y0, y1)
+
+ # fmt: off
+ @R.function
+ def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv1: R.Tensor((10, 5), dtype="float32") = x[0]
+ lv2: R.Tensor((10, 5), dtype="float32") = x[1]
+ add: R.Tensor((10, 5), dtype="float32") = R.add(lv1, lv2)
+ subtract: R.Tensor((10, 5), dtype="float32") = R.subtract(lv1, lv2)
+ gv1: R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)) = (add, subtract), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ mod = Layer()
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"))
+ }
+ },
+ debug=True,
+ )
+
+ assert_structural_equal(tvm_mod["forward"], forward)
+
+
+def test_nn_module_list_input():
+ class Layer(nn.Module):
+ def __init__(self):
+ pass
+
+ def forward(self, x: List[nn.Tensor]):
+ x0 = x[0]
+ x1 = x[1]
+ y0 = nn.add(x0, x1)
+ y1 = nn.subtract(x0, x1)
+ return [y0, y1]
+
+ # fmt: off
+ @R.function
+ def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv1: R.Tensor((10, 5), dtype="float32") = x[0]
+ lv2: R.Tensor((10, 5), dtype="float32") = x[1]
+ add: R.Tensor((10, 5), dtype="float32") = R.add(lv1, lv2)
+ subtract: R.Tensor((10, 5), dtype="float32") = R.subtract(lv1, lv2)
+ gv1: R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)) = (add, subtract), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ mod = Layer()
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "x": [spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32")]
+ }
+ },
+ debug=True,
+ )
+
+ assert_structural_equal(tvm_mod["forward"], forward)
+
+
if __name__ == "__main__":
tvm.testing.main()