You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/28 03:35:59 UTC
[incubator-tvm] branch master updated: [Relay][Frontend][Pytorch]
Fixed ConvTranspose2D parsing (#5157)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 9c80662 [Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)
9c80662 is described below
commit 9c806621dbbc46e44c47d3f8e7e3cb8e3dcb3222
Author: Josh Fromm <jw...@uw.edu>
AuthorDate: Fri Mar 27 20:35:51 2020 -0700
[Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)
* Fixed conv transpose parsing.
* small format change.
* Chage test module names.
* Simplified test syntax.
---
python/tvm/relay/frontend/pytorch.py | 6 +++++-
tests/python/frontend/pytorch/test_forward.py | 9 +++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 92f917d..6a26711 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -251,7 +251,7 @@ def _hardtanh():
def _convolution():
def _impl(inputs, input_types):
# Use transpose or normal
- use_transpose = True if inputs[6] == "1" else False
+ use_transpose = True if inputs[6] == 1 else False
data = inputs[0]
weight = inputs[1]
@@ -268,6 +268,10 @@ def _convolution():
else:
assert "data type {} could not be parsed in conv op" % (type(weight))
+ # Transposed convolutions have IOHW layout.
+ if use_transpose:
+ weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
+
channels = weight_shape[0]
groups = int(inputs[8])
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index a5557ce..1878266 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -448,6 +448,14 @@ def test_forward_conv():
input_data=torch.randn((1, 8, 16, 16)))
+def test_forward_conv_transpose():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
+ verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)
+
+
def test_forward_threshold():
torch.set_grad_enabled(False)
input_shape = [1, 3]
@@ -1050,6 +1058,7 @@ if __name__ == "__main__":
test_forward_maxpool1d()
test_forward_hardtanh()
test_forward_conv()
+ test_forward_conv_transpose()
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()