You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/12 21:04:28 UTC

[tvm] branch unity updated: [Unity][Frontend][NN] Enable tuple/list input (#15670)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cdf40be5ca [Unity][Frontend][NN] Enable tuple/list input (#15670)
cdf40be5ca is described below

commit cdf40be5ca7f4090b80b3ddfd09805d55e48a3fb
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Tue Sep 12 14:04:22 2023 -0700

    [Unity][Frontend][NN] Enable tuple/list input (#15670)
    
    - Enable tuple/list input for models, with nesting support.
    - Correct the order of inputs for JIT execution.
---
 python/tvm/relax/frontend/nn/spec.py           | 168 +++++++++++++++-----
 python/tvm/relax/frontend/nn/torch.py          |  20 ++-
 tests/python/relax/test_frontend_nn_jit.py     | 208 +++++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_modules.py |  80 ++++++++++
 4 files changed, 429 insertions(+), 47 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py
index aeecfee782..6c8b42e46c 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -18,7 +18,7 @@
 from collections import defaultdict
 import inspect
 import threading
-from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional
+import typing
 
 from tvm import tir
 from tvm.ir import IRModule
@@ -26,12 +26,13 @@ from tvm.runtime import load_static_library
 
 from ... import expr as rx
 from ...block_builder import BlockBuilder
-from ...struct_info import ShapeStructInfo
+from ...struct_info import ShapeStructInfo, TupleStructInfo
 from . import core
 
-ArgSpecType = Union["Int", "Tensor"]
-MethodSpecType = Union["MethodSpec", Dict[str, ArgSpecType]]
-ModuleSpecType = Union["ModuleSpec", Dict[str, MethodSpecType]]
+ArgSpecType = typing.Union["Int", "Tensor"]
+MethodSpecType = typing.Union["MethodSpec", typing.Dict[str, ArgSpecType]]
+ModuleSpecType = typing.Union["ModuleSpec", typing.Dict[str, MethodSpecType]]
+SpecAny = typing.Union["Int", "Tensor", "Tuple"]
 
 
 class Int:  # pylint: disable=too-few-public-methods
@@ -47,10 +48,10 @@ class Int:  # pylint: disable=too-few-public-methods
 class Tensor:  # pylint: disable=too-few-public-methods
     """A tensor input with static ndim and dtype, but can have symbolic shapes."""
 
-    shape: List[Union[int, str]]
+    shape: typing.List[typing.Union[int, str]]
     dtype: str
 
-    def __init__(self, shape: Sequence[Union[int, str]], dtype: str) -> None:
+    def __init__(self, shape: typing.Sequence[typing.Union[int, str]], dtype: str) -> None:
         self.shape = list(shape)
         self.dtype = dtype
 
@@ -59,14 +60,38 @@ class Tensor:  # pylint: disable=too-few-public-methods
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class Tuple:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: typing.Union[typing.List[SpecAny], typing.Tuple[SpecAny, ...]]
+
+    def __init__(
+        self,
+        name: str,
+        elements: typing.Union[typing.List[SpecAny], typing.Tuple[SpecAny, ...]],
+    ) -> None:
+        assert isinstance(elements, (tuple, list)), f"Unsupported container type: {type(elements)}"
+        self.name = name
+        self.elements = elements
+
+    def __repr__(self) -> str:
+        return self.elements.__repr__()
+
+
 class MethodSpec:
     """A spec for a compiled method"""
 
-    method: Callable
-    arg_names: List[str]
-    arg_specs: List[ArgSpecType]
+    method: typing.Callable
+    arg_names: typing.List[str]
+    arg_specs: typing.List[ArgSpecType]
 
-    def __init__(self, method: Callable, arg_names: List[str], arg_specs: List[ArgSpecType]):
+    def __init__(
+        self,
+        method: typing.Callable,
+        arg_names: typing.List[str],
+        arg_specs: typing.List[ArgSpecType],
+    ):
         self.method = method
         self.arg_names = arg_names
         self.arg_specs = arg_specs
@@ -85,7 +110,7 @@ class MethodSpec:
         return self._repr(name="MethodSpec")
 
     @staticmethod
-    def from_raw(spec: MethodSpecType, method: Callable) -> "MethodSpec":
+    def from_raw(spec: MethodSpecType, method: typing.Callable) -> "MethodSpec":
         """Create MethodSpec from raw python dictionaries.
 
         Examples
@@ -105,22 +130,37 @@ class MethodSpec:
         method_signature = inspect.signature(method)
         arg_names = list(method_signature.parameters.keys())
         arg_specs = []
+
+        def _convert_arg_spec(arg_spec, arg_name):
+            if arg_spec is Int or arg_spec is int:
+                return Int()
+            elif isinstance(arg_spec, str) and arg_spec == "int":
+                return Int()
+            elif isinstance(arg_spec, (Int, Tensor)):
+                return arg_spec
+            elif isinstance(arg_spec, (tuple, list, Tuple)):
+                return Tuple(
+                    arg_name,
+                    elements=type(arg_spec)(
+                        [
+                            _convert_arg_spec(arg_spec_i, f"{arg_name}_{i}")
+                            for i, arg_spec_i in enumerate(arg_spec)
+                        ]
+                    ),
+                )
+
+            else:
+                raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
+
         for arg_name in arg_names:
             if arg_name in spec:
                 arg_spec = spec[arg_name]
-                if arg_spec is Int or arg_spec is int:
-                    arg_spec = Int()
-                elif isinstance(arg_spec, str) and arg_spec == "int":
-                    arg_spec = Int()
-                elif isinstance(arg_spec, (Int, Tensor)):
-                    pass
-                else:
-                    raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
+                arg_spec = _convert_arg_spec(arg_spec, arg_name)
                 arg_specs.append(arg_spec)
         return MethodSpec(method, arg_names, arg_specs)
 
     @staticmethod
-    def from_torch(args: List[Any], method: Callable) -> "MethodSpec":
+    def from_torch(args: typing.List[typing.Any], method: typing.Callable) -> "MethodSpec":
         """Converts a list of torch tensors to MethodSpec."""
         from .torch import (  # pylint: disable=import-outside-toplevel
             _method_spec_from_torch,
@@ -128,9 +168,9 @@ class MethodSpec:
 
         return _method_spec_from_torch(args, method)
 
-    def as_inputs(self) -> List[Union[tir.Var, core.Tensor]]:
+    def as_inputs(self) -> typing.List[typing.Union[tir.Var, core.Tensor]]:
         """Convert the MethodSpec to a list of inputs to Module's method."""
-        str2var: Dict[str, tir.Var] = {}
+        str2var: typing.Dict[str, tir.Var] = {}
 
         def _get_var(name: str) -> tir.Var:
             if name in str2var:
@@ -139,8 +179,7 @@ class MethodSpec:
             str2var[name] = var
             return var
 
-        args = []
-        for arg_name, arg_spec in zip(self.arg_names, self.arg_specs):
+        def _convert_input(arg_name, arg_spec):
             if isinstance(arg_spec, Int):
                 arg = _get_var(arg_name)
             elif isinstance(arg_spec, Tensor):
@@ -149,8 +188,26 @@ class MethodSpec:
                     shape=[_get_var(x) if isinstance(x, str) else x for x in arg_spec.shape],
                     dtype=arg_spec.dtype,
                 )
+            elif isinstance(arg_spec, Tuple):
+                elements = type(arg_spec.elements)(
+                    [
+                        _convert_input(
+                            arg_name=arg_name + f"_tmp{i}", arg_spec=arg_spec.elements[i]
+                        )
+                        for i in range(len(arg_spec.elements))
+                    ]
+                )
+                arg = Tuple(
+                    name=arg_name,
+                    elements=elements,
+                )
             else:
                 raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
+            return arg
+
+        args = []
+        for arg_name, arg_spec in zip(self.arg_names, self.arg_specs):
+            arg = _convert_input(arg_name=arg_name, arg_spec=arg_spec)
             args.append(arg)
         return args
 
@@ -159,14 +216,14 @@ class ModuleSpec:
     """A spec for a compiled nn.Module"""
 
     module: core.Module
-    method_names: List[str]
-    method_specs: List[MethodSpecType]
+    method_names: typing.List[str]
+    method_specs: typing.List[MethodSpecType]
 
     def __init__(
         self,
         module: core.Module,
-        method_names: List[str],
-        method_specs: List[MethodSpecType],
+        method_names: typing.List[str],
+        method_specs: typing.List[MethodSpecType],
     ) -> None:
         self.module = module
         self.method_names = method_names
@@ -226,10 +283,12 @@ class ExternFunctionSpec:
     """A spec for a compiled external function."""
 
     symbol: str
-    args: List[Tensor]
-    ret: Union[Tensor, List[Tensor]]
+    args: typing.List[Tensor]
+    ret: typing.Union[Tensor, typing.List[Tensor]]
 
-    def __init__(self, symbol: str, args: List[Tensor], ret: Union[Tensor, List[Tensor]]) -> None:
+    def __init__(
+        self, symbol: str, args: typing.List[Tensor], ret: typing.Union[Tensor, typing.List[Tensor]]
+    ) -> None:
         self.symbol = symbol
         self.args = args
         self.ret = ret
@@ -247,9 +306,9 @@ class ExternModuleSpec:
     """A spec for a compiled external Module."""
 
     filename: str
-    functions: List[ExternFunctionSpec]
+    functions: typing.List[ExternFunctionSpec]
 
-    def __init__(self, filename: str, functions: List[ExternFunctionSpec]) -> None:
+    def __init__(self, filename: str, functions: typing.List[ExternFunctionSpec]) -> None:
         self.filename = filename
         self.functions = functions
 
@@ -290,11 +349,11 @@ class SpecBuilder:
 
     def build(
         self, spec: ModuleSpec, debug: bool = False
-    ) -> Tuple[IRModule, List[Tuple[str, core.Parameter]]]:
+    ) -> typing.Tuple[IRModule, typing.List[typing.Tuple[str, core.Parameter]]]:
         """Build the ModuleSpec to TVM IRModule. Returns the IRModule and the parameters."""
 
         # pylint: disable=protected-access
