You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:53 UTC

[40/50] [abbrv] incubator-mxnet-test git commit: Add ConvRNN Cell, ConvLSTM Cell (#6832)

Add ConvRNN Cell, ConvLSTM Cell (#6832)

* Add ConvLSTM cell

* Fix lint

* Fix typo

* Add activation parameters to ConvLSTM and ConvGRU

* Add leaky relu to activation options

* Change defaut params

* Remove h2h_pad

* Fix python3 compatibility bug

* Fix wrong padding

* Add base class for Conv RNN


Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/de5b0fe4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/de5b0fe4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/de5b0fe4

Branch: refs/heads/master
Commit: de5b0fe4e584958600551ac151aa6c03755face8
Parents: a21cab7
Author: Xu Dong <ds...@gmail.com>
Authored: Wed Jul 5 10:40:37 2017 +0800
Committer: Eric Junyuan Xie <pi...@users.noreply.github.com>
Committed: Tue Jul 4 19:40:37 2017 -0700

----------------------------------------------------------------------
 python/mxnet/rnn/rnn_cell.py      | 346 +++++++++++++++++++++++++++++++++
 tests/python/unittest/test_rnn.py |  47 +++++
 2 files changed, 393 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/de5b0fe4/python/mxnet/rnn/rnn_cell.py
----------------------------------------------------------------------
diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py
index d0505f8..320f781 100644
--- a/python/mxnet/rnn/rnn_cell.py
+++ b/python/mxnet/rnn/rnn_cell.py
@@ -6,6 +6,7 @@
 from __future__ import print_function
 
 import warnings
+import functools
 
 from .. import symbol, init, ndarray
 from ..base import string_types, numeric_types
@@ -1068,3 +1069,348 @@ class BidirectionalCell(BaseRNNCell):
 
         states = [l_states, r_states]
         return outputs, states
+
+
+class BaseConvRNNCell(BaseRNNCell):
+    """Abstract base class for Convolutional RNN cells
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol,
+        Type of activation function.
+    prefix : str, default ''
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    conv_layout : str, , default 'NCHW'
+        Layout of ConvolutionOp
+    """
+    def __init__(self, input_shape, num_hidden,
+                 h2h_kernel, h2h_dilate,
+                 i2h_kernel, i2h_stride,
+                 i2h_pad, i2h_dilate,
+                 activation,
+                 prefix='', params=None, conv_layout='NCHW'):
+        super(BaseConvRNNCell, self).__init__(prefix=prefix, params=params)
+        # Convolution setting
+        self._h2h_kernel = h2h_kernel
+        assert (self._h2h_kernel[0] % 2 == 1) and (self._h2h_kernel[1] % 2 == 1), \
+            "Only support odd number, get h2h_kernel= %s" % str(h2h_kernel)
+        self._h2h_pad = (h2h_dilate[0] * (h2h_kernel[0] - 1) // 2,
+                         h2h_dilate[1] * (h2h_kernel[1] - 1) // 2)
+        self._h2h_dilate = h2h_dilate
+        self._i2h_kernel = i2h_kernel
+        self._i2h_stride = i2h_stride
+        self._i2h_pad = i2h_pad
+        self._i2h_dilate = i2h_dilate
+
+        self._num_hidden = num_hidden
+        self._input_shape = input_shape
+        self._conv_layout = conv_layout
+        self._activation = activation
+
+        # Infer state shape
+        data = symbol.Variable('data')
+        self._state_shape = symbol.Convolution(data=data,
+                                               num_filter=self._num_hidden,
+                                               kernel=self._i2h_kernel,
+                                               stride=self._i2h_stride,
+                                               pad=self._i2h_pad,
+                                               dilate=self._i2h_dilate,
+                                               layout=conv_layout)
+        self._state_shape = self._state_shape.infer_shape(data=input_shape)[1][0]
+        self._state_shape = (0, ) + self._state_shape[1:]
+
+    @property
+    def state_info(self):
+        return [{'shape': self._state_shape, '__layout__': self._conv_layout},
+                {'shape': self._state_shape, '__layout__': self._conv_layout}]
+
+    def __call__(self, inputs, states):
+        raise NotImplementedError("BaseConvRNNCell is abstract class for convolutional RNN")
+
+class ConvRNNCell(BaseConvRNNCell):
+    """Convolutional RNN cells
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int, default (1, 1)
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol,
+        default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2)
+        Type of activation function.
+    prefix : str, default 'ConvRNN_'
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    conv_layout : str, , default 'NCHW'
+        Layout of ConvolutionOp
+    """
+    def __init__(self, input_shape, num_hidden,
+                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
+                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
+                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
+                 prefix='ConvRNN_', params=None, conv_layout='NCHW'):
+        super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
+                                          h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
+                                          i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
+                                          i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
+                                          activation=activation, prefix=prefix,
+                                          params=params, conv_layout=conv_layout)
+        # Get params
+        self._iW = self.params.get('i2h_weight')
+        self._hW = self.params.get('h2h_weight')
+        self._iB = self.params.get('i2h_bias')
+        self._hB = self.params.get('h2h_bias')
+
+    @property
+    def _gate_names(self):
+        return ('',)
+
+    def __call__(self, inputs, states):
+        self._counter += 1
+        name = '%st%d_'%(self._prefix, self._counter)
+        i2h = symbol.Convolution(name='%si2h'%name,
+                                 data=inputs,
+                                 num_filter=self._num_hidden,
+                                 kernel=self._i2h_kernel,
+                                 stride=self._i2h_stride,
+                                 pad=self._i2h_pad,
+                                 dilate=self._i2h_dilate,
+                                 weight=self._iW,
+                                 bias=self._iB,)
+        h2h = symbol.Convolution(name='%sh2h'%name,
+                                 data=states[0],
+                                 num_filter=self._num_hidden,
+                                 kernel=self._h2h_kernel,
+                                 dilate=self._h2h_dilate,
+                                 pad=self._h2h_pad,
+                                 stride=(1, 1),
+                                 weight=self._hW,
+                                 bias=self._hB)
+        output = self._get_activation(i2h + h2h, self._activation,
+                                      name='%sout'%name)
+        return output, [output]
+
+
+class ConvLSTMCell(BaseConvRNNCell):
+    """Convolutional LSTM network cell.
+
+    Reference:
+        Xingjian et al. NIPS2015
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int, default (1, 1)
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol
+        default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2)
+        Type of activation function.
+    prefix : str, default 'ConvLSTM_'
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    forget_bias : bias added to forget gate, default 1.0.
+        Jozefowicz et al. 2015 recommends setting this to 1.0
+    conv_layout : str, , default 'NCHW'
+        Layout of ConvolutionOp
+    """
+    def __init__(self, input_shape, num_hidden,
+                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
+                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
+                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
+                 prefix='ConvLSTM_', params=None, forget_bias=1.0,
+                 conv_layout='NCHW'):
+        super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
+                                           h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
+                                           i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
+                                           i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
+                                           activation=activation, prefix=prefix,
+                                           params=params, conv_layout=conv_layout)
+
+        # Get params
+        self._iW = self.params.get('i2h_weight')
+        self._hW = self.params.get('h2h_weight')
+        # we add the forget_bias to i2h_bias, this adds the bias to the forget gate activation
+        self._iB = self.params.get('i2h_bias', init=init.LSTMBias(forget_bias=forget_bias))
+        self._hB = self.params.get('h2h_bias')
+
+    @property
+    def _gate_names(self):
+        return ['_i', '_f', '_c', '_o']
+
+    def __call__(self, inputs, states):
+        self._counter += 1
+        name = '%st%d_'%(self._prefix, self._counter)
+        i2h = symbol.Convolution(name='%si2h'%name,
+                                 data=inputs,
+                                 num_filter=self._num_hidden*4,
+                                 kernel=self._i2h_kernel,
+                                 stride=self._i2h_stride,
+                                 pad=self._i2h_pad,
+                                 dilate=self._i2h_dilate,
+                                 weight=self._iW,
+                                 bias=self._iB,)
+        h2h = symbol.Convolution(name='%sh2h'%name,
+                                 data=states[0],
+                                 num_filter=self._num_hidden*4,
+                                 kernel=self._h2h_kernel,
+                                 dilate=self._h2h_dilate,
+                                 pad=self._h2h_pad,
+                                 stride=(1, 1),
+                                 weight=self._hW,
+                                 bias=self._hB)
+
+        gates = i2h + h2h
+        slice_gates = symbol.SliceChannel(gates, num_outputs=4, axis=self._conv_layout.find('C'),
+                                          name="%sslice"%name)
+        in_gate = symbol.Activation(slice_gates[0], act_type="sigmoid",
+                                    name='%si'%name)
+        forget_gate = symbol.Activation(slice_gates[1], act_type="sigmoid",
+                                        name='%sf'%name)
+        in_transform = self._get_activation(slice_gates[2], self._activation,
+                                            name='%sc'%name)
+        out_gate = symbol.Activation(slice_gates[3], act_type="sigmoid",
+                                     name='%so'%name)
+        next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform,
+                                        name='%sstate'%name)
+        next_h = symbol._internal._mul(out_gate, self._get_activation(next_c, self._activation),
+                                       name='%sout'%name)
+
+        return next_h, [next_h, next_c]
+
+class ConvGRUCell(BaseConvRNNCell):
+    """Convolutional Gated Rectified Unit (GRU) network cell.
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int, default (1, 1)
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol,
+        default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2)
+        Type of activation function.
+    prefix : str, default 'ConvGRU_'
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    conv_layout : str, , default 'NCHW'
+        Layout of ConvolutionOp
+    """
+    def __init__(self, input_shape, num_hidden,
+                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
+                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
+                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
+                 prefix='ConvGRU_', params=None, conv_layout='NCHW'):
+        super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
+                                          h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
+                                          i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
+                                          i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
+                                          activation=activation, prefix=prefix,
+                                          params=params, conv_layout=conv_layout)
+        # Get params
+        self._iW = self.params.get('i2h_weight')
+        self._hW = self.params.get('h2h_weight')
+        self._iB = self.params.get('i2h_bias')
+        self._hB = self.params.get('h2h_bias')
+
+    @property
+    def _gate_names(self):
+        return ['_r', '_z', '_o']
+
+    def __call__(self, inputs, states):
+        self._counter += 1
+        seq_idx = self._counter
+        name = '%st%d_' % (self._prefix, seq_idx)
+        i2h = symbol.Convolution(name='%s_i2h'%name, data=inputs,
+                                 num_filter=self._num_hidden * 3,
+                                 kernel=self._i2h_kernel,
+                                 stride=self._i2h_stride,
+                                 pad=self._i2h_pad,
+                                 dilate=self._i2h_dilate,
+                                 weight=self._iW,
+                                 bias=self._iB,)
+        h2h = symbol.Convolution(name='%s_h2h'%name, data=states[0],
+                                 num_filter=self._num_hidden * 3,
+                                 kernel=self._h2h_kernel,
+                                 dilate=self._h2h_dilate,
+                                 pad=self._h2h_pad,
+                                 stride=(1, 1),
+                                 weight=self._hW,
+                                 bias=self._hB)
+
+        i2h_r, i2h_z, i2h = symbol.SliceChannel(i2h, num_outputs=3, name="%s_i2h_slice" % name)
+        h2h_r, h2h_z, h2h = symbol.SliceChannel(h2h, num_outputs=3, name="%s_h2h_slice" % name)
+
+        reset_gate = symbol.Activation(i2h_r + h2h_r, act_type="sigmoid",
+                                       name="%s_r_act" % name)
+        update_gate = symbol.Activation(i2h_z + h2h_z, act_type="sigmoid",
+                                        name="%s_z_act" % name)
+
+        next_h_tmp = self._get_activation(i2h + reset_gate * h2h, self._activation,
+                                          name="%s_h_act" % name)
+
+        next_h = symbol._internal._plus((1. - update_gate) * next_h_tmp, update_gate * states[0],
+                                        name='%sout' % name)
+
+        return next_h, [next_h]

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/de5b0fe4/tests/python/unittest/test_rnn.py
----------------------------------------------------------------------
diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py
index 419104d..6df8452 100644
--- a/tests/python/unittest/test_rnn.py
+++ b/tests/python/unittest/test_rnn.py
@@ -175,6 +175,50 @@ def test_unfuse():
     args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
     assert outs == [(10, 200), (10, 200), (10, 200)]
 
