You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/05/05 00:46:44 UTC
[tvm] branch main updated: [Frontend][Keras] Fix Dense with 3d
inputs (#7753)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 26a5e29 [Frontend][Keras] Fix Dense with 3d inputs (#7753)
26a5e29 is described below
commit 26a5e299bedf8beb255baac7e1c4fb0946f4e17b
Author: Trevor Morris <tr...@amazon.com>
AuthorDate: Tue May 4 17:46:26 2021 -0700
[Frontend][Keras] Fix Dense with 3d inputs (#7753)
* Fix keras rnn dense
* Fix unit test
* Fix unit test
---
python/tvm/relay/frontend/keras.py | 2 +-
tests/python/frontend/keras/test_forward.py | 5 +++++
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index eb16bf2..2227e12 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -250,7 +250,7 @@ def _convert_dense(inexpr, keras_layer, etab):
raise tvm.error.OpAttributeInvalid(
"Input shape {} is not valid for operator Dense.".format(input_shape)
)
- inexpr = _op.squeeze(inexpr, axis=0)
+ inexpr = _op.squeeze(inexpr, axis=[0])
out = _op.nn.dense(data=inexpr, **params)
if keras_layer.use_bias:
bias = etab.new_const(weightList[1])
diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py
index c7f734b..e9420a5 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -198,6 +198,11 @@ class TestKeras:
x = keras.layers.Dense(10, activation="relu", kernel_initializer="uniform")(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
+ # RNN dense
+ data = keras.layers.Input(shape=(1, 32))
+ x = keras.layers.Dense(32, activation="relu", kernel_initializer="uniform")(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_permute(self, keras):
data = keras.layers.Input(shape=(2, 3, 4))