You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/12 22:31:34 UTC

[GitHub] piiswrong closed pull request #8945: Add param 'num_filter' for 'BaseConvRNNCell'.

piiswrong closed pull request #8945: Add param 'num_filter' for 'BaseConvRNNCell'.
URL: https://github.com/apache/incubator-mxnet/pull/8945
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py
index 3301102ba9..7b08c9bf1e 100644
--- a/python/mxnet/rnn/rnn_cell.py
+++ b/python/mxnet/rnn/rnn_cell.py
@@ -1093,7 +1093,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
 
 class BaseConvRNNCell(BaseRNNCell):
     """Abstract base class for Convolutional RNN cells"""
-    def __init__(self, input_shape, num_hidden,
+    def __init__(self, input_shape, num_hidden,num_filter,
                  h2h_kernel, h2h_dilate,
                  i2h_kernel, i2h_stride,
                  i2h_pad, i2h_dilate,
@@ -1112,6 +1112,7 @@ def __init__(self, input_shape, num_hidden,
         self._i2h_stride = i2h_stride
         self._i2h_pad = i2h_pad
         self._i2h_dilate = i2h_dilate
+        self._num_filter = num_filter
 
         self._num_hidden = num_hidden
         self._input_shape = input_shape
@@ -1121,7 +1122,7 @@ def __init__(self, input_shape, num_hidden,
         # Infer state shape
         data = symbol.Variable('data')
         self._state_shape = symbol.Convolution(data=data,
-                                               num_filter=self._num_hidden,
+                                               num_filter=self._num_filter,
                                                kernel=self._i2h_kernel,
                                                stride=self._i2h_stride,
                                                pad=self._i2h_pad,
@@ -1149,7 +1150,7 @@ def _conv_forward(self, inputs, states, name):
 
         i2h = symbol.Convolution(name='%si2h'%name,
                                  data=inputs,
-                                 num_filter=self._num_hidden*self._num_gates,
+                                 num_filter=self._num_filter*self._num_gates,
                                  kernel=self._i2h_kernel,
                                  stride=self._i2h_stride,
                                  pad=self._i2h_pad,
@@ -1160,7 +1161,7 @@ def _conv_forward(self, inputs, states, name):
 
         h2h = symbol.Convolution(name='%sh2h'%name,
                                  data=states[0],
-                                 num_filter=self._num_hidden*self._num_gates,
+                                 num_filter=self._num_filter*self._num_gates,
                                  kernel=self._h2h_kernel,
                                  dilate=self._h2h_dilate,
                                  pad=self._h2h_pad,
@@ -1182,6 +1183,8 @@ class ConvRNNCell(BaseConvRNNCell):
         Shape of input in single timestep.
     num_hidden : int
         Number of units in output symbol.
+    num_filter : int 
+        Number of Convolution filter(channel)
     h2h_kernel : tuple of int, default (3, 3)
         Kernel of Convolution operator in state-to-state transitions.
     h2h_dilate : tuple of int, default (1, 1)
@@ -1214,7 +1217,7 @@ class ConvRNNCell(BaseConvRNNCell):
     conv_layout : str, , default 'NCHW'
         Layout of ConvolutionOp
     """
-    def __init__(self, input_shape, num_hidden,
+    def __init__(self, input_shape, num_hidden,num_filter,
                  h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                  i2h_kernel=(3, 3), i2h_stride=(1, 1),
                  i2h_pad=(1, 1), i2h_dilate=(1, 1),
@@ -1222,7 +1225,7 @@ def __init__(self, input_shape, num_hidden,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                  prefix='ConvRNN_', params=None, conv_layout='NCHW'):
-        super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
+        super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,num_filter=num_filter,
                                           h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                           i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                           i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
@@ -1262,6 +1265,8 @@ class ConvLSTMCell(BaseConvRNNCell):
         Shape of input in single timestep.
     num_hidden : int
         Number of units in output symbol.
+    num_filter : int 
+        Number of Convolution filter(channel)
     h2h_kernel : tuple of int, default (3, 3)
         Kernel of Convolution operator in state-to-state transitions.
     h2h_dilate : tuple of int, default (1, 1)
@@ -1294,7 +1299,7 @@ class ConvLSTMCell(BaseConvRNNCell):
     conv_layout : str, , default 'NCHW'
         Layout of ConvolutionOp
     """
-    def __init__(self, input_shape, num_hidden,
+    def __init__(self, input_shape, num_hidden,num_filter,
                  h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                  i2h_kernel=(3, 3), i2h_stride=(1, 1),
                  i2h_pad=(1, 1), i2h_dilate=(1, 1),
@@ -1303,7 +1308,7 @@ def __init__(self, input_shape, num_hidden,
                  activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                  prefix='ConvLSTM_', params=None,
                  conv_layout='NCHW'):
-        super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
+        super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,num_filter=num_filter,
                                            h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                            i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                            i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
@@ -1354,6 +1359,8 @@ class ConvGRUCell(BaseConvRNNCell):
         Shape of input in single timestep.
     num_hidden : int
         Number of units in output symbol.
+    num_filter : int 
+        Number of Convolution filter(channel)
     h2h_kernel : tuple of int, default (3, 3)
         Kernel of Convolution operator in state-to-state transitions.
     h2h_dilate : tuple of int, default (1, 1)
@@ -1386,7 +1393,7 @@ class ConvGRUCell(BaseConvRNNCell):
     conv_layout : str, , default 'NCHW'
         Layout of ConvolutionOp
     """
-    def __init__(self, input_shape, num_hidden,
+    def __init__(self, input_shape, num_hidden,num_filter,
                  h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                  i2h_kernel=(3, 3), i2h_stride=(1, 1),
                  i2h_pad=(1, 1), i2h_dilate=(1, 1),
@@ -1394,7 +1401,7 @@ def __init__(self, input_shape, num_hidden,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                  prefix='ConvGRU_', params=None, conv_layout='NCHW'):
-        super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
+        super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,num_filter=num_filter,
                                           h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                           i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                           i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services