You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/08 18:38:40 UTC

[incubator-mxnet] branch master updated: [MXNET-31] Support variable sequence length in gluon.RecurrentCell (#9934)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f04edb5  [MXNET-31] Support variable sequence length in gluon.RecurrentCell  (#9934)
f04edb5 is described below

commit f04edb5045c3134f8d055d17185f35ff6e0f98fa
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Thu Mar 8 10:38:37 2018 -0800

    [MXNET-31] Support variable sequence length in gluon.RecurrentCell  (#9934)
    
    * try to enable variable length sequence in RNN
    
    improve the speed of stacking the outputs
    
    mask the output also
    
    fix bug
    
    fix bug
    
    add test
    
    * add the argument to cells in contrib
    
    * Use one line to split the elements to a list
    
    * remove redundant code
    
    * fix bug in VariationalDropoutCell
    
    * add test for VariationalDropoutCell
    
    * fix bug
---
 python/mxnet/gluon/contrib/rnn/rnn_cell.py |  28 ++++++--
 python/mxnet/gluon/rnn/rnn_cell.py         | 111 ++++++++++++++++++++++-------
 tests/python/unittest/test_gluon_rnn.py    |  76 ++++++++++++++++++++
 3 files changed, 184 insertions(+), 31 deletions(-)

diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index d74c107..d6402b7 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -20,7 +20,8 @@
 __all__ = ['VariationalDropoutCell']
 
 from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
-from ...rnn.rnn_cell import _format_sequence, _get_begin_state
+from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
+from ... import tensor_types
 
 
 class VariationalDropoutCell(ModifierCell):
@@ -113,7 +114,8 @@ class VariationalDropoutCell(ModifierCell):
         return s.format(name=self.__class__.__name__,
                         **self.__dict__)
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
         """Unrolls an RNN cell across time steps.
 
         Parameters
@@ -143,6 +145,15 @@ class VariationalDropoutCell(ModifierCell):
             (batch_size, length, ...) if layout is 'NTC',
             or (length, batch_size, ...) if layout is 'TNC'.
             If `None`, output whatever is faster.
+        valid_length : Symbol, NDArray or None
+            `valid_length` specifies the length of the sequences in the batch without padding.
+            This option is especially useful for building sequence-to-sequence models where
+            the input and output sequences would potentially be padded.
+            If `valid_length` is None, all sequences are assumed to have the same length.
+            If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+            The ith element will be the length of the ith sequence in the batch.
+            The last valid state will be return and the padded outputs will be masked with 0.
+            Note that `valid_length` must be smaller or equal to `length`.
 
         Returns
         -------
@@ -160,7 +171,8 @@ class VariationalDropoutCell(ModifierCell):
         # only when state dropout is not present.
         if self.drop_states:
             return super(VariationalDropoutCell, self).unroll(length, inputs, begin_state,
-                                                              layout, merge_outputs)
+                                                              layout, merge_outputs,
+                                                              valid_length=valid_length)
 
         self.reset()
 
@@ -172,12 +184,16 @@ class VariationalDropoutCell(ModifierCell):
             self._initialize_input_masks(F, first_input, states)
             inputs = F.broadcast_mul(inputs, self.drop_inputs_mask.expand_dims(axis=axis))
 
-        outputs, states = self.base_cell.unroll(length, inputs, states, layout, merge_outputs=True)
+        outputs, states = self.base_cell.unroll(length, inputs, states, layout, merge_outputs=True,
+                                                valid_length=valid_length)
         if self.drop_outputs:
             first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis, squeeze_axis=True)
             self._initialize_output_mask(F, first_output)
             outputs = F.broadcast_mul(outputs, self.drop_outputs_mask.expand_dims(axis=axis))
-
+        merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
+            merge_outputs
         outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
-
+        if valid_length is not None:
+            outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
+                                                     merge_outputs)
         return outputs, states
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32f..61bf24e 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -83,8 +83,7 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
             F = ndarray
             batch_size = inputs[0].shape[batch_axis]
         if merge is True:
-            inputs = [F.expand_dims(i, axis=axis) for i in inputs]
-            inputs = F.concat(*inputs, dim=axis)
+            inputs = F.stack(*inputs, axis=axis)
             in_axis = axis
 
     if isinstance(inputs, tensor_types) and axis != in_axis:
@@ -92,6 +91,16 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
 
     return inputs, axis, F, batch_size
 
+def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, merge):
+    assert valid_length is not None
+    if not isinstance(data, tensor_types):
+        data = F.stack(*data, axis=time_axis)
+    outputs = F.SequenceMask(data, sequence_length=valid_length, use_sequence_length=True,
+                             axis=time_axis)
+    if not merge:
+        outputs = _as_list(F.split(outputs, num_outputs=length, axis=time_axis,
+                                   squeeze_axis=True))
+    return outputs
 
 class RecurrentCell(Block):
     """Abstract base class for RNN cells
@@ -163,7 +172,8 @@ class RecurrentCell(Block):
             states.append(state)
         return states
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
         """Unrolls an RNN cell across time steps.
 
         Parameters
@@ -193,6 +203,15 @@ class RecurrentCell(Block):
             (batch_size, length, ...) if layout is 'NTC',
             or (length, batch_size, ...) if layout is 'TNC'.
             If `None`, output whatever is faster.
+        valid_length : Symbol, NDArray or None
+            `valid_length` specifies the length of the sequences in the batch without padding.
+            This option is especially useful for building sequence-to-sequence models where
+            the input and output sequences would potentially be padded.
+            If `valid_length` is None, all sequences are assumed to have the same length.
+            If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+            The ith element will be the length of the ith sequence in the batch.
+            The last valid state will be return and the padded outputs will be masked with 0.
+            Note that `valid_length` must be smaller or equal to `length`.
 
         Returns
         -------
@@ -207,15 +226,24 @@ class RecurrentCell(Block):
         """
         self.reset()
 
-        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, False)
+        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
         begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
 
         states = begin_state
         outputs = []
+        all_states = []
         for i in range(length):
             output, states = self(inputs[i], states)
             outputs.append(output)
-
+            if valid_length is not None:
+                all_states.append(states)
+        if valid_length is not None:
+            states = [F.SequenceLast(F.stack(*ele_list, axis=0),
+                                     sequence_length=valid_length,
+                                     use_sequence_length=True,
+                                     axis=0)
+                      for ele_list in zip(*all_states)]
+            outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, True)
         outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
 
         return outputs, states
