You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/16 02:23:32 UTC

[incubator-mxnet] branch master updated: [MXNET-766] add dynamic_unroll RNN for HybridBlock (#11948)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b527b6  [MXNET-766] add dynamic_unroll RNN for HybridBlock (#11948)
4b527b6 is described below

commit 4b527b681f01bff02a770c1b49a3b5c9278cf4a0
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Fri Feb 15 18:23:01 2019 -0800

    [MXNET-766] add dynamic_unroll RNN for HybridBlock (#11948)
    
    * add contrib unroll.
    
    * reenable some tests.
    
    * fix a bug.
    
    * fix lint.
    
    * fix a bug.
    
    * support diff layouts.
    
    * update doc.
    
    * use a diff default layout.
    
    * remove _contrib_format_sequence.
    
    * fix lint.
    
    * rename.
---
 python/mxnet/gluon/contrib/rnn/rnn_cell.py  | 115 ++++++++++++++++++++++++++++
 tests/python/unittest/test_gluon_contrib.py |  95 ++++++++++++++++++++++-
 2 files changed, 209 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index 0cbc9ea..3bd8e78 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -22,6 +22,7 @@ __all__ = ['VariationalDropoutCell', 'LSTMPCell']
 from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
 from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
 from ... import tensor_types
+from ....base import _as_list
 
 class VariationalDropoutCell(ModifierCell):
     """
@@ -320,3 +321,117 @@ class LSTMPCell(HybridRecurrentCell):
 
         return next_r, [next_r, next_c]
     # pylint: enable= arguments-differ
+
+
+def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
+                   layout='TNC', valid_length=None):
+    """Unrolls an RNN cell across time steps.
+
+    Currently, 'TNC' is a preferred layout. unroll on the input of this layout
+    runs much faster.
+
+    Parameters
+    ----------
+    cell : an object whose base class is RNNCell.
+        The RNN cell to run on the input sequence.
+    inputs : Symbol
+        It should have shape (batch_size, length, ...) if `layout` is 'NTC',
+        or (length, batch_size, ...) if `layout` is 'TNC'.
+    begin_state : nested list of Symbol
+        The initial states of the RNN sequence.
+    drop_inputs : float, default 0.
+        The dropout rate for inputs. Won't apply dropout if it equals 0.
+    drop_outputs : float, default 0.
+        The dropout rate for outputs. Won't apply dropout if it equals 0.
+    layout : str, optional
+        `layout` of input symbol. Only used if inputs
+        is a single Symbol.
+    valid_length : Symbol, NDArray or None
+        `valid_length` specifies the length of the sequences in the batch without padding.
+        This option is especially useful for building sequence-to-sequence models where
+        the input and output sequences would potentially be padded.
+        If `valid_length` is None, all sequences are assumed to have the same length.
+        If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+        The ith element will be the length of the ith sequence in the batch.
+        The last valid state will be return and the padded outputs will be masked with 0.
+        Note that `valid_length` must be smaller or equal to `length`.
+
+    Returns
+    -------
+    outputs : Symbol
+        the output of the RNN from this unrolling.
+
+    states : list of Symbol
+        The new state of this RNN after this unrolling.
+        The type of this symbol is same as the output of `begin_state`.
+
+    Examples
+    --------
+    >>> seq_len = 3
+    >>> batch_size = 2
+    >>> input_size = 5
+    >>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_')
+    >>> cell.initialize(ctx=mx.cpu())
+    >>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+    >>> state_shape = (batch_size, input_size)
+    >>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)]
+    >>> valid_length = mx.nd.array([2, 3])
+    >>> output, states = mx.gluon.contrib.rnn.rnn_cell.dynamic_unroll(cell, rnn_data, states,
+                                                                      valid_length=valid_length,
+                                                                      layout='TNC')
+    >>> print(output)
+    [[[ 0.00767238  0.00023103  0.03973929 -0.00925503 -0.05660512]
+      [ 0.00881535  0.05428379 -0.02493718 -0.01834097  0.02189514]]
+     [[-0.00676967  0.01447039  0.01287002 -0.00574152 -0.05734247]
+      [ 0.01568508  0.02650866 -0.04270559 -0.04328435  0.00904011]]
+     [[ 0.          0.          0.          0.          0.        ]
+      [ 0.01055336  0.02734251 -0.03153727 -0.03742751 -0.01378113]]]
+     <NDArray 3x2x5 @cpu(0)>
+    """
+
+    # Merge is always True, so we don't need length.
+    inputs, axis, F, _ = _format_sequence(0, inputs, layout, True)
+    if axis != 0:
+        axes = list(range(len(layout)))
+        tmp = axes[0]
+        axes[0] = axes[axis]
+        axes[axis] = tmp
+        inputs = F.transpose(inputs, axes=axes)
+    states = begin_state
+
+    if drop_inputs:
+        inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,))
+
+    if valid_length is None:
+        def loop_body(inputs, states):
+            return cell(inputs, states)
+    else:
+        zeros = []
+        for s in states:
+            zeros.append(F.zeros_like(s))
+        states = list(_as_list(states))
+        states.append(F.zeros((1)))
+        def loop_body(inputs, states):
+            cell_states = states[:-1]
+            iter_no = states[-1]
+            out, new_states = cell(inputs, cell_states)
+            for i, state in enumerate(cell_states):
+                new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no),
+                                        new_states[i], state)
+            new_states.append(iter_no + 1)
+            return out, new_states
+
+    outputs, states = F.contrib.foreach(loop_body, inputs, states)
+    if drop_outputs:
+        outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,))
+    if valid_length is not None:
+        if axis != 0:
+            outputs = F.transpose(outputs, axes)
+        outputs = F.SequenceMask(outputs, sequence_length=valid_length,
+                                 use_sequence_length=True, axis=axis)
+        # the last state is the iteration number. We don't need it.
+        return outputs, states[:-1]
+    else:
+        if axis != 0:
+            outputs = F.transpose(outputs, axes)
+        return outputs, states
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 6901e8b..1e05559 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -17,12 +17,14 @@
 
 from __future__ import print_function
 import mxnet as mx
+import copy
+from mxnet import gluon
 from mxnet.gluon import contrib
 from mxnet.gluon import nn
 from mxnet.gluon.contrib.nn import (
     Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D,
     PixelShuffle2D, PixelShuffle3D)
-from mxnet.test_utils import almost_equal
+from mxnet.test_utils import almost_equal, default_context, assert_almost_equal
 from common import setup_module, with_seed, teardown
 import numpy as np
 from numpy.testing import assert_allclose
@@ -313,6 +315,97 @@ def test_sampler():
     assert list(interval_sampler) == [0, 3, 6, 9]
 
 
+class TestRNNLayer(gluon.HybridBlock):
+    def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None):
+        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+        self.cell = cell_type(hidden_size, prefix='rnn_')
+        self.layout = layout
+
+    def hybrid_forward(self, F, inputs, states, valid_length):
+        if isinstance(valid_length, list) and len(valid_length) == 0:
+            valid_length = None
+        return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states,
+                                                   valid_length=valid_length,
+                                                   layout=self.layout)
+
+def check_unroll(cell_type, num_states, layout):
+    batch_size = 20
+    input_size = 50
+    hidden_size = 30
+    seq_len = 10
+    if layout == 'TNC':
+        rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+    elif layout == 'NTC':
+        rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size))
+    else:
+        print("Wrong layout")
+        return
+    valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size)))
+    state_shape = (batch_size, hidden_size)
+    states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]
+
+    cell = cell_type(hidden_size, prefix='rnn_')
+    cell.initialize(ctx=default_context())
+    if layout == 'TNC':
+        cell(rnn_data[0], states)
+    else:
+        cell(rnn_data[:,0,:], states)
+    params1 = cell.collect_params()
+    orig_params1 = copy.deepcopy(params1)
+
+    trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
+    with mx.autograd.record():
+        res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
+                                    layout=layout, merge_outputs=True)
+    res1.backward()
+    trainer.step(batch_size)
+
+    configs = [
+            lambda layer: None,
+            lambda layer: layer.hybridize(),
+            lambda layer: layer.hybridize({'inline_limit': 0}),
+            lambda layer: layer.hybridize({'static_alloc': True}),
+            lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ]
+    # We can't pass None to a hybrid block, but it accepts an empty list.
+    # so we use an empty list to represent valid_length if it's None.
+    if valid_length is None:
+        valid_length = []
+    for config in configs:
+        layer = TestRNNLayer(cell_type, hidden_size, layout)
+        layer.initialize(ctx=default_context())
+        config(layer)
+        res2, states2 = layer(rnn_data, states, valid_length)
+        params2 = layer.collect_params()
+        for key, val in orig_params1.items():
+            params2[key].set_data(copy.deepcopy(val.data()))
+
+        trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
+        with mx.autograd.record():
+            res2, states2 = layer(rnn_data, states, valid_length)
+        assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+        assert len(states1) == len(states2)
+        for i in range(len(states1)):
+            assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(),
+                                rtol=0.001, atol=0.0001)
+        res2.backward()
+        trainer.step(batch_size)
+
+        for key, val in params1.items():
+            weight1 = val.data()
+            weight2 = params2[key].data()
+            assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(),
+                    rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_contrib_unroll():
+    cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
+            (gluon.rnn.GRUCell, 1)]
+    for cell_type, num_states in cell_types:
+        check_unroll(cell_type, num_states, 'TNC')
+        check_unroll(cell_type, num_states, 'NTC')
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()