You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/30 19:38:06 UTC
[incubator-mxnet] branch master updated: rnn_cell little bug fixed
(#11003)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 92900f0 rnn_cell little bug fixed (#11003)
92900f0 is described below
commit 92900f03866e708a11765a224c21d7ebb479654a
Author: chinakook <ch...@msn.com>
AuthorDate: Sun Jul 1 03:37:59 2018 +0800
rnn_cell little bug fixed (#11003)
---
CONTRIBUTORS.md | 1 +
python/mxnet/gluon/rnn/rnn_cell.py | 80 ++++++++++++++++++++++++++++++++-
tests/python/unittest/test_gluon_rnn.py | 68 ++++++++++++++++++++++++++++
3 files changed, 147 insertions(+), 2 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index f1ab129..dca9b1f 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -172,3 +172,4 @@ List of Contributors
* [Thomas Delteil](https://github.com/ThomasDelteil)
* [Jesse Brizzi](https://github.com/jessebrizzi)
* [Hang Zhang](http://hangzh.com)
+* [Kou Ding](https://github.com/chinakook)
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index f318b10..0cda938 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -22,7 +22,7 @@
"""Definition of various recurrent neural network cells."""
__all__ = ['RecurrentCell', 'HybridRecurrentCell',
'RNNCell', 'LSTMCell', 'GRUCell',
- 'SequentialRNNCell', 'DropoutCell',
+ 'SequentialRNNCell', 'HybridSequentialRNNCell', 'DropoutCell',
'ModifierCell', 'ZoneoutCell', 'ResidualCell',
'BidirectionalCell']
@@ -81,7 +81,7 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
F = symbol
else:
F = ndarray
- batch_size = inputs[0].shape[batch_axis]
+ batch_size = inputs[0].shape[0]
if merge is True:
inputs = F.stack(*inputs, axis=axis)
in_axis = axis
@@ -687,6 +687,7 @@ class SequentialRNNCell(RecurrentCell):
self._counter += 1
next_states = []
p = 0
+ assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
for cell in self._children.values():
assert not isinstance(cell, BidirectionalCell)
n = len(cell.state_info())
@@ -730,6 +731,81 @@ class SequentialRNNCell(RecurrentCell):
raise NotImplementedError
+class HybridSequentialRNNCell(HybridRecurrentCell):
+ """Sequentially stacking multiple HybridRNN cells."""
+ def __init__(self, prefix=None, params=None):
+ super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)
+
+ def __repr__(self):
+ s = '{name}(\n{modstr}\n)'
+ return s.format(name=self.__class__.__name__,
+ modstr='\n'.join(['({i}): {m}'.format(i=i, m=_indent(m.__repr__(), 2))
+ for i, m in self._children.items()]))
+
+ def add(self, cell):
+ """Appends a cell into the stack.
+
+ Parameters
+ ----------
+ cell : RecurrentCell
+ The cell to add.
+ """
+ self.register_child(cell)
+
+ def state_info(self, batch_size=0):
+ return _cells_state_info(self._children.values(), batch_size)
+
+ def begin_state(self, **kwargs):
+ assert not self._modified, \
+ "After applying modifier cells (e.g. ZoneoutCell) the base " \
+ "cell cannot be called directly. Call the modifier cell instead."
+ return _cells_begin_state(self._children.values(), **kwargs)
+
+ def __call__(self, inputs, states):
+ self._counter += 1
+ next_states = []
+ p = 0
+ assert all(not isinstance(cell, BidirectionalCell) for cell in self._children.values())
+ for cell in self._children.values():
+ n = len(cell.state_info())
+ state = states[p:p+n]
+ p += n
+ inputs, state = cell(inputs, state)
+ next_states.append(state)
+ return inputs, sum(next_states, [])
+
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
+ self.reset()
+
+ inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
+ num_cells = len(self._children)
+ begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
+
+ p = 0
+ next_states = []
+ for i, cell in enumerate(self._children.values()):
+ n = len(cell.state_info())
+ states = begin_state[p:p+n]
+ p += n
+ inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
+ layout=layout,
+ merge_outputs=None if i < num_cells-1 else merge_outputs,
+ valid_length=valid_length)
+ next_states.extend(states)
+
+ return inputs, next_states
+
+ def __getitem__(self, i):
+ return self._children[str(i)]
+
+ def __len__(self):
+ return len(self._children)
+
+ def hybrid_forward(self, F, inputs, states):
+ return self.__call__(inputs, states)
+
+
class DropoutCell(HybridRecurrentCell):
"""Applies dropout on input.
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 9dbcb3b..3e73218 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -171,6 +171,54 @@ def test_stack():
assert outs == [(10, 100), (10, 100), (10, 100)]
+def test_hybridstack():
+ cell = gluon.rnn.HybridSequentialRNNCell()
+ for i in range(5):
+ if i == 1:
+ cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
+ else:
+ cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
+ inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+ outputs, _ = cell.unroll(3, inputs)
+ outputs = mx.sym.Group(outputs)
+ keys = sorted(cell.collect_params().keys())
+ for i in range(5):
+ assert 'rnn_stack%d_h2h_weight'%i in keys
+ assert 'rnn_stack%d_h2h_bias'%i in keys
+ assert 'rnn_stack%d_i2h_weight'%i in keys
+ assert 'rnn_stack%d_i2h_bias'%i in keys
+ assert outputs.list_outputs() == ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']
+
+ args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+ assert outs == [(10, 100), (10, 100), (10, 100)]
+
+ # Test HybridSequentialRNNCell nested in nn.HybridBlock, SequentialRNNCell will fail in this case
+ class BidirectionalOfSequential(gluon.HybridBlock):
+ def __init__(self):
+ super(BidirectionalOfSequential, self).__init__()
+
+ with self.name_scope():
+ cell0 = gluon.rnn.HybridSequentialRNNCell()
+ cell0.add(gluon.rnn.LSTMCell(100))
+ cell0.add(gluon.rnn.LSTMCell(100))
+
+ cell1 = gluon.rnn.HybridSequentialRNNCell()
+ cell1.add(gluon.rnn.LSTMCell(100))
+ cell1.add(gluon.rnn.LSTMCell(100))
+
+ self.rnncell = gluon.rnn.BidirectionalCell(cell0, cell1)
+
+ def hybrid_forward(self, F, x):
+ return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True)
+
+ x = mx.nd.random.uniform(shape=(10, 3, 100))
+ net = BidirectionalOfSequential()
+ net.collect_params().initialize()
+ outs, _ = net(x)
+
+ assert outs.shape == (10, 3, 200)
+
+
def test_bidirectional():
cell = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(100, prefix='rnn_l0_'),
@@ -196,6 +244,26 @@ def test_zoneout():
assert outs == [(10, 100), (10, 100), (10, 100)]
+def test_unroll_layout():
+ cell = gluon.rnn.HybridSequentialRNNCell()
+ for i in range(5):
+ if i == 1:
+ cell.add(gluon.rnn.ResidualCell(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
+ else:
+ cell.add(gluon.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
+ cell.collect_params().initialize()
+ inputs = [mx.nd.random.uniform(shape=(10,50)) for _ in range(3)]
+ outputs, _ = cell.unroll(3, inputs, layout='TNC')
+ assert outputs[0].shape == (10, 100)
+ assert outputs[1].shape == (10, 100)
+ assert outputs[2].shape == (10, 100)
+
+ outputs, _ = cell.unroll(3, inputs, layout='NTC')
+ assert outputs[0].shape == (10, 100)
+ assert outputs[1].shape == (10, 100)
+ assert outputs[2].shape == (10, 100)
+
+
def check_rnn_forward(layer, inputs, deterministic=True):
inputs.attach_grad()
layer.collect_params().initialize()