You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/30 19:38:06 UTC

[incubator-mxnet] branch master updated: rnn_cell little bug fixed (#11003)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 92900f0  rnn_cell little bug fixed (#11003)
92900f0 is described below

commit 92900f03866e708a11765a224c21d7ebb479654a
Author: chinakook <ch...@msn.com>
AuthorDate: Sun Jul 1 03:37:59 2018 +0800

    rnn_cell little bug fixed (#11003)
---
 CONTRIBUTORS.md                         |  1 +
 python/mxnet/gluon/rnn/rnn_cell.py      | 80 ++++++++++++++++++++++++++++++++-
 tests/python/unittest/test_gluon_rnn.py | 68 ++++++++++++++++++++++++++++
 3 files changed, 147 insertions(+), 2 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index f1ab129..dca9b1f 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -172,3 +172,4 @@ List of Contributors
 * [Thomas Delteil](https://github.com/ThomasDelteil)
 * [Jesse Brizzi](https://github.com/jessebrizzi)
 * [Hang Zhang](http://hangzh.com)
+* [Kou Ding](https://github.com/chinakook)
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index f318b10..0cda938 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -22,7 +22,7 @@
 """Definition of various recurrent neural network cells."""
 __all__ = ['RecurrentCell', 'HybridRecurrentCell',
            'RNNCell', 'LSTMCell', 'GRUCell',
-           'SequentialRNNCell', 'DropoutCell',
+           'SequentialRNNCell', 'HybridSequentialRNNCell', 'DropoutCell',
            'ModifierCell', 'ZoneoutCell', 'ResidualCell',
            'BidirectionalCell']
 
@@ -81,7 +81,7 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
             F = symbol
         else:
             F = ndarray
-            batch_size = inputs[0].shape[batch_axis]
+            batch_size = inputs[0].shape[0]
         if merge is True:
             inputs = F.stack(*inputs, axis=axis)
             in_axis = axis
@@ -687,6 +687,7 @@ class SequentialRNNCell(RecurrentCell):
         self._counter += 1
         next_states = []
         p = 0
+        assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
         for cell in self._children.values():
             assert not isinstance(cell, BidirectionalCell)
             n = len(cell.state_info())
@@ -730,6 +731,81 @@ class SequentialRNNCell(RecurrentCell):
         raise NotImplementedError
 
 
+class HybridSequentialRNNCell(HybridRecurrentCell):
+    """Sequentially stacking multiple HybridRNN cells."""
+    def __init__(self, prefix=None, params=None):
+        super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)
+
+    def __repr__(self):
+        s = '{name}(\n{modstr}\n)'
+        return s.format(name=self.__class__.__name__,
+                        modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2))
+                                          for i, m in self._children.items()]))
+
+    def add(self, cell):
+        """Appends a cell into the stack.
+
+        Parameters
+        ----------
+        cell : RecurrentCell
+            The cell to add.
+        """
+        self.register_child(cell)
+
+    def state_info(self, batch_size=0):
+        return _cells_state_info(self._children.values(), batch_size)
+
+    def begin_state(self, **kwargs):
+        assert not self._modified, \
+            "After applying modifier cells (e.g. ZoneoutCell) the base " \
+            "cell cannot be called directly. Call the modifier cell instead."
+        return _cells_begin_state(self._children.values(), **kwargs)
+
+    def __call__(self, inputs, states):
+        self._counter += 1
+        next_states = []
+        p = 0
+        assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
+        for cell in self._children.values():
+            n = len(cell.state_info())
+            state = states[p:p+n]
+            p += n
+            inputs, state = cell(inputs, state)
+            next_states.append(state)
+        return inputs, sum(next_states, [])
+
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
+        self.reset()
+
+        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
+        num_cells = len(self._children)
+        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
+
+        p = 0
+        next_states = []
+        for i, cell in enumerate(self._children.values()):
+            n = len(cell.state_info())
+            states = begin_state[p:p+n]
+            p += n
+            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
+                                         layout=layout,
+                                         merge_outputs=None if i < num_cells-1 else merge_outputs,
+                                         valid_length=valid_length)
+            next_states.extend(states)
+
+        return inputs, next_states
+
+    def __getitem__(self, i):
+        return self._children[str(i)]
+
+    def __len__(self):
+        return len(self._children)
+
+    def hybrid_forward(self, F, inputs, states):
+        return self.__call__(inputs, states)
+
+
 class DropoutCell(HybridRecurrentCell):
     """Applies dropout on input.
 
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 9dbcb3b..3e73218 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -171,6 +171,54 @@ def test_stack():
     assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
+def test_hybridstack():
+    cell = gluon.rnn.HybridSequentialRNNCell()
+    for i in range(5):
+        if i == 1:
+            cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
+        else:
+            cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    keys = sorted(cell.collect_params().keys())
+    for i in range(5):
+        assert 'rnn_stack%d_h2h_weight'%i in keys
+        assert 'rnn_stack%d_h2h_bias'%i in keys
+        assert 'rnn_stack%d_i2h_weight'%i in keys
+        assert 'rnn_stack%d_i2h_bias'%i in keys
+    assert outputs.list_outputs() == ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+    assert outs == [(10, 100), (10, 100), (10, 100)]
+
+    # Test HybridSequentialRNNCell nested in nn.HybridBlock, SequentialRNNCell will fail in this case
+    class BidirectionalOfSequential(gluon.HybridBlock):
+        def __init__(self):
+            super(BidirectionalOfSequential, self).__init__()
+
+            with self.name_scope():
+                cell0 = gluon.rnn.HybridSequentialRNNCell()
+                cell0.add(gluon.rnn.LSTMCell(100))
+                cell0.add(gluon.rnn.LSTMCell(100))
+                
+                cell1 = gluon.rnn.HybridSequentialRNNCell()
+                cell1.add(gluon.rnn.LSTMCell(100))
+                cell1.add(gluon.rnn.LSTMCell(100))
+
+                self.rnncell = gluon.rnn.BidirectionalCell(cell0, cell1)
+        
+        def hybrid_forward(self, F, x):
+            return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True)
+            
+    x = mx.nd.random.uniform(shape=(10, 3, 100))
+    net = BidirectionalOfSequential()
+    net.collect_params().initialize()
+    outs, _ = net(x)
+    
+    assert outs.shape == (10, 3, 200)
+
+
 def test_bidirectional():
     cell = gluon.rnn.BidirectionalCell(
             gluon.rnn.LSTMCell(100, prefix='rnn_l0_'),
@@ -196,6 +244,26 @@ def test_zoneout():
     assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
+def test_unroll_layout():
+    cell = gluon.rnn.HybridSequentialRNNCell()
+    for i in range(5):
+        if i == 1:
+            cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
+        else:
+            cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
+    cell.collect_params().initialize()
+    inputs = [mx.nd.random.uniform(shape=(10,50)) for _ in range(3)]
+    outputs, _ = cell.unroll(3, inputs, layout='TNC')
+    assert outputs[0].shape == (10, 100)
+    assert outputs[1].shape == (10, 100)
+    assert outputs[2].shape == (10, 100)
+
+    outputs, _ = cell.unroll(3, inputs, layout='NTC')
+    assert outputs[0].shape == (10, 100)
+    assert outputs[1].shape == (10, 100)
+    assert outputs[2].shape == (10, 100)
+
+
 def check_rnn_forward(layer, inputs, deterministic=True):
     inputs.attach_grad()
     layer.collect_params().initialize()