-        def _params() -> List[Tuple[str, core.Parameter]]:
+        def _params() -> typing.List[typing.Tuple[str, core.Parameter]]:
             params = []
             for name, param in core._attribute_finder(
                 spec.module, prefix="", condition_yield=lambda x: isinstance(x, core.Parameter)
@@ -302,7 +361,7 @@ class SpecBuilder:
                 params.append((name, param))
             return params
 
-        def _effects() -> List[Tuple[str, core.Effect]]:
+        def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
             result = []
             if self.io_effect is not None:
                 result.append(("", self.io_effect))
@@ -312,7 +371,7 @@ class SpecBuilder:
                 result.append((name, effect))
             return result
 
-        def _extern_modules() -> List[Tuple[str, List[str]]]:
+        def _extern_modules() -> typing.List[typing.Tuple[str, typing.List[str]]]:
             mod2func = defaultdict(set)
             for _, extern_module in core._attribute_finder(
                 spec.module, "", condition_yield=lambda x: isinstance(x, core.ExternModule)
@@ -358,7 +417,7 @@ class SpecBuilder:
 
 def _emit_effect_init(
     builder: BlockBuilder,
-    effects: List[Tuple[str, core.Effect]],
+    effects: typing.List[typing.Tuple[str, core.Effect]],
 ):
     outputs = []
     for prefix, effect in effects:
@@ -372,11 +431,11 @@ def _emit_effect_init(
 def _emit_method(
     builder: BlockBuilder,
     spec: MethodSpec,
-    params: List[Tuple[str, core.Parameter]],
-    effects: Optional[List[Tuple[str, core.Effect]]],
+    params: typing.List[typing.Tuple[str, core.Parameter]],
+    effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
 ):
     # pylint: disable=protected-access
-    def _unwrap_ret(expr: Any) -> Any:
+    def _unwrap_ret(expr: typing.Any) -> typing.Any:
         if isinstance(expr, core.Tensor):
             return expr._expr  # pylint: disable=protected-access
         if isinstance(expr, tuple):
@@ -390,6 +449,13 @@ def _emit_method(
             return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
         if isinstance(arg, core.Tensor):
             return arg._expr  # pylint: disable=protected-access
+        if isinstance(arg, Tuple):
+            return rx.Var(
+                arg.name,
+                struct_info=TupleStructInfo(
+                    [_convert_input(arg_i).struct_info for arg_i in arg.elements]
+                ),
+            )
         raise TypeError(f"Unsupported input type: {type(arg)}")
 
     explicit_inputs = spec.as_inputs()
@@ -403,6 +469,24 @@ def _emit_method(
         inputs.append(param._expr)
         # pylint: enable=protected-access
 
+    def _detuple(arg, var: rx.Var, builder: BlockBuilder):
+        if isinstance(arg, Tuple):
+            ret = []
+            for i in range(len(arg.elements)):
+                field = builder.emit(rx.TupleGetItem(var, i), name_hint=f"{arg.name}_{i}")
+                ret.append(_detuple(arg.elements[i], field, builder))
+            return type(arg.elements)(ret)
+        elif isinstance(arg, core.Tensor):
+            return core.Tensor(_expr=var)
+        elif isinstance(arg, tir.Var):
+            return arg
+        else:
+            raise TypeError(f"Unsupported input type: {type(arg)}")
+
+    for arg_idx, (arg, var) in enumerate(zip(explicit_inputs, inputs)):
+        if isinstance(arg, Tuple):
+            explicit_inputs[arg_idx] = _detuple(arg, var, builder)
+
     outputs = spec.method(*explicit_inputs)
     effect_outputs = []
     for _, effect in effects:
diff --git a/python/tvm/relax/frontend/nn/torch.py b/python/tvm/relax/frontend/nn/torch.py
index 5d8c890845..aa77a11ab0 100644
--- a/python/tvm/relax/frontend/nn/torch.py
+++ b/python/tvm/relax/frontend/nn/torch.py
@@ -43,11 +43,14 @@ class TorchModule:  # pylint: disable=too-few-public-methods
         vm: VirtualMachine,
         params: List[NDArray],
     ):
-        effects = vm["_initialize_effect"]()
+        try:
+            self.effects = vm["_initialize_effect"]()
+        except AttributeError:
+            self.effects = None
+
         self.spec = spec
         self.vm = vm
         self.params = params
-        self.effects = effects
 
     def __getitem__(self, method_name: str) -> Callable:
         def _find_method(method_name):
@@ -62,7 +65,7 @@ class TorchModule:  # pylint: disable=too-few-public-methods
         def _closure(*args):
             if len(args) != len(method_spec.arg_names):
                 raise TypeError(
-                    f"Argument length mismatch. Expected {len(method_spec.args)} arguments, "
+                    f"Argument length mismatch. Expected {len(method_spec.arg_names)} arguments, "
                     f"but got {len(args)} arguments. The spec is: {method_spec}"
                 )
             args = [
@@ -71,14 +74,16 @@ class TorchModule:  # pylint: disable=too-few-public-methods
                     method_spec.arg_names, method_spec.arg_specs, args
                 )
             ]
-            outputs, self.effects = method(*args, *self.params, *self.effects)
+            if self.effects is not None:
+                outputs, self.effects = method(*args, *self.effects, *self.params)
+            else:
+                outputs = method(*args, *self.params)
             return _tvm_to_torch(outputs)
 
         _closure.__name__ = method_name
         return _closure
 
 
-@staticmethod
 def _tvm_to_torch(arg):
     if isinstance(arg, (list, tuple, Array)):
         return [_tvm_to_torch(i) for i in arg]
@@ -103,6 +108,11 @@ def _torch_to_tvm(arg_name, arg_spec, arg_torch):
                 f"Expected argument `{arg_name}` to be `int`, but got {type(arg_torch)}"
             )
         return ShapeTuple([arg_torch])
+    if isinstance(arg_spec, _spec.Tuple):
+        return [
+            _torch_to_tvm(f"{arg_name}[{i}]", x, arg_torch[i])
+            for i, x in enumerate(arg_spec.elements)
+        ]
     raise TypeError(f"Unsupported spec item type: {type(arg_spec)}")
 
 
diff --git a/tests/python/relax/test_frontend_nn_jit.py b/tests/python/relax/test_frontend_nn_jit.py
new file mode 100644
index 0000000000..4feaaf9aaa
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_jit.py
@@ -0,0 +1,208 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from typing import Tuple, List
+import torch
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.relax.frontend.nn import spec
+from tvm.relax.frontend import nn
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit(debug):
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: nn.Tensor):
+            y = nn.add(x, x)
+            return y
+
+    forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32")}}
+    mod = Layer()
+
+    model = mod.jit(spec=forward_spec, debug=debug)
+
+    x = torch.rand((10, 5), dtype=torch.float32)
+    y = model["forward"](x)
+    assert isinstance(y, torch.Tensor)
+    assert torch.allclose(x + x, y)
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_int_input(debug):
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: nn.Tensor, i: tir.Var):
+            y = nn.add(x, x)
+            y = nn.reshape(y, (i, 5, 5))
+            return y
+
+    forward_spec = {"forward": {"x": spec.Tensor([10, 5], dtype="float32"), "i": int}}
+    mod = Layer()
+
+    model = mod.jit(spec=forward_spec, debug=debug)
+
+    x = torch.rand((10, 5), dtype=torch.float32)
+    y = model["forward"](x, 2)
+    assert isinstance(y, torch.Tensor)
+    assert torch.allclose(torch.reshape(x + x, (2, 5, 5)), y)
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_with_effect(debug):
+    class Layer(nn.Module):
+        def __init__(self):
+            self.cache = nn.KVCache(10, [10, 5])
+
+        def forward(self, x: nn.Tensor, total_seq_len: tir.Var):
+            self.cache.append(x)
+            y = self.cache.view(total_seq_len)
+            return y
+
+    forward_spec = {
+        "forward": {"x": spec.Tensor([1, 10, 5], dtype="float32"), "total_seq_len": int}
+    }
+    mod = Layer()
+
+    with tvm.transform.PassContext(opt_level=3):
+        model = mod.jit(spec=forward_spec, debug=debug)
+
+    x0 = torch.rand((1, 10, 5), dtype=torch.float32)
+    y = model["forward"](x0, 1)
+    assert isinstance(y, torch.Tensor)
+    assert torch.allclose(x0, y)
+
+    x1 = torch.rand((1, 10, 5), dtype=torch.float32)
+    y = model["forward"](x1, 2)
+    assert torch.allclose(torch.concat([x0, x1], dim=0), y)
+
+    x2 = torch.rand((1, 10, 5), dtype=torch.float32)
+    y = model["forward"](x2, 3)
+    assert torch.allclose(torch.concat([x0, x1, x2], dim=0), y)
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_tuple_input(debug):
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
+            assert isinstance(x, tuple)
+            x0 = x[0]
+            x1 = x[1]
+            y0 = nn.add(x0, x1)
+            y1 = nn.subtract(x0, x1)
+            return (y0, y1)
+
+    forward_spec = {
+        "forward": {
+            "x": (
+                spec.Tensor([10, 5], dtype="float32"),
+                spec.Tensor([10, 5], dtype="float32"),
+            )
+        }
+    }
+    mod = Layer()
+
+    model = mod.jit(spec=forward_spec, debug=debug)
+
+    x0 = torch.rand((10, 5), dtype=torch.float32)
+    x1 = torch.rand((10, 5), dtype=torch.float32)
+    x = (x0, x1)
+    y = model["forward"](x)
+
+    assert torch.allclose(x0 + x1, y[0])
+    assert torch.allclose(x0 - x1, y[1])
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_list_input(debug):
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: List[nn.Tensor]):
+            assert isinstance(x, list)
+            x0 = x[0]
+            x1 = x[1]
+            y0 = nn.add(x0, x1)
+            y1 = nn.subtract(x0, x1)
+            return (y0, y1)
+
+    forward_spec = {
+        "forward": {
+            "x": [
+                spec.Tensor([10, 5], dtype="float32"),
+                spec.Tensor([10, 5], dtype="float32"),
+            ]
+        }
+    }
+    mod = Layer()
+
+    model = mod.jit(spec=forward_spec, debug=debug)
+
+    x0 = torch.rand((10, 5), dtype=torch.float32)
+    x1 = torch.rand((10, 5), dtype=torch.float32)
+    x = (x0, x1)
+    y = model["forward"](x)
+
+    assert torch.allclose(x0 + x1, y[0])
+    assert torch.allclose(x0 - x1, y[1])
+
+
+@pytest.mark.parametrize("debug", [True, False])
+def test_jit_tuple_input_with_int(debug):
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: Tuple[nn.Tensor, nn.Tensor, int]):
+            x0 = x[0]
+            x1 = x[1]
+            y0 = nn.add(x0, x1)
+            y1 = nn.subtract(x0, x1)
+            y2 = nn.reshape(x0, (5, x[2], 5))
+            return (y0, y1, y2)
+
+    forward_spec = {
+        "forward": {
+            "x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"), int)
+        }
+    }
+    mod = Layer()
+
+    model = mod.jit(spec=forward_spec, debug=debug)
+
+    x0 = torch.rand((10, 5), dtype=torch.float32)
+    x1 = torch.rand((10, 5), dtype=torch.float32)
+    x = (x0, x1, 2)
+    y0, y1, y2 = model["forward"](x)
+
+    assert torch.allclose(x0 + x1, y0)
+    assert torch.allclose(x0 - x1, y1)
+    assert torch.allclose(torch.reshape(x0, (5, 2, 5)), y2)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py
index 9d10a84f55..524472feb4 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -16,12 +16,14 @@
 # under the License.
 import numpy as np
 import pytest
