You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/03/06 19:23:36 UTC

[tvm] branch unity updated: [Unity][Frontend] FX translator supports unwrapping unit return tuple (#14212)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cc991d547f [Unity][Frontend] FX translator supports unwrapping unit return tuple (#14212)
cc991d547f is described below

commit cc991d547f03a377f179e0def9717448041501b9
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Mar 6 14:23:29 2023 -0500

    [Unity][Frontend] FX translator supports unwrapping unit return tuple (#14212)
    
    Previously we found that the FX GraphModule captured by torch.dynamo
    will always return a tuple at the end, even if the Module being traced
    returns a single object. Unwrapping the unit tuple in that case can
    help and ease model deployment. Therefore, this PR introduces an option
    "unwrap_return_unit_tuple" to `from_fx`, to indicate if we caller wants
    to unwrap the returned unit tuple.
    
    `dynamo_subgraph_capture` now will always use True on this option.
---
 python/tvm/relax/frontend/torch/dynamo.py        |  7 +++++-
 python/tvm/relax/frontend/torch/fx_translator.py | 30 ++++++++++++++++++++----
 tests/python/relax/test_frontend_dynamo.py       |  8 +++----
 tests/python/relax/test_frontend_from_fx.py      | 29 +++++++++++++++++++++++
 4 files changed, 65 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py
index a7fb9bc015..c71c1fbc84 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -148,7 +148,12 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) -> tvm.IRModule:
     def _capture(graph_module: fx.GraphModule, example_inputs):
         assert isinstance(graph_module, torch.fx.GraphModule)
         input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
-        mod_ = from_fx(graph_module, input_info, keep_params_as_input)
+        mod_ = from_fx(
+            graph_module,
+            input_info,
+            keep_params_as_input=keep_params_as_input,
+            unwrap_unit_return_tuple=True,
+        )
         mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"]
         return graph_module.forward
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index b4a77ccb33..d865984a14 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -896,7 +896,11 @@ class TorchFXImporter:
         }
 
     def from_fx(
-        self, model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool
+        self,
+        model,
+        input_info: List[Tuple[Tuple[int], str]],
+        keep_params_as_input: bool,
+        unwrap_unit_return_tuple: bool,
     ) -> tvm.IRModule:
         """Convert a PyTorch FX GraphModule to a Relax program."""
         from torch import fx
@@ -947,7 +951,15 @@ class TorchFXImporter:
                         self.env[node] = inputs.pop(0)
                     elif node.op == "output":
                         args = self.retrieve_args(node)
-                        output = self.block_builder.emit_output(args[0])
+                        assert len(args) == 1
+                        if (
+                            unwrap_unit_return_tuple
+                            and isinstance(args[0], (tuple, relax.Tuple))
+                            and len(args[0]) == 1
+                        ):
+                            output = self.block_builder.emit_output(args[0][0])
+                        else:
+                            output = self.block_builder.emit_output(args[0])
                         break
                     elif node.op == "get_attr":
                         self.env[node] = TorchFXImporter._fetch_attr(model, node.target)
@@ -980,7 +992,11 @@ class TorchFXImporter:
 
 
 def from_fx(
-    model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool = False
+    model,
+    input_info: List[Tuple[Tuple[int], str]],
+    *,
+    keep_params_as_input: bool = False,
+    unwrap_unit_return_tuple: bool = False,
 ) -> tvm.IRModule:
     """Convert a PyTorch FX GraphModule to a Relax program
 
@@ -995,6 +1011,10 @@ def from_fx(
     keep_params_as_input : bool
         Whether to keep model parameters as input variables.
 
+    unwrap_unit_return_tuple : bool
+        A boolean flag indicating if to the return value when it is an unit tuple.
+        When the return value is not a unit tuple, no unwrap will take place.
+
     Returns
     -------
     output : tvm.IRModule
@@ -1062,4 +1082,6 @@ def from_fx(
     to print out the tabular representation of the PyTorch module, and then
     check the placeholder rows in the beginning of the tabular.
     """
-    return TorchFXImporter().from_fx(model, input_info, keep_params_as_input)
+    return TorchFXImporter().from_fx(
+        model, input_info, keep_params_as_input, unwrap_unit_return_tuple
+    )
diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py
index b47e3e22bd..192f8e8b10 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -135,14 +135,14 @@ def test_subgraph_capture():
             inp_0: R.Tensor((10, 100), dtype="float32"),
             w0: R.Tensor((10, 100), dtype="float32"),
             w1: R.Tensor((10,), dtype="float32"),
-        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+        ) -> R.Tensor((10, 10), dtype="float32"):
             # block 0
             with R.dataflow():
                 lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None)
                 lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
                 lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
                 lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+                gv: R.Tensor((10, 10), dtype="float32") = lv3
                 R.output(gv)
             return gv
 
@@ -182,11 +182,11 @@ def test_subgraph_capture():
         @R.function
         def subgraph_1(
             inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), dtype="float32")
-        ) -> R.Tuple(R.Tensor((10,), dtype="float32")):
+        ) -> R.Tensor((10,), dtype="float32"):
             # block 0
             with R.dataflow():
                 lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01)
-                gv1: R.Tuple(R.Tensor((10,), dtype="float32")) = (lv5,)
+                gv1: R.Tensor((10,), dtype="float32") = lv5
                 R.output(gv1)
             return gv1
 
diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py
index e36be8c3c8..458efd7fcb 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2229,6 +2229,35 @@ def test_keep_params():
     tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().numpy())
 
 
+@tvm.testing.requires_gpu
+def test_unwrap_unit_return_tuple():
+    import torch.fx as fx
+    from torch.nn import Module
+    from tvm.relax.frontend.torch import from_fx
+
+    class Identity(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return (x,)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((256, 256), dtype="float32") = inp_0
+                R.output(gv)
+            return gv
+
+    graph_model = fx.symbolic_trace(Identity())
+    mod = from_fx(graph_model, [([256, 256], "float32")], unwrap_unit_return_tuple=True)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 @tvm.testing.requires_gpu
 def test_argmax():
     import torch