@@ -645,7 +673,8 @@ class SequentialRNNCell(RecurrentCell):
             next_states.append(state)
         return inputs, sum(next_states, [])
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
@@ -658,8 +687,10 @@ class SequentialRNNCell(RecurrentCell):
             n = len(cell.state_info())
             states = begin_state[p:p+n]
             p += n
-            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, layout=layout,
-                                         merge_outputs=None if i < num_cells-1 else merge_outputs)
+            inputs, states = cell.unroll(length, inputs=inputs, begin_state=states,
+                                         layout=layout,
+                                         merge_outputs=None if i < num_cells-1 else merge_outputs,
+                                         valid_length=valid_length)
             next_states.extend(states)
 
         return inputs, next_states
@@ -713,7 +744,8 @@ class DropoutCell(HybridRecurrentCell):
             inputs = F.Dropout(data=inputs, p=self.rate, name='t%d_fwd'%self._counter)
         return inputs, states
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
@@ -722,7 +754,7 @@ class DropoutCell(HybridRecurrentCell):
         else:
             return super(DropoutCell, self).unroll(
                 length, inputs, begin_state=begin_state, layout=layout,
-                merge_outputs=merge_outputs)
+                merge_outputs=merge_outputs, valid_length=None)
 
 
 class ModifierCell(HybridRecurrentCell):
@@ -827,17 +859,23 @@ class ResidualCell(ModifierCell):
         output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter)
         return output, states
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         self.base_cell._modified = False
         outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state,
-                                                layout=layout, merge_outputs=merge_outputs)
+                                                layout=layout, merge_outputs=merge_outputs,
+                                                valid_length=valid_length)
         self.base_cell._modified = True
 
         merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
                         merge_outputs
