You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/28 03:35:59 UTC

[incubator-tvm] branch master updated: [Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c80662  [Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)
9c80662 is described below

commit 9c806621dbbc46e44c47d3f8e7e3cb8e3dcb3222
Author: Josh Fromm <jw...@uw.edu>
AuthorDate: Fri Mar 27 20:35:51 2020 -0700

    [Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)
    
    * Fixed conv transpose parsing.
    
    * small format change.
    
    * Chage test module names.
    
    * Simplified test syntax.
---
 python/tvm/relay/frontend/pytorch.py          | 6 +++++-
 tests/python/frontend/pytorch/test_forward.py | 9 +++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 92f917d..6a26711 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -251,7 +251,7 @@ def _hardtanh():
 def _convolution():
     def _impl(inputs, input_types):
         # Use transpose or normal
-        use_transpose = True if inputs[6] == "1" else False
+        use_transpose = True if inputs[6] == 1 else False
 
         data = inputs[0]
         weight = inputs[1]
@@ -268,6 +268,10 @@ def _convolution():
         else:
             assert "data type {} could not be parsed in conv op" % (type(weight))
 
+        # Transposed convolutions have IOHW layout.
+        if use_transpose:
+            weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
+
         channels = weight_shape[0]
         groups = int(inputs[8])
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index a5557ce..1878266 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -448,6 +448,14 @@ def test_forward_conv():
                  input_data=torch.randn((1, 8, 16, 16)))
 
 
+def test_forward_conv_transpose():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
+    verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)
+
+
 def test_forward_threshold():
     torch.set_grad_enabled(False)
     input_shape = [1, 3]
@@ -1050,6 +1058,7 @@ if __name__ == "__main__":
     test_forward_maxpool1d()
     test_forward_hardtanh()
     test_forward_conv()
+    test_forward_conv_transpose()
     test_forward_threshold()
     test_forward_contiguous()
     test_forward_batchnorm()