You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/05/30 22:34:29 UTC

[incubator-mxnet] branch master updated: fix gluon rnn cell single step unroll (#15081)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3143858  fix gluon rnn cell single step unroll (#15081)
3143858 is described below

commit 31438583d72dcd72bedb6e83a0884f53a9a8fe37
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu May 30 15:34:06 2019 -0700

    fix gluon rnn cell single step unroll (#15081)
---
 python/mxnet/gluon/rnn/rnn_cell.py      |  5 +++-
 tests/python/unittest/test_gluon_rnn.py | 53 +++++++++++++++++----------------
 2 files changed, 31 insertions(+), 27 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 9154ccf..71c7b3f 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -114,7 +114,10 @@ def _reverse_sequences(sequences, unroll_step, valid_length=None):
         reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
                                                sequence_length=valid_length,
                                                use_sequence_length=True)
-        reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
+        if unroll_step > 1 or F is symbol:
+            reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
+        else:
+            reversed_sequences = [reversed_sequences[0]]
 
     return reversed_sequences
 
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 9d78920..309756b 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -634,32 +634,33 @@ def test_layer_fill_shape():
 
 
 def test_bidirectional_unroll_valid_length():
-    # Test BidirectionalCell.
-    # In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).
-    
-    class BiLSTM(gluon.nn.HybridBlock):
-        def __init__(self, rnn_size, time_step, **kwargs):
-            super(BiLSTM, self).__init__(**kwargs)
-            self.time_step = time_step
-            with self.name_scope():
-                self.bi_lstm = gluon.rnn.BidirectionalCell(
-                    gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
-                    gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
-                    output_prefix='lstm_bi_')
-
-        def hybrid_forward(self, F, inputs, valid_len):
-            outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
-                                                  layout='NTC', merge_outputs=True)
-            return outputs, states
-
-    rnn_size, time_step = 100, 3
-    net = BiLSTM(rnn_size, time_step)
-    net.initialize()
-    net.hybridize()
-    inputs_data = mx.nd.random.uniform(shape=(10, 3, 50))
-    valid_len = mx.nd.array([1]*10)
-    outputs, _ = net(inputs_data, valid_len)
-    assert outputs.shape == (10, 3, 200)
+    def _check_bidirectional_unroll_valid_length(length):
+        class BiLSTM(gluon.nn.HybridBlock):
+            def __init__(self, rnn_size, time_step, **kwargs):
+                super(BiLSTM, self).__init__(**kwargs)
+                self.time_step = time_step
+                with self.name_scope():
+                    self.bi_lstm = gluon.rnn.BidirectionalCell(
+                        gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
+                        gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
+                        output_prefix='lstm_bi_')
+
+            def hybrid_forward(self, F, inputs, valid_len):
+                outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
+                                                      layout='NTC', merge_outputs=True)
+                return outputs, states
+
+        rnn_size = 100
+        net = BiLSTM(rnn_size, length)
+        net.initialize()
+        net.hybridize()
+        inputs_data = mx.nd.random.uniform(shape=(10, length, 50))
+        valid_len = mx.nd.array([length]*10)
+        outputs, _ = net(inputs_data, valid_len)
+        assert outputs.shape == (10, length, 200)
+
+    _check_bidirectional_unroll_valid_length(1)
+    _check_bidirectional_unroll_valid_length(3)
 
 
 if __name__ == '__main__':