-        inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
+        inputs, axis, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
+        if valid_length is not None:
+            # mask the padded inputs to zero
+            inputs = _mask_sequence_variable_length(F, inputs, length, valid_length, axis,
+                                                    merge_outputs)
         if merge_outputs:
             outputs = F.elemwise_add(outputs, inputs)
         else:
@@ -880,34 +918,57 @@ class BidirectionalCell(HybridRecurrentCell):
             "cell cannot be called directly. Call the modifier cell instead."
         return _cells_begin_state(self._children, **kwargs)
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
+        if valid_length is None:
+            reversed_inputs = list(reversed(inputs))
+        else:
+            reversed_inputs = F.SequenceReverse(F.stack(*inputs, axis=0),
+                                                sequence_length=valid_length,
+                                                use_sequence_length=True)
+            reversed_inputs = _as_list(F.split(reversed_inputs, axis=0, num_outputs=length,
+                                               squeeze_axis=True))
         begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)
 
         states = begin_state
         l_cell, r_cell = self._children
         l_outputs, l_states = l_cell.unroll(length, inputs=inputs,
                                             begin_state=states[:len(l_cell.state_info(batch_size))],
-                                            layout=layout, merge_outputs=merge_outputs)
+                                            layout=layout, merge_outputs=merge_outputs,
+                                            valid_length=valid_length)
         r_outputs, r_states = r_cell.unroll(length,
-                                            inputs=list(reversed(inputs)),
+                                            inputs=reversed_inputs,
                                             begin_state=states[len(l_cell.state_info(batch_size)):],
-                                            layout=layout, merge_outputs=merge_outputs)
-
+                                            layout=layout, merge_outputs=False,
+                                            valid_length=valid_length)
+        if valid_length is None:
+            reversed_r_outputs = list(reversed(r_outputs))
+        else:
+            reversed_r_outputs = F.SequenceReverse(F.stack(*r_outputs, axis=0),
+                                                   sequence_length=valid_length,
+                                                   use_sequence_length=True,
+                                                   axis=0)
+            reversed_r_outputs = _as_list(F.split(reversed_r_outputs, axis=0, num_outputs=length,
+                                                  squeeze_axis=True))
         if merge_outputs is None:
-            merge_outputs = (isinstance(l_outputs, tensor_types)
-                             and isinstance(r_outputs, tensor_types))
+            merge_outputs = isinstance(l_outputs, tensor_types)
             l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs)
-            r_outputs, _, _, _ = _format_sequence(None, r_outputs, layout, merge_outputs)
+            reversed_r_outputs, _, _, _ = _format_sequence(None, reversed_r_outputs, layout,
+                                                           merge_outputs)
 
         if merge_outputs:
-            r_outputs = F.reverse(r_outputs, axis=axis)
-            outputs = F.concat(l_outputs, r_outputs, dim=2, name='%sout'%self._output_prefix)
+            reversed_r_outputs = F.stack(*reversed_r_outputs, axis=axis)
+            outputs = F.concat(l_outputs, reversed_r_outputs, dim=2,
+                               name='%sout'%self._output_prefix)
+
         else:
             outputs = [F.concat(l_o, r_o, dim=1, name='%st%d'%(self._output_prefix, i))
-                       for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed(r_outputs)))]
-
+                       for i, (l_o, r_o) in enumerate(zip(l_outputs, reversed_r_outputs))]
+        if valid_length is not None:
+            outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis,
+                                                     merge_outputs)
         states = l_states + r_states
         return outputs, states
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 2288842..871deeb 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -257,6 +257,7 @@ def check_rnn_layer_forward(layer, inputs, states=None):
     mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
     mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)
 
+
 def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10)))
@@ -274,6 +275,81 @@ def test_rnn_layers():
     with mx.autograd.record():
         net(mx.nd.ones((2, 3, 10))).backward()
 