+def test_convrnn():
+    cell = mx.rnn.ConvRNNCell(input_shape = (1, 3, 16, 10), num_hidden=10,
+                              h2h_kernel=(3, 3), h2h_dilate=(1, 1),
+                              i2h_kernel=(3, 3), i2h_stride=(1, 1),
+                              i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                              prefix='rnn_')
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+    assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(1, 3, 16, 10), rnn_t1_data=(1, 3, 16, 10), rnn_t2_data=(1, 3, 16, 10))
+    assert outs == [(1, 10, 16, 10), (1, 10, 16, 10), (1, 10, 16, 10)]
+
+def test_convlstm():
+    cell = mx.rnn.ConvLSTMCell(input_shape = (1, 3, 16, 10), num_hidden=10,
+                               h2h_kernel=(3, 3), h2h_dilate=(1, 1),
+                               i2h_kernel=(3, 3), i2h_stride=(1, 1),
+                               i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                               prefix='rnn_', forget_bias=1.0)
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+    assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(1, 3, 16, 10), rnn_t1_data=(1, 3, 16, 10), rnn_t2_data=(1, 3, 16, 10))
+    assert outs == [(1, 10, 16, 10), (1, 10, 16, 10), (1, 10, 16, 10)]
+
+def test_convgru():
+    cell = mx.rnn.ConvGRUCell(input_shape = (1, 3, 16, 10), num_hidden=10,
+                              h2h_kernel=(3, 3), h2h_dilate=(1, 1),
+                              i2h_kernel=(3, 3), i2h_stride=(1, 1),
+                              i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                              prefix='rnn_')
+    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+    outputs, _ = cell.unroll(3, inputs)
+    outputs = mx.sym.Group(outputs)
+    assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+    assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+
+    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(1, 3, 16, 10), rnn_t1_data=(1, 3, 16, 10), rnn_t2_data=(1, 3, 16, 10))
+    assert outs == [(1, 10, 16, 10), (1, 10, 16, 10), (1, 10, 16, 10)]
 
 if __name__ == '__main__':
     test_rnn()
@@ -184,3 +228,6 @@ if __name__ == '__main__':
     test_stack()
     test_bidirectional()
     test_unfuse()
+    test_convrnn()
+    test_convlstm()
+    test_convgru()