You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/30 19:54:03 UTC

[GitHub] piiswrong closed pull request #8512: gluon rnn refactor

piiswrong closed pull request #8512: gluon rnn refactor
URL: https://github.com/apache/incubator-mxnet/pull/8512
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
index 09db5470ef..d0c8daa3ba 100644
--- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
@@ -26,6 +26,8 @@
 from math import floor
 
 from ....base import numeric_types
+from .... import symbol
+from ...nn import HybridLambda
 from ...rnn import HybridRecurrentCell
 
 
@@ -50,6 +52,8 @@ def __init__(self, input_shape, hidden_channels,
         self._input_shape = input_shape
         self._conv_layout = conv_layout
         self._activation = activation
+        with self.name_scope():
+            self.activation = HybridLambda(activation, prefix='')
 
         # Convolution setting
         assert all(isinstance(spec, int) or len(spec) == dims
@@ -210,8 +214,10 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
         i2h, h2h = self._conv_forward(F, inputs, states,
                                       i2h_weight, h2h_weight, i2h_bias, h2h_bias,
                                       prefix)
-        output = self._get_activation(F, i2h + h2h, self._activation,
-                                      name=prefix+'out')
+        if F is symbol:
+            output = self.activation(i2h + h2h, prefix+'out')
+        else:
+            output = self.activation(i2h + h2h)
         return output, [output]
 
 
@@ -455,17 +461,17 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
         i2h, h2h = self._conv_forward(F, inputs, states,
                                       i2h_weight, h2h_weight, i2h_bias, h2h_bias,
                                       prefix)
-        gates = i2h + h2h
-        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice',
-                                     axis=self._channel_axis)
-        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
-        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
-        in_transform = self._get_activation(F, slice_gates[2], self._activation, name=prefix+'c')
-        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
-        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
-                                   name=prefix+'state')
-        next_h = F._internal._mul(out_gate, self._get_activation(F, next_c, self._activation),
-                                  name=prefix+'out')
+        i, f, c, o = (i2h + h2h).split(4, axis=self._channel_axis)
+        i = i.sigmoid(name=prefix+'i')
+        f = f.sigmoid(name=prefix+'f')
+        o = o.sigmoid(name=prefix+'o')
+        if F is symbol:
+            c = self.activation(c, prefix+'c')
+        else:
+            c = self.activation(c)
+
+        next_c = F._internal._plus(f * states[1], i * c, name=prefix+'state')
+        next_h = F._internal._mul(o, self.activation(next_c), name=prefix+'out')
 
         return next_h, [next_h, next_c]
 
@@ -738,20 +744,22 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
                                       i2h_weight, h2h_weight, i2h_bias, h2h_bias,
                                       prefix)
 
-        i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3,
-                                           name=prefix+'i2h_slice',
-                                           axis=self._channel_axis)
-        h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
-                                           name=prefix+'h2h_slice',
-                                           axis=self._channel_axis)
+        i2h_r, i2h_z, i2h = i2h.split(num_outputs=3,
+                                      name=prefix+'i2h_slice',
+                                      axis=self._channel_axis)
+        h2h_r, h2h_z, h2h = h2h.split(num_outputs=3,
+                                      name=prefix+'h2h_slice',
+                                      axis=self._channel_axis)
 
         reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
                                   name=prefix+'r_act')
         update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
                                    name=prefix+'z_act')
 
-        next_h_tmp = self._get_activation(F, i2h + reset_gate * h2h, self._activation,
-                                          name=prefix+'h_act')
+        if F is symbol:
+            next_h_tmp = self.activation(i2h + reset_gate * h2h, prefix+'h_act')
+        else:
+            next_h_tmp = self.activation(i2h + reset_gate * h2h)
 
         next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * states[0],
                                    name=prefix+'out')
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 80bb8e3fb8..6b48547b0e 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -31,7 +31,7 @@
 from ..block import Block, HybridBlock
 from ..utils import _indent
 from .. import tensor_types
