You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/12 20:35:28 UTC
[tvm] branch main updated: Make keras reshape less restrictive
(#7446)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a1260cc Make keras reshape less restrictive (#7446)
a1260cc is described below
commit a1260cc19342c4db61c6942a11c2b2b2b58f8bad
Author: Trevor Morris <tr...@amazon.com>
AuthorDate: Fri Feb 12 12:35:17 2021 -0800
Make keras reshape less restrictive (#7446)
---
python/tvm/relay/frontend/keras.py | 31 ++++++++---------------------
tests/python/frontend/keras/test_forward.py | 10 ++++++++++
2 files changed, 18 insertions(+), 23 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index 4bdca2c..eb16bf2 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -864,29 +864,14 @@ def _convert_reshape(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
inshape = keras_layer.input_shape # includes batch
tshape = keras_layer.target_shape # no batch
- if len(inshape) == 3 and len(tshape) == 1:
- # (?, a, b) -> (-1, ab)
- shape = (-1, tshape[0])
- elif len(inshape) in [2, 3] and len(tshape) == 2:
- # (?, cc) -> (-1, c, c)
- # (?, a, b) -> (-1, c, c)
- assert tshape[0] == tshape[1], "Only supports square target shapes, but got {}".format(
- tshape
- )
- shape = (-1,) + tshape
- else:
- # (?, h, w, c) -> (-1, c, H, W)
- # (?, h, w, c) -> (-1, c, hw)
- # (?, hw, c) -> (-1, c, h, w)
- ch = inshape[-1]
- assert ch == tshape[-1], (
- "Only supports last dimension in target shape being equal to "
- "the channel number of input tensor."
- )
- if etab.data_layout == "NCHW":
- shape = (-1, ch) + tshape[:-1]
- else:
- shape = (-1,) + tshape[:-1] + (ch,)
+ shape = (-1,) + tshape
+
+ if etab.data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2):
+ # Perform reshape in original NHWC format.
+ inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1])
+ inexpr = _op.reshape(inexpr, newshape=shape)
+ return _op.transpose(inexpr, axes=[0, -1] + list(range(1, len(shape) - 1)))
+
return _op.reshape(inexpr, newshape=shape)
diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py
index 05d8904..561e444 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -350,6 +350,16 @@ class TestKeras:
x = keras.layers.Reshape(target_shape=(4, 4))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
+ # "non-square" target shape
+ data = keras.layers.Input(shape=(15,))
+ x = keras.layers.Reshape(target_shape=(5, 3))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model, need_transpose=False)
+ # modify channel dim
+ data = keras.layers.Input(shape=(3, 2, 4))
+ x = keras.layers.Reshape(target_shape=(3, 8))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model)
def test_forward_crop(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))