You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "junrushao (via GitHub)" <gi...@apache.org> on 2023/09/05 19:01:10 UTC

[GitHub] [tvm] junrushao commented on a diff in pull request #15670: [Unity][Frontend][NN] Enable tuple/list input

junrushao commented on code in PR #15670:
URL: https://github.com/apache/tvm/pull/15670#discussion_r1316271852


##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -39,6 +39,7 @@
 
 import numpy as np
 
+import tvm

Review Comment:
   don't import tvm directly inside the tvm pacakge in case of any unexpected behavior. instead please import a specific item you need, for example, `from tvm.transform import PassContent`



##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -411,6 +412,10 @@ def jit(  # pylint: disable=too-many-arguments
 
         # Compile mod and feed it to VM
         mod = relax.pipeline.get_pipeline(pipeline)(mod)  # pylint: disable=no-value-for-parameter
+
+        if device != "cpu":
+            with target, tvm.transform.PassContext(opt_level=3):
+                mod = tvm.tir.transform.DefaultGPUSchedule()(mod)

Review Comment:
   to be clear, the system is designed to work instead of `relax.pipeline` (see line 414 above), meaning users are expected to provide the name of the pipeline that covers this set of transformations, instead of hardcoding a DefaultGPUSchedule pass



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:

Review Comment:
   Technically tuple and list are two different things. Intuitively, you can see `Tuple` as a fixed length array whose elements are known typed, for example, `Tuple[int, str, float]` means a tuple of 3 elements `int`, `str`, `float`; `List`, if it's homogenous, means a variable length array of the same type, for example, `List[int]` means a list of integers.
   
   Therefore, you may instead use a more indicative name, i.e. `Tuple`, instead of `TupleList`. If it name conflicts with `typing.Tuple`, just avoid importing it from the `typing` package.



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: Union[List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, "TupleList"]]]
+
+    def __init__(
+        self,
+        name: str,
+        elements: Union[
+            List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, "TupleList"]]
+        ],
+    ) -> None:
+        assert type(elements) in [tuple, list], f"Unsupported container type: {type(elements)}"
+        for i, e in enumerate(elements):  # pylint: disable=invalid-name
+            assert isinstance(e, (core.Tensor, TupleList)), (
+                f"Expected all elements in the {name} tuple/list to be of type Tensor/tuple/list, "
+                f"but found a {type(e)} at index {i}."
+            )
+        self.name = name
+        self.elements = elements

Review Comment:
   always consolidate by casting `elements` to python `list`



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -105,17 +131,23 @@ def from_raw(spec: MethodSpecType, method: Callable) -> "MethodSpec":
         method_signature = inspect.signature(method)
         arg_names = list(method_signature.parameters.keys())
         arg_specs = []
+
+        def _convert_arg_spec(arg_spec):
+            if arg_spec is Int or arg_spec is int:
+                return Int()
+            elif isinstance(arg_spec, str) and arg_spec == "int":
+                return Int()
+            elif isinstance(arg_spec, (Int, Tensor)):
+                return arg_spec
+            elif isinstance(arg_spec, (tuple, list)):
+                return type(arg_spec)([_convert_arg_spec(arg_spec_i) for arg_spec_i in arg_spec])
+            else:
+                raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")

Review Comment:
   Make it a global method. `parse_spec(arg_spec, arg_name)`



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: Union[List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, "TupleList"]]]
+
+    def __init__(
+        self,
+        name: str,
+        elements: Union[
+            List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, "TupleList"]]
+        ],
+    ) -> None:
+        assert type(elements) in [tuple, list], f"Unsupported container type: {type(elements)}"
+        for i, e in enumerate(elements):  # pylint: disable=invalid-name
+            assert isinstance(e, (core.Tensor, TupleList)), (
+                f"Expected all elements in the {name} tuple/list to be of type Tensor/tuple/list, "
+                f"but found a {type(e)} at index {i}."
+            )

Review Comment:
   There isn't much point to limit its types to `Tuple`/`Tensor` but not `int` - it's possible to make `int` work as well



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: Union[List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, "TupleList"]]]

Review Comment:
   Define a global type `SpecAny`:
   
   ```
   SpecAny = Union["Int", "Tensor", "Tuple"]
   ```
   
   Then define the elements in Tuple as:
   
   ```
   elements: List[SpecAny]
   ```



##########
tests/python/relax/test_frontend_nn_modules.py:
##########
@@ -529,5 +531,83 @@ def forward(
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_nn_module_tuple_input_output():
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: Tuple[nn.Tensor]):

Review Comment:
   `Tuple[nn.Tensor]` means a tuple of length 1 that contains a tensor
   
   ```suggestion
           def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
   ```



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -105,17 +131,23 @@ def from_raw(spec: MethodSpecType, method: Callable) -> "MethodSpec":
         method_signature = inspect.signature(method)
         arg_names = list(method_signature.parameters.keys())
         arg_specs = []
+
+        def _convert_arg_spec(arg_spec):
+            if arg_spec is Int or arg_spec is int:
+                return Int()
+            elif isinstance(arg_spec, str) and arg_spec == "int":
+                return Int()
+            elif isinstance(arg_spec, (Int, Tensor)):
+                return arg_spec
+            elif isinstance(arg_spec, (tuple, list)):

Review Comment:
   ```suggestion
               elif isinstance(arg_spec, (tuple, list, TupleList)):
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org