-from ..nn import LeakyReLU
+from ..nn import Dense
 
 
 def _cells_state_info(cells, batch_size):
@@ -116,7 +116,8 @@ def reset(self):
         self._init_counter = -1
         self._counter = -1
         for cell in self._children:
-            cell.reset()
+            if hasattr(cell, 'reset') and callable(cell.reset):
+                cell.reset()
 
     def state_info(self, batch_size=0):
         """shape and layout information of states"""
@@ -270,7 +271,40 @@ def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
 
 
-class RNNCell(HybridRecurrentCell):
+# pylint: disable=abstract-method
+class _GatedRecurrentCell(HybridRecurrentCell):
+    def __init__(self, hidden_size,
+                 i2h_weight_initializer, h2h_weight_initializer,
+                 i2h_bias_initializer, h2h_bias_initializer,
+                 input_size, num_gates, num_recurrent_channels, alias,
+                 prefix, params):
+        super(_GatedRecurrentCell, self).__init__(prefix=prefix, params=params)
+        self._hidden_size = hidden_size
+        self._input_size = input_size
+        self._num_gates = num_gates
+        self._num_recurrent_channels = num_recurrent_channels
+        self._alias = alias
+        with self.name_scope():
+            self.i2h = Dense(num_gates*hidden_size,
+                             in_units=input_size,
+                             weight_initializer=i2h_weight_initializer,
+                             bias_initializer=i2h_bias_initializer,
+                             prefix='i2h_')
+            self.h2h = Dense(num_gates*hidden_size,
+                             in_units=hidden_size,
+                             weight_initializer=h2h_weight_initializer,
+                             bias_initializer=h2h_bias_initializer,
+                             prefix='h2h_')
+
+    def state_info(self, batch_size=0):
+        return [{'shape': (batch_size, self._hidden_size),
+                 '__layout__': 'NC'}] * self._num_recurrent_channels
+
+    def gate_forward(self, inputs, states):
+        return (self.i2h(inputs) + self.h2h(states[0])).split(self._num_gates)
+
+
+class RNNCell(_GatedRecurrentCell):
     r"""Elman RNN recurrent neural network cell.
 
     Each call computes the following function:
@@ -287,7 +321,7 @@ class RNNCell(HybridRecurrentCell):
     ----------
     hidden_size : int
         Number of units in output symbol
-    activation : str or Symbol, default 'tanh'
+    activation : str, HybridBlock, default 'tanh'
         Type of activation function.
     i2h_weight_initializer : str or Initializer
         Initializer for the input weights matrix, used for the linear
@@ -321,56 +355,33 @@ def __init__(self, hidden_size, activation='tanh',
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  input_size=0, prefix=None, params=None):
-        super(RNNCell, self).__init__(prefix=prefix, params=params)
-        self._hidden_size = hidden_size
+        super(RNNCell, self).__init__(hidden_size,
+                                      i2h_weight_initializer, h2h_weight_initializer,
+                                      i2h_bias_initializer, h2h_bias_initializer,
+                                      input_size, 1, 1, 'rnn', prefix, params)
         self._activation = activation
-        self._input_size = input_size
-        self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
-                                          dtype=None, init=i2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
-                                          dtype=None, init=h2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
-                                        dtype=None, init=i2h_bias_initializer,
-                                        allow_deferred_init=True)
-        self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
-                                        dtype=None, init=h2h_bias_initializer,
-                                        allow_deferred_init=True)
-
-    def state_info(self, batch_size=0):
-        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
-
-    def _alias(self):
-        return 'rnn'
 
     def __repr__(self):
         s = '{name}({mapping}'
         if hasattr(self, '_activation'):
             s += ', {_activation}'
         s += ')'
-        shape = self.i2h_weight.shape
+        shape = self.i2h.weight.shape
         mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **self.__dict__)
 
-    def hybrid_forward(self, F, inputs, states, i2h_weight,
-                       h2h_weight, i2h_bias, h2h_bias):
+    def hybrid_forward(self, F, inputs, states):
         prefix = 't%d_'%self._counter
-        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
-                               num_hidden=self._hidden_size,
-                               name=prefix+'i2h')
-        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
-                               num_hidden=self._hidden_size,
-                               name=prefix+'h2h')
-        output = self._get_activation(F, i2h + h2h, self._activation,
+        output = self._get_activation(F, self.gate_forward(inputs, states),
+                                      self._activation,
                                       name=prefix+'out')
 
         return output, [output]
 
 
-class LSTMCell(HybridRecurrentCell):
+class LSTMCell(_GatedRecurrentCell):
     r"""Long-Short Term Memory (LSTM) network cell.
 
     Each call computes the following function:
