You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ec...@apache.org on 2023/04/14 07:12:42 UTC

[tvm] branch main updated: [Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)

This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 62f9b1d29a [Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
62f9b1d29a is described below

commit 62f9b1d29ae25fbdeb425bfc600c5dac7c23f694
Author: Qingchao Shen <qi...@outlook.com>
AuthorDate: Fri Apr 14 15:12:29 2023 +0800

    [Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
    
    * [Tensorflow] Fix conv2d_transpose for NHWC layout
    
    If "data_format" == "NHWC", the kernel_layout should be "HWOI" rather than "HWIO".
    
    * remove deed code
    
    * add test cases
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update tensorflow_ops.py
    
    * Update tensorflow_ops.py
---
 python/tvm/relay/frontend/tensorflow_ops.py      |  4 ++--
 tests/python/frontend/tensorflow/test_forward.py | 21 ++++++++++++++++++++-
 2 files changed, 22 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py
index 6b3f144619..27374fad1a 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -464,8 +464,8 @@ def _conv(opname):
             if opname == "conv":
                 attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
             elif opname == "conv_transpose":
-                # conv_transpose in TVM has weights be IOHW for NCHW
-                attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
+                # conv_transpose has weights be IOHW, because the attr["data_format"] always be NCHW
+                attr["kernel_layout"] = "IOHW"
             else:
                 attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 703df79942..bd966fa71c 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -742,7 +742,16 @@ def test_forward_convolution():
             "NCHW",
             [1, 1, 8, 8],
         )
-
+        _test_convolution(
+            "conv_transpose",
+            [4, 19, 8, 8],
+            [2, 2, 66, 19],
+            [1, 1],
+            [2, 2],
+            "VALID",
+            "NCHW",
+            [4, 66, 16, 16],
+        )
     _test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC")
     _test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC")
     _test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC")
@@ -917,6 +926,16 @@ def test_forward_convolution():
         [4, 8, 8, 176],
         add_shapes_to_graph_def=False,
     )
+    _test_convolution(
+        "conv_transpose",
+        [4, 8, 8, 19],
+        [2, 2, 66, 19],
+        [1, 1],
+        [2, 2],
+        "VALID",
+        "NHWC",
+        [4, 16, 16, 66],
+    )
     # Explicit padding
     if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
         _test_convolution(