You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/08 18:38:40 UTC
[incubator-mxnet] branch master updated: [MXNET-31] Support
variable sequence length in gluon.RecurrentCell (#9934)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f04edb5 [MXNET-31] Support variable sequence length in gluon.RecurrentCell (#9934)
f04edb5 is described below
commit f04edb5045c3134f8d055d17185f35ff6e0f98fa
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Thu Mar 8 10:38:37 2018 -0800
[MXNET-31] Support variable sequence length in gluon.RecurrentCell (#9934)
* try to enable variable length sequence in RNN
improve the speed of stacking the outputs
mask the output also
fix bug
fix bug
add test
* add the argument to cells in contrib
* Use one line to split the elements to a list
* remove redundant code
* fix bug in VariationalDropoutCell
* add test for VariationalDropoutCell
* fix bug
---
python/mxnet/gluon/contrib/rnn/rnn_cell.py | 28 ++++++--
python/mxnet/gluon/rnn/rnn_cell.py | 111 ++++++++++++++++++++++-------
tests/python/unittest/test_gluon_rnn.py | 76 ++++++++++++++++++++
3 files changed, 184 insertions(+), 31 deletions(-)
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index d74c107..d6402b7 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -20,7 +20,8 @@
__all__ = ['VariationalDropoutCell']
from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
-from ...rnn.rnn_cell import _format_sequence, _get_begin_state
+from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
+from ... import tensor_types
class VariationalDropoutCell(ModifierCell):
@@ -113,7 +114,8 @@ class VariationalDropoutCell(ModifierCell):
return s.format(name=self.__class__.__name__,
**self.__dict__)
- def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
"""Unrolls an RNN cell across time steps.
Parameters
@@ -143,6 +145,15 @@ class VariationalDropoutCell(ModifierCell):
(batch_size, length, ...) if layout is 'NTC',
or (length, batch_size, ...) if layout is 'TNC'.
If `None`, output whatever is faster.
+ valid_length : Symbol, NDArray or None
+ `valid_length` specifies the length of the sequences in the batch without padding.
+ This option is especially useful for building sequence-to-sequence models where
+ the input and output sequences would potentially be padded.
+ If `valid_length` is None, all sequences are assumed to have the same length.
+ If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+ The ith element will be the length of the ith sequence in the batch.
+ The last valid state will be return and the padded outputs will be masked with 0.
+ Note that `valid_length` must be smaller or equal to `length`.
Returns
-------
@@ -160,7 +171,8 @@ class VariationalDropoutCell(ModifierCell):
# only when state dropout is not present.
if self.drop_states:
return super(VariationalDropoutCell, self).unroll(length, inputs, begin_state,
- layout, merge_outputs)
+ layout, merge_outputs,
+ valid_length=valid_length)
self.reset()
@@ -172,12 +184,16 @@ class VariationalDropoutCell(ModifierCell):
self._initialize_input_masks(F, first_input, states)
inputs = F.broadcast_mul(inputs, self.drop_inputs_mask.expand_dims(axis=axis))
- outputs, states = self.base_cell.unroll(length, inputs, states, layout, merge_outputs=True)
+ outputs, states = self.base_cell.unroll(length, inputs, states, layout, merge_outputs=True,
+ valid_length=valid_length)
if self.drop_outputs:
first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis, squeeze_axis=True)
self._initialize_output_mask(F, first_output)
outputs = F.broadcast_mul(outputs, self.drop_outputs_mask.expand_dims(axis=axis))
-
+ merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
+ merge_outputs
outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
-
+ if valid_length is not None:
+ outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
+ merge_outputs)
return outputs, states
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32f..61bf24e 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -83,8 +83,7 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
F = ndarray
batch_size = inputs[0].shape[batch_axis]
if merge is True:
- inputs = [F.expand_dims(i, axis=axis) for i in inputs]
- inputs = F.concat(*inputs, dim=axis)
+ inputs = F.stack(*inputs, axis=axis)
in_axis = axis
if isinstance(inputs, tensor_types) and axis != in_axis:
@@ -92,6 +91,16 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
return inputs, axis, F, batch_size
+def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, merge):
+ assert valid_length is not None
+ if not isinstance(data, tensor_types):
+ data = F.stack(*data, axis=time_axis)
+ outputs = F.SequenceMask(data, sequence_length=valid_length, use_sequence_length=True,
+ axis=time_axis)
+ if not merge:
+ outputs = _as_list(F.split(outputs, num_outputs=length, axis=time_axis,
+ squeeze_axis=True))
+ return outputs
class RecurrentCell(Block):
"""Abstract base class for RNN cells
@@ -163,7 +172,8 @@ class RecurrentCell(Block):
states.append(state)
return states
- def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
"""Unrolls an RNN cell across time steps.
Parameters
@@ -193,6 +203,15 @@ class RecurrentCell(Block):
(batch_size, length, ...) if layout is 'NTC',
or (length, batch_size, ...) if layout is 'TNC'.
If `None`, output whatever is faster.
+ valid_length : Symbol, NDArray or None
+ `valid_length` specifies the length of the sequences in the batch without padding.
+ This option is especially useful for building sequence-to-sequence models where
+ the input and output sequences would potentially be padded.
+ If `valid_length` is None, all sequences are assumed to have the same length.
+ If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+ The ith element will be the length of the ith sequence in the batch.
+ The last valid state will be return and the padded outputs will be masked with 0.
+ Note that `valid_length` must be smaller or equal to `length`.
Returns
-------
@@ -207,15 +226,24 @@ class RecurrentCell(Block):
"""
self.reset()
- inputs, _, F, batch_size = _format_sequence(length, inputs, layout, False)
+ inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
states = begin_state
outputs = []
+ all_states = []
for i in range(length):
output, states = self(inputs[i], states)
outputs.append(output)
-
+ if valid_length is not None:
+ all_states.append(states)
+ if valid_length is not None:
+ states = [F.SequenceLast(F.stack(*ele_list, axis=0),
+ sequence_length=valid_length,
+ use_sequence_length=True,
+ axis=0)
+ for ele_list in zip(*all_states)]
+ outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, True)
outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
return outputs, states
@@ -645,7 +673,8 @@ class SequentialRNNCell(RecurrentCell):
next_states.append(state)
return inputs, sum(next_states, [])
- def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
self.reset()
inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
@@ -658,8 +687,10 @@ class SequentialRNNCell(RecurrentCell):
n = len(cell.state_info())
states = begin_state[p:p+n]
p += n
- inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, layout=layout,
- merge_outputs=None if i < num_cells-1 else merge_outputs)
+ inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
+ layout=layout,
+ merge_outputs=None if i < num_cells-1 else merge_outputs,
+ valid_length=valid_length)
next_states.extend(states)
return inputs, next_states
@@ -713,7 +744,8 @@ class DropoutCell(HybridRecurrentCell):
inputs = F.Dropout(data=inputs, p=self.rate, name='t%d_fwd'%self._counter)
return inputs, states
- def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
self.reset()
inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
@@ -722,7 +754,7 @@ class DropoutCell(HybridRecurrentCell):
else:
return super(DropoutCell, self).unroll(
length, inputs, begin_state=begin_state, layout=layout,
- merge_outputs=merge_outputs)
+ merge_outputs=merge_outputs, valid_length=None)
class ModifierCell(HybridRecurrentCell):
@@ -827,17 +859,23 @@ class ResidualCell(ModifierCell):
output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter)
return output, states
- def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
self.reset()
self.base_cell._modified = False
outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state,
- layout=layout, merge_outputs=merge_outputs)
+ layout=layout, merge_outputs=merge_outputs,
+ valid_length=valid_length)
self.base_cell._modified = True
merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
merge_outputs
- inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
+ inputs, axis, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
+ if valid_length is not None:
+ # mask the padded inputs to zero
+ inputs = _mask_sequence_variable_length(F, inputs, length, valid_length, axis,
+ merge_outputs)
if merge_outputs:
outputs = F.elemwise_add(outputs, inputs)
else:
@@ -880,34 +918,57 @@ class BidirectionalCell(HybridRecurrentCell):
"cell cannot be called directly. Call the modifier cell instead."
return _cells_begin_state(self._children, **kwargs)
- def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+ valid_length=None):
self.reset()
inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
+ if valid_length is None:
+ reversed_inputs = list(reversed(inputs))
+ else:
+ reversed_inputs = F.SequenceReverse(F.stack(*inputs, axis=0),
+ sequence_length=valid_length,
+ use_sequence_length=True)
+ reversed_inputs = _as_list(F.split(reversed_inputs, axis=0, num_outputs=length,
+ squeeze_axis=True))
begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
states = begin_state
l_cell, r_cell = self._children
l_outputs, l_states = l_cell.unroll(length, inputs=inputs,
begin_state=states[:len(l_cell.state_info(batch_size))],
- layout=layout, merge_outputs=merge_outputs)
+ layout=layout, merge_outputs=merge_outputs,
+ valid_length=valid_length)
r_outputs, r_states = r_cell.unroll(length,
- inputs=list(reversed(inputs)),
+ inputs=reversed_inputs,
begin_state=states[len(l_cell.state_info(batch_size)):],
- layout=layout, merge_outputs=merge_outputs)
-
+ layout=layout, merge_outputs=False,
+ valid_length=valid_length)
+ if valid_length is None:
+ reversed_r_outputs = list(reversed(r_outputs))
+ else:
+ reversed_r_outputs = F.SequenceReverse(F.stack(*r_outputs, axis=0),
+ sequence_length=valid_length,
+ use_sequence_length=True,
+ axis=0)
+ reversed_r_outputs = _as_list(F.split(reversed_r_outputs, axis=0, num_outputs=length,
+ squeeze_axis=True))
if merge_outputs is None:
- merge_outputs = (isinstance(l_outputs, tensor_types)
- and isinstance(r_outputs, tensor_types))
+ merge_outputs = isinstance(l_outputs, tensor_types)
l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs)
- r_outputs, _, _, _ = _format_sequence(None, r_outputs, layout, merge_outputs)
+ reversed_r_outputs, _, _, _ = _format_sequence(None, reversed_r_outputs, layout,
+ merge_outputs)
if merge_outputs:
- r_outputs = F.reverse(r_outputs, axis=axis)
- outputs = F.concat(l_outputs, r_outputs, dim=2, name='%sout'%self._output_prefix)
+ reversed_r_outputs = F.stack(*reversed_r_outputs, axis=axis)
+ outputs = F.concat(l_outputs, reversed_r_outputs, dim=2,
+ name='%sout'%self._output_prefix)
+
else:
outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i))
- for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed(r_outputs)))]
-
+ for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed_r_outputs))]
+ if valid_length is not None:
+ outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
+ merge_outputs)
states = l_states + r_states
return outputs, states
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 2288842..871deeb 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -257,6 +257,7 @@ def check_rnn_layer_forward(layer, inputs, states=None):
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)
+
def test_rnn_layers():
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10)))
@@ -274,6 +275,81 @@ def test_rnn_layers():
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()
+
+def test_rnn_unroll_variant_length():
+ # Test for imperative usage
+ cell_list = []
+ for base_cell_class in [gluon.rnn.RNNCell, gluon.rnn.LSTMCell, gluon.rnn.GRUCell]:
+ cell_list.append(base_cell_class(20))
+ cell_list.append(gluon.rnn.BidirectionalCell(
+ l_cell=base_cell_class(20),
+ r_cell=base_cell_class(20)))
+ cell_list.append(gluon.contrib.rnn.VariationalDropoutCell(base_cell=base_cell_class(20)))
+ stack_res_rnn_cell = gluon.rnn.SequentialRNNCell()
+ stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
+ stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
+ cell_list.append(stack_res_rnn_cell)
+ batch_size = 4
+ max_length = 10
+ valid_length = [3, 10, 5, 6]
+ valid_length_nd = mx.nd.array(valid_length)
+ for cell in cell_list:
+ cell.collect_params().initialize()
+ cell.hybridize()
+ # Test for NTC layout
+ data_nd = mx.nd.random.normal(0, 1, shape=(batch_size, max_length, 20))
+ outs, states = cell.unroll(length=max_length, inputs=data_nd,
+ valid_length=valid_length_nd,
+ merge_outputs=True,
+ layout='NTC')
+ for i, ele_length in enumerate(valid_length):
+ # Explicitly unroll each sequence and compare the final states and output
+ ele_out, ele_states = cell.unroll(length=ele_length,
+ inputs=data_nd[i:(i+1), :ele_length, :],
+ merge_outputs=True,
+ layout='NTC')
+ assert_allclose(ele_out.asnumpy(), outs[i:(i+1), :ele_length, :].asnumpy(),
+ atol=1E-4, rtol=1E-4)
+ if ele_length < max_length:
+ # Check the padded outputs are all zero
+ assert_allclose(outs[i:(i+1), ele_length:max_length, :].asnumpy(), 0)
+ for valid_out_state, gt_state in zip(states, ele_states):
+ assert_allclose(valid_out_state[i:(i+1)].asnumpy(), gt_state.asnumpy(),
+ atol=1E-4, rtol=1E-4)
+
+ # Test for TNC layout
+ data_nd = mx.nd.random.normal(0, 1, shape=(max_length, batch_size, 20))
+ outs, states = cell.unroll(length=max_length, inputs=data_nd,
+ valid_length=valid_length_nd,
+ layout='TNC')
+ for i, ele_length in enumerate(valid_length):
+ # Explicitly unroll each sequence and compare the final states and output
+ ele_out, ele_states = cell.unroll(length=ele_length,
+ inputs=data_nd[:ele_length, i:(i+1), :],
+ merge_outputs=True,
+ layout='TNC')
+ assert_allclose(ele_out.asnumpy(), outs[:ele_length, i:(i + 1), :].asnumpy(),
+ atol=1E-4, rtol=1E-4)
+ if ele_length < max_length:
+ # Check the padded outputs are all zero
+ assert_allclose(outs[ele_length:max_length, i:(i+1), :].asnumpy(), 0)
+ for valid_out_state, gt_state in zip(states, ele_states):
+ assert_allclose(valid_out_state[i:(i+1)].asnumpy(), gt_state.asnumpy(),
+ atol=1E-4, rtol=1E-4)
+ # For symbolic test, we need to make sure that it can be binded and run
+ data = mx.sym.var('data', shape=(4, 10, 2))
+ cell = gluon.rnn.RNNCell(100)
+ valid_length = mx.sym.var('valid_length', shape=(4,))
+ outs, states = cell.unroll(length=10, inputs=data, valid_length=valid_length,
+ merge_outputs=True, layout='NTC')
+ mod = mx.mod.Module(states[0], data_names=('data', 'valid_length'), label_names=None,
+ context=mx.cpu())
+ mod.bind(data_shapes=[('data', (4, 10, 2)), ('valid_length', (4,))], label_shapes=None)
+ mod.init_params()
+ mod.forward(mx.io.DataBatch([mx.random.normal(0, 1, (4, 10, 2)), mx.nd.array([3, 6, 10, 2])]))
+ mod.get_outputs()[0].asnumpy()
+
+
def test_cell_fill_shape():
cell = gluon.rnn.LSTMCell(10)
cell.hybridize()
--
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.