You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/08/04 02:43:04 UTC

[GitHub] szha commented on a change in pull request #6832: Add ConvRNN Cell, ConvLSTM Cell

szha commented on a change in pull request #6832: Add ConvRNN Cell, ConvLSTM Cell
URL: https://github.com/apache/incubator-mxnet/pull/6832#discussion_r131301824
 
 

 ##########
 File path: python/mxnet/rnn/rnn_cell.py
 ##########
 @@ -1068,3 +1068,346 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
 
         states = [l_states, r_states]
         return outputs, states
+
+class ConvRNNCell(BaseRNNCell):
+    """Convolutional RNN cells
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int, default (1, 1)
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol, default 'tanh'
+        Type of activation function. Options are 'relu' and 'tanh'.
+    prefix : str, default 'ConvRNN_'
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    conv_layout : str, , default 'NCHW'
+        Layout of ConvolutionOp
+    """
+    def __init__(self, input_shape, num_hidden,
+                 h2h_kernel=(3, 3), h2h_pad=(1, 1), h2h_dilate=(1, 1),
+                 i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1),
+                 activation='tanh', prefix='ConvRNN_', params=None, conv_layout='NCHW'):
+        super(ConvRNNCell, self).__init__(prefix=prefix, params=params)
+        # Convolution setting
+        self._h2h_kernel = h2h_kernel
+        self._h2h_pad = h2h_pad
+        self._h2h_dilate = h2h_dilate
+        self._i2h_kernel = i2h_kernel
+        self._i2h_stride = i2h_stride
+        self._i2h_pad = i2h_pad
+        self._i2h_dilate = i2h_dilate
+
+        self._num_hidden = num_hidden
+        self._input_shape = input_shape
+        self._conv_layout = conv_layout
+        self._activation = activation
+
+        # Get params
+        self._iW = self.params.get('i2h_weight')
+        self._hW = self.params.get('h2h_weight')
+        self._iB = self.params.get('i2h_bias')
+        self._hB = self.params.get('h2h_bias')
+
+        # Infer state shape
+        data = symbol.Variable('data')
+        self._state_shape = symbol.Convolution(data=data,
+                                               num_filter=self._num_hidden,
+                                               kernel=self._i2h_kernel,
+                                               stride=self._i2h_stride,
+                                               pad=self._i2h_pad,
+                                               dilate=self._i2h_dilate,
+                                               layout=conv_layout)
+        self._state_shape = self._state_shape.infer_shape(data=input_shape)[1][0]
+        self._state_shape = (0, ) + self._state_shape[1:]
+
+    @property
+    def state_info(self):
+        return [{'shape': self._state_shape, '__layout__': self._conv_layout},
+                {'shape': self._state_shape, '__layout__': self._conv_layout}]
+
+    @property
+    def _gate_names(self):
+        return ('',)
+
+    def __call__(self, inputs, states):
+        self._counter += 1
+        name = '%st%d_'%(self._prefix, self._counter)
+        i2h = symbol.Convolution(name='%si2h'%name,
+                                 data=inputs,
+                                 num_filter=self._num_hidden,
+                                 kernel=self._i2h_kernel,
+                                 stride=self._i2h_stride,
+                                 pad=self._i2h_pad,
+                                 dilate=self._i2h_dilate,
+                                 weight=self._iW,
+                                 bias=self._iB,)
+        h2h = symbol.Convolution(name='%sh2h'%name,
+                                 data=states[0],
+                                 num_filter=self._num_hidden,
+                                 kernel=self._h2h_kernel,
+                                 dilate=self._h2h_dilate,
+                                 pad=self._h2h_pad,
+                                 stride=(1, 1),
+                                 weight=self._hW,
+                                 bias=self._hB)
+        output = self._get_activation(i2h + h2h, self._activation,
+                                      name='%sout'%name)
+        return output, [output]
+
+class ConvLSTMCell(BaseRNNCell):
+    """Convolutional LSTM network cell.
+
+    Reference:
+        Xingjian et al. NIPS2015
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int, default (1, 1)
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol, default 'tanh'
+        Type of activation function. Options are 'relu' and 'tanh'.
 
 Review comment:
   @sxjscience thanks. I will set the default to tanh for now in #7264 
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services