You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/30 22:34:29 UTC
[incubator-mxnet] branch master updated: fix gluon rnn cell single
step unroll (#15081)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3143858 fix gluon rnn cell single step unroll (#15081)
3143858 is described below
commit 31438583d72dcd72bedb6e83a0884f53a9a8fe37
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu May 30 15:34:06 2019 -0700
fix gluon rnn cell single step unroll (#15081)
---
python/mxnet/gluon/rnn/rnn_cell.py | 5 +++-
tests/python/unittest/test_gluon_rnn.py | 53 +++++++++++++++++----------------
2 files changed, 31 insertions(+), 27 deletions(-)
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 9154ccf..71c7b3f 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -114,7 +114,10 @@ def _reverse_sequences(sequences, unroll_step, valid_length=None):
reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
- reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
+ if unroll_step > 1 or F is symbol:
+ reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
+ else:
+ reversed_sequences = [reversed_sequences[0]]
return reversed_sequences
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 9d78920..309756b 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -634,32 +634,33 @@ def test_layer_fill_shape():
def test_bidirectional_unroll_valid_length():
- # Test BidirectionalCell.
- # In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).
-
- class BiLSTM(gluon.nn.HybridBlock):
- def __init__(self, rnn_size, time_step, **kwargs):
- super(BiLSTM, self).__init__(**kwargs)
- self.time_step = time_step
- with self.name_scope():
- self.bi_lstm = gluon.rnn.BidirectionalCell(
- gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
- gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
- output_prefix='lstm_bi_')
-
- def hybrid_forward(self, F, inputs, valid_len):
- outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
- layout='NTC', merge_outputs=True)
- return outputs, states
-
- rnn_size, time_step = 100, 3
- net = BiLSTM(rnn_size, time_step)
- net.initialize()
- net.hybridize()
- inputs_data = mx.nd.random.uniform(shape=(10, 3, 50))
- valid_len = mx.nd.array([1]*10)
- outputs, _ = net(inputs_data, valid_len)
- assert outputs.shape == (10, 3, 200)
+ def _check_bidirectional_unroll_valid_length(length):
+ class BiLSTM(gluon.nn.HybridBlock):
+ def __init__(self, rnn_size, time_step, **kwargs):
+ super(BiLSTM, self).__init__(**kwargs)
+ self.time_step = time_step
+ with self.name_scope():
+ self.bi_lstm = gluon.rnn.BidirectionalCell(
+ gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
+ gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
+ output_prefix='lstm_bi_')
+
+ def hybrid_forward(self, F, inputs, valid_len):
+ outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
+ layout='NTC', merge_outputs=True)
+ return outputs, states
+
+ rnn_size = 100
+ net = BiLSTM(rnn_size, length)
+ net.initialize()
+ net.hybridize()
+ inputs_data = mx.nd.random.uniform(shape=(10, length, 50))
+ valid_len = mx.nd.array([length]*10)
+ outputs, _ = net(inputs_data, valid_len)
+ assert outputs.shape == (10, length, 200)
+
+ _check_bidirectional_unroll_valid_length(1)
+ _check_bidirectional_unroll_valid_length(3)
if __name__ == '__main__':