You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ec...@apache.org on 2023/04/14 07:12:42 UTC
[tvm] branch main updated: [Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 62f9b1d29a [Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
62f9b1d29a is described below
commit 62f9b1d29ae25fbdeb425bfc600c5dac7c23f694
Author: Qingchao Shen <qi...@outlook.com>
AuthorDate: Fri Apr 14 15:12:29 2023 +0800
[Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
* [Tensorflow] Fix conv2d_transpose for NHWC layout
If "data_format" == "NHWC", the kernel_layout should be "HWOI" rather than "HWIO".
* remove deed code
* add test cases
* Update test_forward.py
* Update test_forward.py
* Update tensorflow_ops.py
* Update tensorflow_ops.py
---
python/tvm/relay/frontend/tensorflow_ops.py | 4 ++--
tests/python/frontend/tensorflow/test_forward.py | 21 ++++++++++++++++++++-
2 files changed, 22 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py
index 6b3f144619..27374fad1a 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -464,8 +464,8 @@ def _conv(opname):
if opname == "conv":
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
elif opname == "conv_transpose":
- # conv_transpose in TVM has weights be IOHW for NCHW
- attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
+ # conv_transpose has weights be IOHW, because the attr["data_format"] always be NCHW
+ attr["kernel_layout"] = "IOHW"
else:
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 703df79942..bd966fa71c 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -742,7 +742,16 @@ def test_forward_convolution():
"NCHW",
[1, 1, 8, 8],
)
-
+ _test_convolution(
+ "conv_transpose",
+ [4, 19, 8, 8],
+ [2, 2, 66, 19],
+ [1, 1],
+ [2, 2],
+ "VALID",
+ "NCHW",
+ [4, 66, 16, 16],
+ )
_test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC")
_test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC")
_test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC")
@@ -917,6 +926,16 @@ def test_forward_convolution():
[4, 8, 8, 176],
add_shapes_to_graph_def=False,
)
+ _test_convolution(
+ "conv_transpose",
+ [4, 8, 8, 19],
+ [2, 2, 66, 19],
+ [1, 1],
+ [2, 2],
+ "VALID",
+ "NHWC",
+ [4, 16, 16, 66],
+ )
# Explicit padding
if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
_test_convolution(