You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/03/06 19:23:36 UTC
[tvm] branch unity updated: [Unity][Frontend] FX translator supports unwrapping unit return tuple (#14212)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cc991d547f [Unity][Frontend] FX translator supports unwrapping unit return tuple (#14212)
cc991d547f is described below
commit cc991d547f03a377f179e0def9717448041501b9
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Mar 6 14:23:29 2023 -0500
[Unity][Frontend] FX translator supports unwrapping unit return tuple (#14212)
Previously we found that the FX GraphModule captured by torch.dynamo
will always return a tuple at the end, even if the Module being traced
returns a single object. Unwrapping the unit tuple in that case can
help and ease model deployment. Therefore, this PR introduces an option
"unwrap_return_unit_tuple" to `from_fx`, to indicate if we caller wants
to unwrap the returned unit tuple.
`dynamo_subgraph_capture` now will always use True on this option.
---
python/tvm/relax/frontend/torch/dynamo.py | 7 +++++-
python/tvm/relax/frontend/torch/fx_translator.py | 30 ++++++++++++++++++++----
tests/python/relax/test_frontend_dynamo.py | 8 +++----
tests/python/relax/test_frontend_from_fx.py | 29 +++++++++++++++++++++++
4 files changed, 65 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py
index a7fb9bc015..c71c1fbc84 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -148,7 +148,12 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) -> tvm.IRModule:
def _capture(graph_module: fx.GraphModule, example_inputs):
assert isinstance(graph_module, torch.fx.GraphModule)
input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
- mod_ = from_fx(graph_module, input_info, keep_params_as_input)
+ mod_ = from_fx(
+ graph_module,
+ input_info,
+ keep_params_as_input=keep_params_as_input,
+ unwrap_unit_return_tuple=True,
+ )
mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"]
return graph_module.forward
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index b4a77ccb33..d865984a14 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -896,7 +896,11 @@ class TorchFXImporter:
}
def from_fx(
- self, model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool
+ self,
+ model,
+ input_info: List[Tuple[Tuple[int], str]],
+ keep_params_as_input: bool,
+ unwrap_unit_return_tuple: bool,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx
@@ -947,7 +951,15 @@ class TorchFXImporter:
self.env[node] = inputs.pop(0)
elif node.op == "output":
args = self.retrieve_args(node)
- output = self.block_builder.emit_output(args[0])
+ assert len(args) == 1
+ if (
+ unwrap_unit_return_tuple
+ and isinstance(args[0], (tuple, relax.Tuple))
+ and len(args[0]) == 1
+ ):
+ output = self.block_builder.emit_output(args[0][0])
+ else:
+ output = self.block_builder.emit_output(args[0])
break
elif node.op == "get_attr":
self.env[node] = TorchFXImporter._fetch_attr(model, node.target)
@@ -980,7 +992,11 @@ class TorchFXImporter:
def from_fx(
- model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool = False
+ model,
+ input_info: List[Tuple[Tuple[int], str]],
+ *,
+ keep_params_as_input: bool = False,
+ unwrap_unit_return_tuple: bool = False,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
@@ -995,6 +1011,10 @@ def from_fx(
keep_params_as_input : bool
Whether to keep model parameters as input variables.
+ unwrap_unit_return_tuple : bool
+ A boolean flag indicating if to the return value when it is an unit tuple.
+ When the return value is not a unit tuple, no unwrap will take place.
+
Returns
-------
output : tvm.IRModule
@@ -1062,4 +1082,6 @@ def from_fx(
to print out the tabular representation of the PyTorch module, and then
check the placeholder rows in the beginning of the tabular.
"""
- return TorchFXImporter().from_fx(model, input_info, keep_params_as_input)
+ return TorchFXImporter().from_fx(
+ model, input_info, keep_params_as_input, unwrap_unit_return_tuple
+ )
diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py
index b47e3e22bd..192f8e8b10 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -135,14 +135,14 @@ def test_subgraph_capture():
inp_0: R.Tensor((10, 100), dtype="float32"),
w0: R.Tensor((10, 100), dtype="float32"),
w1: R.Tensor((10,), dtype="float32"),
- ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ ) -> R.Tensor((10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None)
lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
- gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+ gv: R.Tensor((10, 10), dtype="float32") = lv3
R.output(gv)
return gv
@@ -182,11 +182,11 @@ def test_subgraph_capture():
@R.function
def subgraph_1(
inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), dtype="float32")
- ) -> R.Tuple(R.Tensor((10,), dtype="float32")):
+ ) -> R.Tensor((10,), dtype="float32"):
# block 0
with R.dataflow():
lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01)
- gv1: R.Tuple(R.Tensor((10,), dtype="float32")) = (lv5,)
+ gv1: R.Tensor((10,), dtype="float32") = lv5
R.output(gv1)
return gv1
diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py
index e36be8c3c8..458efd7fcb 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2229,6 +2229,35 @@ def test_keep_params():
tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().numpy())
+@tvm.testing.requires_gpu
+def test_unwrap_unit_return_tuple():
+ import torch.fx as fx
+ from torch.nn import Module
+ from tvm.relax.frontend.torch import from_fx
+
+ class Identity(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return (x,)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((256, 256), dtype="float32") = inp_0
+ R.output(gv)
+ return gv
+
+ graph_model = fx.symbolic_trace(Identity())
+ mod = from_fx(graph_model, [([256, 256], "float32")], unwrap_unit_return_tuple=True)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
@tvm.testing.requires_gpu
def test_argmax():
import torch