+
+def test_rnn_unroll_variant_length():
+    # Test for imperative usage
+    cell_list = []
+    for base_cell_class in [gluon.rnn.RNNCell, gluon.rnn.LSTMCell, gluon.rnn.GRUCell]:
+        cell_list.append(base_cell_class(20))
+        cell_list.append(gluon.rnn.BidirectionalCell(
+                         l_cell=base_cell_class(20),
+                         r_cell=base_cell_class(20)))
+        cell_list.append(gluon.contrib.rnn.VariationalDropoutCell(base_cell=base_cell_class(20)))
+    stack_res_rnn_cell = gluon.rnn.SequentialRNNCell()
+    stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
+    stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
+    cell_list.append(stack_res_rnn_cell)
+    batch_size = 4
+    max_length = 10
+    valid_length = [3, 10, 5, 6]
+    valid_length_nd = mx.nd.array(valid_length)
+    for cell in cell_list:
+        cell.collect_params().initialize()
+        cell.hybridize()
+        # Test for NTC layout
+        data_nd = mx.nd.random.normal(0, 1, shape=(batch_size, max_length, 20))
+        outs, states = cell.unroll(length=max_length, inputs=data_nd,
+                                   valid_length=valid_length_nd,
+                                   merge_outputs=True,
+                                   layout='NTC')
+        for i, ele_length in enumerate(valid_length):
+            # Explicitly unroll each sequence and compare the final states and output
+            ele_out, ele_states = cell.unroll(length=ele_length,
+                                              inputs=data_nd[i:(i+1), :ele_length, :],
+                                              merge_outputs=True,
+                                              layout='NTC')
+            assert_allclose(ele_out.asnumpy(), outs[i:(i+1), :ele_length, :].asnumpy(),
+                            atol=1E-4, rtol=1E-4)
+            if ele_length < max_length:
+                # Check the padded outputs are all zero
+                assert_allclose(outs[i:(i+1), ele_length:max_length, :].asnumpy(), 0)
+            for valid_out_state, gt_state in zip(states, ele_states):
+                assert_allclose(valid_out_state[i:(i+1)].asnumpy(), gt_state.asnumpy(),
+                                atol=1E-4, rtol=1E-4)
+
+        # Test for TNC layout
+        data_nd = mx.nd.random.normal(0, 1, shape=(max_length, batch_size, 20))
+        outs, states = cell.unroll(length=max_length, inputs=data_nd,
+                                   valid_length=valid_length_nd,
+                                   layout='TNC')
+        for i, ele_length in enumerate(valid_length):
+            # Explicitly unroll each sequence and compare the final states and output
+            ele_out, ele_states = cell.unroll(length=ele_length,
+                                              inputs=data_nd[:ele_length, i:(i+1), :],
+                                              merge_outputs=True,
+                                              layout='TNC')
+            assert_allclose(ele_out.asnumpy(), outs[:ele_length, i:(i + 1), :].asnumpy(),
+                            atol=1E-4, rtol=1E-4)
+            if ele_length < max_length:
+                # Check the padded outputs are all zero
+                assert_allclose(outs[ele_length:max_length, i:(i+1), :].asnumpy(), 0)
+            for valid_out_state, gt_state in zip(states, ele_states):
+                assert_allclose(valid_out_state[i:(i+1)].asnumpy(), gt_state.asnumpy(),
+                                atol=1E-4, rtol=1E-4)
+    # For symbolic test, we need to make sure that it can be binded and run
+    data = mx.sym.var('data', shape=(4, 10, 2))
+    cell = gluon.rnn.RNNCell(100)
+    valid_length = mx.sym.var('valid_length', shape=(4,))
+    outs, states = cell.unroll(length=10, inputs=data, valid_length=valid_length,
+                               merge_outputs=True, layout='NTC')
+    mod = mx.mod.Module(states[0], data_names=('data', 'valid_length'), label_names=None,
+                        context=mx.cpu())
+    mod.bind(data_shapes=[('data', (4, 10, 2)), ('valid_length', (4,))], label_shapes=None)
+    mod.init_params()
+    mod.forward(mx.io.DataBatch([mx.random.normal(0, 1, (4, 10, 2)), mx.nd.array([3, 6, 10, 2])]))
+    mod.get_outputs()[0].asnumpy()
+
+
 def test_cell_fill_shape():
     cell = gluon.rnn.LSTMCell(10)
     cell.hybridize()

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.