@@ -429,60 +440,35 @@ def __init__(self, hidden_size,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  input_size=0, prefix=None, params=None):
-        super(LSTMCell, self).__init__(prefix=prefix, params=params)
-
-        self._hidden_size = hidden_size
-        self._input_size = input_size
-        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
-                                          dtype=None, init=i2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
-                                          dtype=None, init=h2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
-                                        dtype=None, init=i2h_bias_initializer,
-                                        allow_deferred_init=True)
-        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
-                                        dtype=None, init=h2h_bias_initializer,
-                                        allow_deferred_init=True)
-
-    def state_info(self, batch_size=0):
-        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
-                {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
-
-    def _alias(self):
-        return 'lstm'
+        super(LSTMCell, self).__init__(hidden_size,
+                                       i2h_weight_initializer, h2h_weight_initializer,
+                                       i2h_bias_initializer, h2h_bias_initializer,
+                                       input_size, 4, 2, 'lstm', prefix, params)
 
     def __repr__(self):
         s = '{name}({mapping})'
-        shape = self.i2h_weight.shape
+        shape = self.i2h.weight.shape
         mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **self.__dict__)
 
-    def hybrid_forward(self, F, inputs, states, i2h_weight,
-                       h2h_weight, i2h_bias, h2h_bias):
+    def hybrid_forward(self, F, inputs, states):
         prefix = 't%d_'%self._counter
-        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
-                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
-        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
-                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
-        gates = i2h + h2h
-        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
-        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
-        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
-        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
-        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
-        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
+        i, f, c, o = self.gate_forward(inputs, states)
+        i = i.sigmoid(name=prefix+'i')
+        f = f.sigmoid(name=prefix+'f')
+        c = c.tanh(name=prefix+'c')
+        o = o.sigmoid(name=prefix+'o')
+        next_c = F._internal._plus(f * states[1], i * c,
                                    name=prefix+'state')
-        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
+        next_h = F._internal._mul(o, next_c.tanh(),
                                   name=prefix+'out')
 
         return next_h, [next_h, next_c]
 
 
