You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/08/02 18:06:14 UTC

[GitHub] szha commented on a change in pull request #7264: gluon conv rnns

szha commented on a change in pull request #7264: gluon conv rnns
URL: https://github.com/apache/incubator-mxnet/pull/7264#discussion_r130952234
 
 

 ##########
 File path: python/mxnet/gluon/rnn/rnn_cell.py
 ##########
 @@ -784,3 +792,378 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
 
         states = l_states + r_states
         return outputs, states
+
+
+class BaseConvRNNCell(HybridRecurrentCell):
+    """Abstract base class for convolutional RNNs"""
+    def __init__(self, hidden_size, conv_layout, input_shape,
+                 h2h_kernel, h2h_dilate,
+                 i2h_kernel, i2h_stride,
+                 i2h_pad, i2h_dilate,
+                 i2h_weight_initializer, h2h_weight_initializer,
+                 i2h_bias_initializer, h2h_bias_initializer,
+                 activation,
+                 prefix=None, params=None):
+        super(BaseConvRNNCell, self).__init__(prefix=prefix, params=params)
+
+        self._hidden_size = hidden_size
+        self._input_shape = input_shape
+        self._conv_layout = conv_layout
+        self.activation = activation
+
+        # Convolution setting
+        self._h2h_kernel = h2h_kernel
+        assert all(k % 2 == 1 for k in self._h2h_kernel), \
+            "Only support odd number, get h2h_kernel= %s" % str(h2h_kernel)
+        self._h2h_pad = tuple([d*(k-1)//2 for d, k in zip(h2h_dilate, h2h_kernel)])
+        self._h2h_dilate = h2h_dilate
+        self._i2h_kernel = i2h_kernel
+        self._i2h_stride = i2h_stride
+        self._i2h_pad = i2h_pad
+        self._i2h_dilate = i2h_dilate
+
+        self._channel_axis = conv_layout.find('C')
+        print(self._channel_axis)
+        in_filters = input_shape[self._channel_axis-1]
+        if self._channel_axis == 1:
+            i2h_param_shape = (hidden_size*self._num_gates, in_filters)+i2h_kernel
+            h2h_param_shape = (hidden_size*self._num_gates, hidden_size)+h2h_kernel
+            self._state_shape = (hidden_size,)+_get_conv_out_size(input_shape[1:],
+                                                                  self._h2h_kernel,
+                                                                  self._h2h_pad,
+                                                                  tuple([1]*len(self._h2h_kernel)),
+                                                                  self._h2h_dilate)
+        else:
+            i2h_param_shape = (hidden_size*self._num_gates,)+i2h_kernel+(in_filters,)
+            h2h_param_shape = (hidden_size*self._num_gates,)+h2h_kernel+(hidden_size,)
+            print(i2h_param_shape)
+            print(h2h_param_shape)
+            self._state_shape = _get_conv_out_size(input_shape[:-1],
+                                                   self._h2h_kernel,
+                                                   self._h2h_pad,
+                                                   tuple([1]*len(self._h2h_kernel)),
+                                                   self._h2h_dilate) + (hidden_size,)
+        print(self._state_shape)
+
+        self.i2h_weight = self.params.get('i2h_weight', shape=i2h_param_shape,
+                                          init=i2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2h_weight = self.params.get('h2h_weight', shape=h2h_param_shape,
+                                          init=h2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size*self._num_gates,),
+                                        init=i2h_bias_initializer,
+                                        allow_deferred_init=True)
+        self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size*self._num_gates,),
+                                        init=h2h_bias_initializer,
+                                        allow_deferred_init=True)
+
+    @property
+    def _num_gates(self):
+        return len(self._gate_names)
+
+    def _conv_forward(self, F, inputs, states,
+                      i2h_weight, h2h_weight, i2h_bias, h2h_bias,
+                      prefix):
+        i2h = F.Convolution(data=inputs,
+                            num_filter=self._hidden_size*self._num_gates,
+                            kernel=self._i2h_kernel,
+                            stride=self._i2h_stride,
+                            pad=self._i2h_pad,
+                            dilate=self._i2h_dilate,
+                            weight=i2h_weight,
+                            bias=i2h_bias,
+                            layout=self._conv_layout,
+                            name=prefix+'i2h')
+        h2h = F.Convolution(data=states[0],
+                            num_filter=self._hidden_size*self._num_gates,
+                            kernel=self._h2h_kernel,
+                            dilate=self._h2h_dilate,
+                            pad=self._h2h_pad,
+                            stride=tuple([1]*len(self._h2h_kernel)),
+                            weight=h2h_weight,
+                            bias=h2h_bias,
+                            layout=self._conv_layout,
+                            name=prefix+'h2h')
+        return i2h, h2h
+
+    def state_info(self, batch_size=0):
+        raise NotImplementedError("BaseConvRNNCell is abstract class for convolutional RNN")
+
+    def hybrid_forward(self, F, inputs, states):
+        raise NotImplementedError("BaseConvRNNCell is abstract class for convolutional RNN")
+
+class ConvRNNCell(BaseConvRNNCell):
+    """Convolutional RNN cells
+
+    Parameters
+    ----------
+    hidden_size : int
+        Number of units in output symbol.
+    conv_layout : str, default 'NCHW'
+        Layout of ConvolutionOp
+    input_shape : tuple, default (3, 224, 224)
+        Input data tensor shape, with the same layout as conv_layout, excluding batch size.
+    h2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in state-to-state transitions.
+    h2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in state-to-state transitions.
+    i2h_kernel : tuple of int, default (3, 3)
+        Kernel of Convolution operator in input-to-state transitions.
+    i2h_stride : tuple of int, default (1, 1)
+        Stride of Convolution operator in input-to-state transitions.
+    i2h_pad : tuple of int, default (1, 1)
+        Pad of Convolution operator in input-to-state transitions.
+    i2h_dilate : tuple of int, default (1, 1)
+        Dilation of Convolution operator in input-to-state transitions.
+    i2h_weight_initializer : str or Initializer
+        Initializer for the input weights matrix, used for the convolution
+        transformation of the inputs.
+    h2h_weight_initializer : str or Initializer
+        Initializer for the recurrent weights matrix, used for the convolution
+        transformation of the recurrent state.
+    i2h_bias_initializer : str or Initializer, default zeros
+        Initializer for the bias vector.
+    h2h_bias_initializer : str or Initializer, default zeros
+        Initializer for the bias vector.
+    activation : str or Block
+        default LeakyReLU(alpha=0.2)
+        Type of activation function.
+    prefix : str, default 'conv_rnn_'
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    """
+    def __init__(self, hidden_size, conv_layout='NCHW', input_shape=(3, 224, 224),
 
 Review comment:
   How do I defer the shape_info initialization? It depends on input.
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services