You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/16 02:23:32 UTC
[incubator-mxnet] branch master updated: [MXNET-766] add
dynamic_unroll RNN for HybridBlock (#11948)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4b527b6 [MXNET-766] add dynamic_unroll RNN for HybridBlock (#11948)
4b527b6 is described below
commit 4b527b681f01bff02a770c1b49a3b5c9278cf4a0
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Fri Feb 15 18:23:01 2019 -0800
[MXNET-766] add dynamic_unroll RNN for HybridBlock (#11948)
* add contrib unroll.
* reenable some tests.
* fix a bug.
* fix lint.
* fix a bug.
* support diff layouts.
* update doc.
* use a diff default layout.
* remove _contrib_format_sequence.
* fix lint.
* rename.
---
python/mxnet/gluon/contrib/rnn/rnn_cell.py | 115 ++++++++++++++++++++++++++++
tests/python/unittest/test_gluon_contrib.py | 95 ++++++++++++++++++++++-
2 files changed, 209 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index 0cbc9ea..3bd8e78 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -22,6 +22,7 @@ __all__ = ['VariationalDropoutCell', 'LSTMPCell']
from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
from ... import tensor_types
+from ....base import _as_list
class VariationalDropoutCell(ModifierCell):
"""
@@ -320,3 +321,117 @@ class LSTMPCell(HybridRecurrentCell):
return next_r, [next_r, next_c]
# pylint: enable= arguments-differ
+
+
+def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
+ layout='TNC', valid_length=None):
+ """Unrolls an RNN cell across time steps.
+
+ Currently, 'TNC' is a preferred layout. unroll on the input of this layout
+ runs much faster.
+
+ Parameters
+ ----------
+ cell : an object whose base class is RNNCell.
+ The RNN cell to run on the input sequence.
+ inputs : Symbol
+ It should have shape (batch_size, length, ...) if `layout` is 'NTC',
+ or (length, batch_size, ...) if `layout` is 'TNC'.
+ begin_state : nested list of Symbol
+ The initial states of the RNN sequence.
+ drop_inputs : float, default 0.
+ The dropout rate for inputs. Won't apply dropout if it equals 0.
+ drop_outputs : float, default 0.
+ The dropout rate for outputs. Won't apply dropout if it equals 0.
+ layout : str, optional
+ `layout` of input symbol. Only used if inputs
+ is a single Symbol.
+ valid_length : Symbol, NDArray or None
+ `valid_length` specifies the length of the sequences in the batch without padding.
+ This option is especially useful for building sequence-to-sequence models where
+ the input and output sequences would potentially be padded.
+ If `valid_length` is None, all sequences are assumed to have the same length.
+ If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
+ The ith element will be the length of the ith sequence in the batch.
+ The last valid state will be return and the padded outputs will be masked with 0.
+ Note that `valid_length` must be smaller or equal to `length`.
+
+ Returns
+ -------
+ outputs : Symbol
+ the output of the RNN from this unrolling.
+
+ states : list of Symbol
+ The new state of this RNN after this unrolling.
+ The type of this symbol is same as the output of `begin_state`.
+
+ Examples
+ --------
+ >>> seq_len = 3
+ >>> batch_size = 2
+ >>> input_size = 5
+ >>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_')
+ >>> cell.initialize(ctx=mx.cpu())
+ >>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+ >>> state_shape = (batch_size, input_size)
+ >>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)]
+ >>> valid_length = mx.nd.array([2, 3])
+ >>> output, states = mx.gluon.contrib.rnn.rnn_cell.dynamic_unroll(cell, rnn_data, states,
+ valid_length=valid_length,
+ layout='TNC')
+ >>> print(output)
+ [[[ 0.00767238 0.00023103 0.03973929 -0.00925503 -0.05660512]
+ [ 0.00881535 0.05428379 -0.02493718 -0.01834097 0.02189514]]
+ [[-0.00676967 0.01447039 0.01287002 -0.00574152 -0.05734247]
+ [ 0.01568508 0.02650866 -0.04270559 -0.04328435 0.00904011]]
+ [[ 0. 0. 0. 0. 0. ]
+ [ 0.01055336 0.02734251 -0.03153727 -0.03742751 -0.01378113]]]
+ <NDArray 3x2x5 @cpu(0)>
+ """
+
+ # Merge is always True, so we don't need length.
+ inputs, axis, F, _ = _format_sequence(0, inputs, layout, True)
+ if axis != 0:
+ axes = list(range(len(layout)))
+ tmp = axes[0]
+ axes[0] = axes[axis]
+ axes[axis] = tmp
+ inputs = F.transpose(inputs, axes=axes)
+ states = begin_state
+
+ if drop_inputs:
+ inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,))
+
+ if valid_length is None:
+ def loop_body(inputs, states):
+ return cell(inputs, states)
+ else:
+ zeros = []
+ for s in states:
+ zeros.append(F.zeros_like(s))
+ states = list(_as_list(states))
+ states.append(F.zeros((1)))
+ def loop_body(inputs, states):
+ cell_states = states[:-1]
+ iter_no = states[-1]
+ out, new_states = cell(inputs, cell_states)
+ for i, state in enumerate(cell_states):
+ new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no),
+ new_states[i], state)
+ new_states.append(iter_no + 1)
+ return out, new_states
+
+ outputs, states = F.contrib.foreach(loop_body, inputs, states)
+ if drop_outputs:
+ outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,))
+ if valid_length is not None:
+ if axis != 0:
+ outputs = F.transpose(outputs, axes)
+ outputs = F.SequenceMask(outputs, sequence_length=valid_length,
+ use_sequence_length=True, axis=axis)
+ # the last state is the iteration number. We don't need it.
+ return outputs, states[:-1]
+ else:
+ if axis != 0:
+ outputs = F.transpose(outputs, axes)
+ return outputs, states
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 6901e8b..1e05559 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -17,12 +17,14 @@
from __future__ import print_function
import mxnet as mx
+import copy
+from mxnet import gluon
from mxnet.gluon import contrib
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import (
Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D,
PixelShuffle2D, PixelShuffle3D)
-from mxnet.test_utils import almost_equal
+from mxnet.test_utils import almost_equal, default_context, assert_almost_equal
from common import setup_module, with_seed, teardown
import numpy as np
from numpy.testing import assert_allclose
@@ -313,6 +315,97 @@ def test_sampler():
assert list(interval_sampler) == [0, 3, 6, 9]
+class TestRNNLayer(gluon.HybridBlock):
+ def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None):
+ super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+ self.cell = cell_type(hidden_size, prefix='rnn_')
+ self.layout = layout
+
+ def hybrid_forward(self, F, inputs, states, valid_length):
+ if isinstance(valid_length, list) and len(valid_length) == 0:
+ valid_length = None
+ return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states,
+ valid_length=valid_length,
+ layout=self.layout)
+
+def check_unroll(cell_type, num_states, layout):
+ batch_size = 20
+ input_size = 50
+ hidden_size = 30
+ seq_len = 10
+ if layout == 'TNC':
+ rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
+ elif layout == 'NTC':
+ rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size))
+ else:
+ print("Wrong layout")
+ return
+ valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size)))
+ state_shape = (batch_size, hidden_size)
+ states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]
+
+ cell = cell_type(hidden_size, prefix='rnn_')
+ cell.initialize(ctx=default_context())
+ if layout == 'TNC':
+ cell(rnn_data[0], states)
+ else:
+ cell(rnn_data[:,0,:], states)
+ params1 = cell.collect_params()
+ orig_params1 = copy.deepcopy(params1)
+
+ trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
+ with mx.autograd.record():
+ res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
+ layout=layout, merge_outputs=True)
+ res1.backward()
+ trainer.step(batch_size)
+
+ configs = [
+ lambda layer: None,
+ lambda layer: layer.hybridize(),
+ lambda layer: layer.hybridize({'inline_limit': 0}),
+ lambda layer: layer.hybridize({'static_alloc': True}),
+ lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ]
+ # We can't pass None to a hybrid block, but it accepts an empty list.
+ # so we use an empty list to represent valid_length if it's None.
+ if valid_length is None:
+ valid_length = []
+ for config in configs:
+ layer = TestRNNLayer(cell_type, hidden_size, layout)
+ layer.initialize(ctx=default_context())
+ config(layer)
+ res2, states2 = layer(rnn_data, states, valid_length)
+ params2 = layer.collect_params()
+ for key, val in orig_params1.items():
+ params2[key].set_data(copy.deepcopy(val.data()))
+
+ trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
+ with mx.autograd.record():
+ res2, states2 = layer(rnn_data, states, valid_length)
+ assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+ assert len(states1) == len(states2)
+ for i in range(len(states1)):
+ assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(),
+ rtol=0.001, atol=0.0001)
+ res2.backward()
+ trainer.step(batch_size)
+
+ for key, val in params1.items():
+ weight1 = val.data()
+ weight2 = params2[key].data()
+ assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(),
+ rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_contrib_unroll():
+ cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
+ (gluon.rnn.GRUCell, 1)]
+ for cell_type, num_states in cell_types:
+ check_unroll(cell_type, num_states, 'TNC')
+ check_unroll(cell_type, num_states, 'NTC')
+
+
if __name__ == '__main__':
import nose
nose.runmodule()