-class GRUCell(HybridRecurrentCell):
+class GRUCell(_GatedRecurrentCell):
     r"""Gated Rectified Unit (GRU) network cell.
     Note: this is an implementation of the cuDNN version of GRUs
     (slight modification compared to Cho et al. 2014).
@@ -537,64 +523,30 @@ def __init__(self, hidden_size,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  input_size=0, prefix=None, params=None):
-        super(GRUCell, self).__init__(prefix=prefix, params=params)
-        self._hidden_size = hidden_size
-        self._input_size = input_size
-        self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
-                                          dtype=None, init=i2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
-                                          dtype=None, init=h2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
-                                        dtype=None, init=i2h_bias_initializer,
-                                        allow_deferred_init=True)
-        self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
-                                        dtype=None, init=h2h_bias_initializer,
-                                        allow_deferred_init=True)
-
-    def state_info(self, batch_size=0):
-        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]
-
-    def _alias(self):
-        return 'gru'
+        super(GRUCell, self).__init__(hidden_size,
+                                      i2h_weight_initializer, h2h_weight_initializer,
+                                      i2h_bias_initializer, h2h_bias_initializer,
+                                      input_size, 3, 1, 'gru', prefix, params)
 
     def __repr__(self):
         s = '{name}({mapping})'
-        shape = self.i2h_weight.shape
+        shape = self.i2h.weight.shape
         mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **self.__dict__)
 
-    def hybrid_forward(self, F, inputs, states, i2h_weight,
-                       h2h_weight, i2h_bias, h2h_bias):
-        # pylint: disable=too-many-locals
+    def hybrid_forward(self, F, inputs, states):
         prefix = 't%d_'%self._counter
         prev_state_h = states[0]
-        i2h = F.FullyConnected(data=inputs,
-                               weight=i2h_weight,
-                               bias=i2h_bias,
-                               num_hidden=self._hidden_size * 3,
-                               name=prefix+'i2h')
-        h2h = F.FullyConnected(data=prev_state_h,
-                               weight=h2h_weight,
-                               bias=h2h_bias,
-                               num_hidden=self._hidden_size * 3,
-                               name=prefix+'h2h')
-
-        i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3,
-                                           name=prefix+'i2h_slice')
-        h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
-                                           name=prefix+'h2h_slice')
-
-        reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
-                                  name=prefix+'r_act')
-        update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
-                                   name=prefix+'z_act')
-
-        next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
-                                  name=prefix+'h_act')
+
+        i2h_r, i2h_z, i2h = self.i2h(inputs).split(3, name=prefix+'i2h_slice')
+        h2h_r, h2h_z, h2h = self.h2h(prev_state_h).split(3, name=prefix+'h2h_slice')
+
+        reset_gate = (i2h_r+h2h_r).sigmoid(name=prefix+'r_act')
+        update_gate = (i2h_z+h2h_z).sigmoid(name=prefix+'z_act')
+
+        next_h_tmp = (i2h + reset_gate * h2h).tanh(name=prefix+'h_act')
 
         next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
                                    name=prefix+'out')
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 2288842192..ac38cc5349 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -25,6 +25,7 @@
 
 def test_rnn():
     cell = gluon.rnn.RNNCell(100, prefix='rnn_')
+    print(cell)
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
     outputs, _ = cell.unroll(3, inputs)
     outputs = mx.sym.Group(outputs)
@@ -37,6 +38,7 @@ def test_rnn():
 
 def test_lstm():
     cell = gluon.rnn.LSTMCell(100, prefix='rnn_')
+    print(cell)
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
     outputs, _ = cell.unroll(3, inputs)
     outputs = mx.sym.Group(outputs)
@@ -52,6 +54,7 @@ def test_lstm_forget_bias():
     stack = gluon.rnn.SequentialRNNCell()
     stack.add(gluon.rnn.LSTMCell(100, i2h_bias_initializer=mx.init.LSTMBias(forget_bias), prefix='l0_'))
     stack.add(gluon.rnn.LSTMCell(100, i2h_bias_initializer=mx.init.LSTMBias(forget_bias), prefix='l1_'))
+    print(stack)
 
     dshape = (32, 1, 200)
     data = mx.sym.Variable('data')
@@ -70,6 +73,7 @@ def test_lstm_forget_bias():
 
 def test_gru():
     cell = gluon.rnn.GRUCell(100, prefix='rnn_')
+    print(cell)
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
     outputs, _ = cell.unroll(3, inputs)
     outputs = mx.sym.Group(outputs)
@@ -278,7 +282,7 @@ def test_cell_fill_shape():
     cell = gluon.rnn.LSTMCell(10)
     cell.hybridize()
     check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
-    assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
+    assert cell.i2h.weight.shape[1] == 7, cell.i2h.weight.shape[1]
 
 def test_layer_fill_shape():
     layer = gluon.rnn.LSTM(10)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services