+from typing import Tuple, List
 
 import tvm
 import tvm.testing
 from tvm import relax
 from tvm.ir import assert_structural_equal
 from tvm.relax.frontend.nn import core, modules, spec
+from tvm.relax.frontend import nn
 from tvm.script import ir as I
 from tvm.script import relax as R
 
@@ -529,5 +531,83 @@ def test_attention():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_nn_module_tuple_input():
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
+            x0 = x[0]
+            x1 = x[1]
+            y0 = nn.add(x0, x1)
+            y1 = nn.subtract(x0, x1)
+            return (y0, y1)
+
+    # fmt: off
+    @R.function
+    def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)):
+        R.func_attr({"num_input": 2})
+        with R.dataflow():
+            lv1: R.Tensor((10, 5), dtype="float32") = x[0]
+            lv2: R.Tensor((10, 5), dtype="float32") = x[1]
+            add: R.Tensor((10, 5), dtype="float32") = R.add(lv1, lv2)
+            subtract: R.Tensor((10, 5), dtype="float32") = R.subtract(lv1, lv2)
+            gv1: R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)) = (add, subtract), (_io,)
+            R.output(gv1)
+        return gv1
+    # fmt: on
+
+    mod = Layer()
+    tvm_mod, _ = mod.export_tvm(
+        spec={
+            "forward": {
+                "x": (spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32"))
+            }
+        },
+        debug=True,
+    )
+
+    assert_structural_equal(tvm_mod["forward"], forward)
+
+
+def test_nn_module_list_input():
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: List[nn.Tensor]):
+            x0 = x[0]
+            x1 = x[1]
+            y0 = nn.add(x0, x1)
+            y1 = nn.subtract(x0, x1)
+            return [y0, y1]
+
+    # fmt: off
+    @R.function
+    def forward(x: R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)):
+        R.func_attr({"num_input": 2})
+        with R.dataflow():
+            lv1: R.Tensor((10, 5), dtype="float32") = x[0]
+            lv2: R.Tensor((10, 5), dtype="float32") = x[1]
+            add: R.Tensor((10, 5), dtype="float32") = R.add(lv1, lv2)
+            subtract: R.Tensor((10, 5), dtype="float32") = R.subtract(lv1, lv2)
+            gv1: R.Tuple(R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((10, 5), dtype="float32")), R.Tuple(R.Object)) = (add, subtract), (_io,)
+            R.output(gv1)
+        return gv1
+    # fmt: on
+
+    mod = Layer()
+    tvm_mod, _ = mod.export_tvm(
+        spec={
+            "forward": {
+                "x": [spec.Tensor([10, 5], dtype="float32"), spec.Tensor([10, 5], dtype="float32")]
+            }
+        },
+        debug=True,
+    )
+
+    assert_structural_equal(tvm_mod["forward"], forward)
+
+
 if __name__ == "__main__":
     tvm.testing.main()