You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/30 19:38:01 UTC

[GitHub] szha closed pull request #11003: rnn_cell little bug fixed

szha closed pull request #11003: rnn_cell little bug fixed
URL: https://github.com/apache/incubator-mxnet/pull/11003
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index f1ab129288a..dca9b1f7282 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -172,3 +172,4 @@ List of Contributors
 * [Thomas Delteil](https://github.com/ThomasDelteil)
 * [Jesse Brizzi](https://github.com/jessebrizzi)
 * [Hang Zhang](http://hangzh.com)
+* [Kou Ding](https://github.com/chinakook)
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index f318b10812a..0cda938f2f1 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -22,7 +22,7 @@
 """Definition of various recurrent neural network cells."""
 __all__ = ['RecurrentCell', 'HybridRecurrentCell',
            'RNNCell', 'LSTMCell', 'GRUCell',
-           'SequentialRNNCell', 'DropoutCell',
+           'SequentialRNNCell', 'HybridSequentialRNNCell', 'DropoutCell',
            'ModifierCell', 'ZoneoutCell', 'ResidualCell',
            'BidirectionalCell']
 
@@ -81,7 +81,7 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
             F = symbol
         else:
             F = ndarray
-            batch_size = inputs[0].shape[batch_axis]
+            batch_size = inputs[0].shape[0]
         if merge is True:
             inputs = F.stack(*inputs, axis=axis)
             in_axis = axis
@@ -687,6 +687,7 @@ def __call__(self, inputs, states):
         self._counter += 1
         next_states = []
         p = 0
+        assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
         for cell in self._children.values():
             assert not isinstance(cell, BidirectionalCell)
             n = len(cell.state_info())
@@ -730,6 +731,81 @@ def hybrid_forward(self, *args, **kwargs):
         raise NotImplementedError
 
 
+class HybridSequentialRNNCell(HybridRecurrentCell):
+    """Sequentially stacking multiple HybridRNN cells."""
+    def __init__(self, prefix=None, params=None):
+        super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)
+
+    def __repr__(self):
+        s = '{name}(\n{modstr}\n)'
+        return s.format(name=self.__class__.__name__,
+                        modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2))
+                                          for i, m in self._children.items()]))
+
+    def add(self, cell):
+        """Appends a cell into the stack.
+
+        Parameters
+        ----------
+        cell : RecurrentCell
+            The cell to add.
+        """
+        self.register_child(cell)
+
+    def state_info(self, batch_size=0):
+        return _cells_state_info(self._children.values(), batch_size)
+
+    def begin_state(self, **kwargs):
+        assert not self._modified, \
+            "After applying modifier cells (e.g. ZoneoutCell) the base " \
+            "cell cannot be called directly. Call the modifier cell instead."
+        return _cells_begin_state(self._children.values(), **kwargs)
+
+    def __call__(self, inputs, states):
+        self._counter += 1
+        next_states = []
+        p = 0
+        assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
+        for cell in self._children.values():
+            n = len(cell.state_info())
+            state = states[p:p+n]
+            p += n
+            inputs, state = cell(inputs, state)
+            next_states.append(state)
+        return inputs, sum(next_states, [])
+
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
+        self.reset()
+
+        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
+        num_cells = len(self._children)
+        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
+
+        p = 0
+        next_states = []
+        for i, cell in enumerate(self._children.values()):
+            n = len(cell.state_info())
+            states = begin_state[p:p+n]
+            p += n
+            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
+                                         layout=layout,
+                                         merge_outputs=None if i < num_cells-1 else merge_outputs,
+                                         valid_length=valid_length)
+            next_states.extend(states)
+
+        return inputs, next_states
+
+    def __getitem__(self, i):
+        return self._children[str(i)]
+
+    def __len__(self):
+        return len(self._children)
+
+    def hybrid_forward(self, F, inputs, states):
+        return self.__call__(inputs, states)
+
+
 class DropoutCell(HybridRecurrentCell):
     """Applies dropout on input.
 
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 9dbcb3b3be8..3e73218df6b 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -171,6 +171,54 @@ def test_stack():
     assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
+def test_hybridstack():
+    cell = gluon.rnn.HybridSequentialRNNCell()
+    for i in range(5):
+        if i == 1:
+            cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
+        else:
+            cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    keys = sorted(cell.collect_params().keys())
+    for i in range(5):
+        assert 'rnn_stack%d_h2h_weight'%i in keys
+        assert 'rnn_stack%d_h2h_bias'%i in keys
+        assert 'rnn_stack%d_i2h_weight'%i in keys
+        assert 'rnn_stack%d_i2h_bias'%i in keys
+    assert outputs.list_outputs() == ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+    assert outs == [(10, 100), (10, 100), (10, 100)]
+
+    # Test HybridSequentialRNNCell nested in nn.HybridBlock, SequentialRNNCell will fail in this case
+    class BidirectionalOfSequential(gluon.HybridBlock):
+        def __init__(self):
+            super(BidirectionalOfSequential, self).__init__()
+
+            with self.name_scope():
+                cell0 = gluon.rnn.HybridSequentialRNNCell()
+                cell0.add(gluon.rnn.LSTMCell(100))
+                cell0.add(gluon.rnn.LSTMCell(100))
+                
+                cell1 = gluon.rnn.HybridSequentialRNNCell()
+                cell1.add(gluon.rnn.LSTMCell(100))
+                cell1.add(gluon.rnn.LSTMCell(100))
+
+                self.rnncell = gluon.rnn.BidirectionalCell(cell0, cell1)
+        
+        def hybrid_forward(self, F, x):
+            return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True)
+            
+    x = mx.nd.random.uniform(shape=(10, 3, 100))
+    net = BidirectionalOfSequential()
+    net.collect_params().initialize()
+    outs, _ = net(x)
+    
+    assert outs.shape == (10, 3, 200)
+
+
 def test_bidirectional():
     cell = gluon.rnn.BidirectionalCell(
             gluon.rnn.LSTMCell(100, prefix='rnn_l0_'),
@@ -196,6 +244,26 @@ def test_zoneout():
     assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
+def test_unroll_layout():
+    cell = gluon.rnn.HybridSequentialRNNCell()
+    for i in range(5):
+        if i == 1:
+            cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
+        else:
+            cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
+    cell.collect_params().initialize()
+    inputs = [mx.nd.random.uniform(shape=(10,50)) for _ in range(3)]
+    outputs, _ = cell.unroll(3, inputs, layout='TNC')
+    assert outputs[0].shape == (10, 100)
+    assert outputs[1].shape == (10, 100)
+    assert outputs[2].shape == (10, 100)
+
+    outputs, _ = cell.unroll(3, inputs, layout='NTC')
+    assert outputs[0].shape == (10, 100)
+    assert outputs[1].shape == (10, 100)
+    assert outputs[2].shape == (10, 100)
+
+
 def check_rnn_forward(layer, inputs, deterministic=True):
     inputs.attach_grad()
     layer.collect_params().initialize()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services