You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/17 07:42:35 UTC

[GitHub] ceiba-w commented on a change in pull request #6832: Add ConvRNN Cell, ConvLSTM Cell

ceiba-w commented on a change in pull request #6832: Add ConvRNN Cell, ConvLSTM Cell
URL: https://github.com/apache/incubator-mxnet/pull/6832#discussion_r157361173
 
 

 ##########
 File path: python/mxnet/rnn/rnn_cell.py
 ##########
 @@ -1068,3 +1069,348 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
 
         states = [l_states, r_states]
         return outputs, states
+
+
+class BaseConvRNNCell(BaseRNNCell):
+    """Abstract base class for Convolutional RNN cells
+
+    Parameters
+    ----------
+    input_shape : tuple of int
+        Shape of input in single timestep.
+    num_hidden : int
+        Number of units in output symbol.
+    h2h_kernel : tuple of int
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int
+        Dilation of Convolution operator in input-to-state transitions.
+    activation : str or Symbol,
+        Type of activation function.
+    prefix : str, default ''
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    conv_layout : str, , default 'NCHW'
+        Layout of ConvolutionOp
+    """
+    def __init__(self, input_shape, num_hidden,
+                 h2h_kernel, h2h_dilate,
+                 i2h_kernel, i2h_stride,
+                 i2h_pad, i2h_dilate,
+                 activation,
+                 prefix='', params=None, conv_layout='NCHW'):
+        super(BaseConvRNNCell, self).__init__(prefix=prefix, params=params)
+        # Convolution setting
+        self._h2h_kernel = h2h_kernel
+        assert (self._h2h_kernel[0] % 2 == 1) and (self._h2h_kernel[1] % 2 == 1), \
+            "Only support odd number, get h2h_kernel= %s" % str(h2h_kernel)
+        self._h2h_pad = (h2h_dilate[0] * (h2h_kernel[0] - 1) // 2,
+                         h2h_dilate[1] * (h2h_kernel[1] - 1) // 2)
+        self._h2h_dilate = h2h_dilate
+        self._i2h_kernel = i2h_kernel
+        self._i2h_stride = i2h_stride
+        self._i2h_pad = i2h_pad
+        self._i2h_dilate = i2h_dilate
+
+        self._num_hidden = num_hidden
+        self._input_shape = input_shape
+        self._conv_layout = conv_layout
+        self._activation = activation
+
+        # Infer state shape
+        data = symbol.Variable('data')
+        self._state_shape = symbol.Convolution(data=data,
+                                               num_filter=self._num_hidden,
+                                               kernel=self._i2h_kernel,
+                                               stride=self._i2h_stride,
+                                               pad=self._i2h_pad,
+                                               dilate=self._i2h_dilate,
+                                               layout=conv_layout)
+        self._state_shape = self._state_shape.infer_shape(data=input_shape)[1][0]
+        self._state_shape = (0, ) + self._state_shape[1:]
+
+    @property
+    def state_info(self):
+        return [{'shape': self._state_shape, '__layout__': self._conv_layout},
+                {'shape': self._state_shape, '__layout__': self._conv_layout}]
+
+    def __call__(self, inputs, states):
+        raise NotImplementedError("BaseConvRNNCell is abstract class for convolutional RNN")
+
+class ConvRNNCell(BaseConvRNNCell):
+    """Convolutional RNN cells
 
 Review comment:
   Not only the ConvLSTMCell but  also LSTMCell has removed the contribution of C_{t-1} in the input gate, forget gate and output gate. 
   
   But I add the Hadamard product about C_{t-1} for personal use.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services