You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/12 20:35:28 UTC

[tvm] branch main updated: Make keras reshape less restrictive (#7446)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a1260cc  Make keras reshape less restrictive (#7446)
a1260cc is described below

commit a1260cc19342c4db61c6942a11c2b2b2b58f8bad
Author: Trevor Morris <tr...@amazon.com>
AuthorDate: Fri Feb 12 12:35:17 2021 -0800

    Make keras reshape less restrictive (#7446)
---
 python/tvm/relay/frontend/keras.py          | 31 ++++++++---------------------
 tests/python/frontend/keras/test_forward.py | 10 ++++++++++
 2 files changed, 18 insertions(+), 23 deletions(-)

diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index 4bdca2c..eb16bf2 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -864,29 +864,14 @@ def _convert_reshape(inexpr, keras_layer, etab):
     _check_data_format(keras_layer)
     inshape = keras_layer.input_shape  # includes batch
     tshape = keras_layer.target_shape  # no batch
-    if len(inshape) == 3 and len(tshape) == 1:
-        # (?, a, b) -> (-1, ab)
-        shape = (-1, tshape[0])
-    elif len(inshape) in [2, 3] and len(tshape) == 2:
-        # (?, cc) -> (-1, c, c)
-        # (?, a, b) -> (-1, c, c)
-        assert tshape[0] == tshape[1], "Only supports square target shapes, but got {}".format(
-            tshape
-        )
-        shape = (-1,) + tshape
-    else:
-        # (?, h, w, c) -> (-1, c, H, W)
-        # (?, h, w, c) -> (-1, c, hw)
-        # (?, hw, c) -> (-1, c, h, w)
-        ch = inshape[-1]
-        assert ch == tshape[-1], (
-            "Only supports last dimension in target shape being equal to "
-            "the channel number of input tensor."
-        )
-        if etab.data_layout == "NCHW":
-            shape = (-1, ch) + tshape[:-1]
-        else:
-            shape = (-1,) + tshape[:-1] + (ch,)
+    shape = (-1,) + tshape
+
+    if etab.data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2):
+        # Perform reshape in original NHWC format.
+        inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1])
+        inexpr = _op.reshape(inexpr, newshape=shape)
+        return _op.transpose(inexpr, axes=[0, -1] + list(range(1, len(shape) - 1)))
+
     return _op.reshape(inexpr, newshape=shape)
 
 
diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py
index 05d8904..561e444 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -350,6 +350,16 @@ class TestKeras:
         x = keras.layers.Reshape(target_shape=(4, 4))(data)
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model, need_transpose=False)
+        # "non-square" target shape
+        data = keras.layers.Input(shape=(15,))
+        x = keras.layers.Reshape(target_shape=(5, 3))(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+        # modify channel dim
+        data = keras.layers.Input(shape=(3, 2, 4))
+        x = keras.layers.Reshape(target_shape=(3, 8))(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model)
 
     def test_forward_crop(self, keras):
         data = keras.layers.Input(shape=(32, 32, 3))