You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/29 21:33:26 UTC

[GitHub] zheng-da closed pull request #11948: [MXNET-766] add unroll RNN for HybridBlock

zheng-da closed pull request #11948: [MXNET-766] add unroll RNN for HybridBlock
URL: https://github.com/apache/incubator-mxnet/pull/11948
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index 1b9afee14bf..d04ec0baa9a 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -22,6 +22,7 @@
 from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
 from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
 from ... import tensor_types
+from ....base import _as_list
 
 class VariationalDropoutCell(ModifierCell):
     """
@@ -315,3 +316,117 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
 
         return next_r, [next_r, next_c]
     # pylint: enable= arguments-differ
+
+
+def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
+           layout='TNC', valid_length=None):
+    """Unrolls an RNN cell across time steps.
+
+    Currently, 'TNC' is a preferred layout. unroll on the input of this layout
+    runs much faster.
+
+    Parameters
+    ----------
+    cell : an object whose base class is RNNCell.
+        The RNN cell to run on the input sequence.
+    inputs : Symbol
+        It should have shape (batch_size, length, ...) if `layout` is 'NTC',
+        or (length, batch_size, ...) if `layout` is 'TNC'.
+    begin_state : nested list of Symbol
+        The initial states of the RNN sequence.
+    drop_inputs : float, default 0.
+        The dropout rate for inputs. Won't apply dropout if it equals 0.
+    drop_outputs : float, default 0.
+        The dropout rate for outputs. Won't apply dropout if it equals 0.
+    layout : str, optional
+        `layout` of input symbol. Only used if inputs
+        is a single Symbol.
+    valid_length : Symbol, NDArray or None
+        `valid_length` specifies the length of the sequences in the batch without padding.
+        This option is especially useful for building sequence-to-sequence models where
+        the input and output sequences would potentially be padded.
+        If `valid_length` is None, all sequences are assumed to have the same length.
+        If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+        The ith element will be the length of the ith sequence in the batch.
+        The last valid state will be return and the padded outputs will be masked with 0.
+        Note that `valid_length` must be smaller or equal to `length`.
+
+    Returns
+    -------
+    outputs : Symbol
+        the output of the RNN from this unrolling.
+
+    states : list of Symbol
+        The new state of this RNN after this unrolling.
+        The type of this symbol is same as the output of `begin_state`.
+
+    Examples
+    --------
+    >>> seq_len = 3
+    >>> batch_size = 2
+    >>> input_size = 5
+    >>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_')
+    >>> cell.initialize(ctx=mx.cpu())
+    >>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+    >>> state_shape = (batch_size, input_size)
+    >>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)]
+    >>> valid_length = mx.nd.array([2, 3])
+    >>> output, states = mx.gluon.contrib.rnn.rnn_cell.unroll(cell, rnn_data, states,
+                                                              valid_length=valid_length,
+                                                              layout='TNC')
+    >>> print(output)
+    [[[ 0.00767238  0.00023103  0.03973929 -0.00925503 -0.05660512]
+      [ 0.00881535  0.05428379 -0.02493718 -0.01834097  0.02189514]]
+     [[-0.00676967  0.01447039  0.01287002 -0.00574152 -0.05734247]
+      [ 0.01568508  0.02650866 -0.04270559 -0.04328435  0.00904011]]
+     [[ 0.          0.          0.          0.          0.        ]
+      [ 0.01055336  0.02734251 -0.03153727 -0.03742751 -0.01378113]]]
+     <NDArray 3x2x5 @cpu(0)>
+    """
+
+    # Merge is always True, so we don't need length.
+    inputs, axis, F, _ = _format_sequence(0, inputs, layout, True)
+    if axis != 0:
+        axes = list(range(len(layout)))
+        tmp = axes[0]
+        axes[0] = axes[axis]
+        axes[axis] = tmp
+        inputs = F.transpose(inputs, axes=axes)
+    states = begin_state
+
+    if drop_inputs:
+        inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,))
+
+    if valid_length is None:
+        def loop_body(inputs, states):
+            return cell(inputs, states)
+    else:
+        zeros = []
+        for s in states:
+            zeros.append(F.zeros_like(s))
+        states = list(_as_list(states))
+        states.append(F.zeros((1)))
+        def loop_body(inputs, states):
+            cell_states = states[:-1]
+            iter_no = states[-1]
+            out, new_states = cell(inputs, cell_states)
+            for i, state in enumerate(cell_states):
+                new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no),
+                                        new_states[i], state)
+            new_states.append(iter_no + 1)
+            return out, new_states
+
+    outputs, states = F.contrib.foreach(loop_body, inputs, states)
+    if drop_outputs:
+        outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,))
+    if valid_length is not None:
+        if axis != 0:
+            outputs = F.transpose(outputs, axes)
+        outputs = F.SequenceMask(outputs, sequence_length=valid_length,
+                                 use_sequence_length=True, axis=axis)
+        # the last state is the iteration number. We don't need it.
+        return outputs, states[:-1]
+    else:
+        if axis != 0:
+            outputs = F.transpose(outputs, axes)
+        return outputs, states
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index a1cd8ea537d..91789bc09c1 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -17,10 +17,12 @@
 
 from __future__ import print_function
 import mxnet as mx
+import copy
+from mxnet import gluon
 from mxnet.gluon import contrib
 from mxnet.gluon import nn
 from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
-from mxnet.test_utils import almost_equal
+from mxnet.test_utils import almost_equal, default_context, assert_almost_equal
 from common import setup_module, with_seed, teardown
 import numpy as np
 from numpy.testing import assert_allclose
@@ -228,6 +230,96 @@ def test_sampler():
     assert list(interval_sampler) == [0, 3, 6, 9]
 
 
+class TestRNNLayer(gluon.HybridBlock):
+    def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None):
+        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+        self.cell = cell_type(hidden_size, prefix='rnn_')
+        self.layout = layout
+
+    def hybrid_forward(self, F, inputs, states, valid_length):
+        if isinstance(valid_length, list) and len(valid_length) == 0:
+            valid_length = None
+        return contrib.rnn.rnn_cell.unroll(self.cell, inputs, states,
+                                           valid_length=valid_length, layout=self.layout)
+
+def check_unroll(cell_type, num_states, layout):
+    batch_size = 20
+    input_size = 50
+    hidden_size = 30
+    seq_len = 10
+    if layout == 'TNC':
+        rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+    elif layout == 'NTC':
+        rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size))
+    else:
+        print("Wrong layout")
+        return
+    valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size)))
+    state_shape = (batch_size, hidden_size)
+    states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]
+
+    cell = cell_type(hidden_size, prefix='rnn_')
+    cell.initialize(ctx=default_context())
+    if layout == 'TNC':
+        cell(rnn_data[0], states)
+    else:
+        cell(rnn_data[:,0,:], states)
+    params1 = cell.collect_params()
+    orig_params1 = copy.deepcopy(params1)
+
+    trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
+    with mx.autograd.record():
+        res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
+                                    layout=layout, merge_outputs=True)
+    res1.backward()
+    trainer.step(batch_size)
+
+    configs = [
+            lambda layer: None,
+            lambda layer: layer.hybridize(),
+            lambda layer: layer.hybridize({'inline_limit': 0}),
+            lambda layer: layer.hybridize({'static_alloc': True}),
+            lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ]
+    # We can't pass None to a hybrid block, but it accepts an empty list.
+    # so we use an empty list to represent valid_length if it's None.
+    if valid_length is None:
+        valid_length = []
+    for config in configs:
+        layer = TestRNNLayer(cell_type, hidden_size, layout)
+        layer.initialize(ctx=default_context())
+        config(layer)
+        res2, states2 = layer(rnn_data, states, valid_length)
+        params2 = layer.collect_params()
+        for key, val in orig_params1.items():
+            params2[key].set_data(copy.deepcopy(val.data()))
+
+        trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
+        with mx.autograd.record():
+            res2, states2 = layer(rnn_data, states, valid_length)
+        assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+        assert len(states1) == len(states2)
+        for i in range(len(states1)):
+            assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(),
+                                rtol=0.001, atol=0.0001)
+        res2.backward()
+        trainer.step(batch_size)
+
+        for key, val in params1.items():
+            weight1 = val.data()
+            weight2 = params2[key].data()
+            assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(),
+                    rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_contrib_unroll():
+    cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
+            (gluon.rnn.GRUCell, 1)]
+    for cell_type, num_states in cell_types:
+        check_unroll(cell_type, num_states, 'TNC')
+        check_unroll(cell_type, num_states, 'NTC')
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services