You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/05 22:52:17 UTC

[GitHub] szha closed pull request #10957: Make inner transform activation configurable for LSTMCell

szha closed pull request #10957: Make inner transform activation configurable for LSTMCell
URL: https://github.com/apache/incubator-mxnet/pull/10957
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 281aba45257..f318b10812a 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -224,6 +224,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
             The new state of this RNN after this unrolling.
             The type of this symbol is same as the output of `begin_state()`.
         """
+        # pylint: disable=too-many-locals
         self.reset()
 
         inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
@@ -251,12 +252,19 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
     #pylint: disable=no-self-use
     def _get_activation(self, F, inputs, activation, **kwargs):
         """Get activation function. Convert if is string"""
-        if isinstance(activation, string_types):
+        if activation == 'tanh':
+            return F.tanh(inputs, **kwargs)
+        elif activation == 'sigmoid':
+            return F.sigmoid(inputs, **kwargs)
+        elif activation == 'relu':
+            return F.relu(inputs, **kwargs)
+        elif activation == 'softsign':
+            return F.softsign(inputs, **kwargs)
+        elif isinstance(activation, string_types):
             return F.Activation(inputs, act_type=activation, **kwargs)
         elif isinstance(activation, LeakyReLU):
             return F.LeakyReLU(inputs, act_type='leaky', slope=activation._alpha, **kwargs)
-        else:
-            return activation(inputs, **kwargs)
+        return activation(inputs, **kwargs)
 
     def forward(self, inputs, states):
         """Unrolls the recurrent cell for one time step.
@@ -441,7 +449,12 @@ class LSTMCell(HybridRecurrentCell):
     params : Parameter or None
         Container for weight sharing between cells.
         Created if `None`.
-
+    activation : str
+        Activation type to use. See nd/symbol Activation
+        for supported types.
+    recurrent_activation : str
+        Activation type to use for the recurrent step. See nd/symbol Activation
+        for supported types.
 
     Inputs:
         - **data**: input tensor with shape `(batch_size, input_size)`.
@@ -453,10 +466,12 @@ class LSTMCell(HybridRecurrentCell):
         - **next_states**: a list of two output recurrent state tensors. Each has
           the same shape as `states`.
     """
+    # pylint: disable=too-many-instance-attributes
     def __init__(self, hidden_size,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
-                 input_size=0, prefix=None, params=None):
+                 input_size=0, prefix=None, params=None, activation='tanh',
+                 recurrent_activation='sigmoid'):
         super(LSTMCell, self).__init__(prefix=prefix, params=params)
 
         self._hidden_size = hidden_size
@@ -473,6 +488,9 @@ def __init__(self, hidden_size,
         self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
                                         init=h2h_bias_initializer,
                                         allow_deferred_init=True)
+        self._activation = activation
+        self._recurrent_activation = recurrent_activation
+
 
     def state_info(self, batch_size=0):
         return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
@@ -491,6 +509,7 @@ def __repr__(self):
 
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
+        # pylint: disable=too-many-locals
         prefix = 't%d_'%self._counter
         i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
                                num_hidden=self._hidden_size*4, name=prefix+'i2h')
@@ -498,13 +517,17 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
                                num_hidden=self._hidden_size*4, name=prefix+'h2h')
         gates = i2h + h2h
         slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
-        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
-        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
-        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
-        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
+        in_gate = self._get_activation(
+            F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
+        forget_gate = self._get_activation(
+            F, slice_gates[1], self._recurrent_activation, name=prefix+'f')
+        in_transform = self._get_activation(
+            F, slice_gates[2], self._activation, name=prefix+'c')
+        out_gate = self._get_activation(
+            F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
         next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
                                    name=prefix+'state')
-        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
+        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation),
                                   name=prefix+'out')
 
         return next_h, [next_h, next_c]
@@ -675,6 +698,7 @@ def __call__(self, inputs, states):
 
     def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
                valid_length=None):
+        # pylint: disable=too-many-locals
         self.reset()
 
         inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
@@ -702,6 +726,7 @@ def __len__(self):
         return len(self._children)
 
     def hybrid_forward(self, *args, **kwargs):
+        # pylint: disable=missing-docstring
         raise NotImplementedError
 
 
@@ -755,10 +780,9 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
         inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
         if isinstance(inputs, tensor_types):
             return self.hybrid_forward(F, inputs, begin_state if begin_state else [])
-        else:
-            return super(DropoutCell, self).unroll(
-                length, inputs, begin_state=begin_state, layout=layout,
-                merge_outputs=merge_outputs, valid_length=None)
+        return super(DropoutCell, self).unroll(
+            length, inputs, begin_state=begin_state, layout=layout,
+            merge_outputs=merge_outputs, valid_length=None)
 
 
 class ModifierCell(HybridRecurrentCell):
@@ -856,6 +880,7 @@ class ResidualCell(ModifierCell):
     """
 
     def __init__(self, base_cell):
+        # pylint: disable=useless-super-delegation
         super(ResidualCell, self).__init__(base_cell)
 
     def hybrid_forward(self, F, inputs, states):
@@ -924,6 +949,7 @@ def begin_state(self, **kwargs):
 
     def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
                valid_length=None):
+        # pylint: disable=too-many-locals
         self.reset()
 
         inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py
index 9fe22ae72df..52a3dcf9934 100644
--- a/tests/python/unittest/test_rnn.py
+++ b/tests/python/unittest/test_rnn.py
@@ -92,15 +92,19 @@ def test_rnn():
 
 
 def test_lstm():
-    cell = mx.rnn.LSTMCell(100, prefix='rnn_', forget_bias=1.0)
-    inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
-    outputs, _ = cell.unroll(3, inputs)
-    outputs = mx.sym.Group(outputs)
-    assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
-    assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
-
-    args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
-    assert outs == [(10, 100), (10, 100), (10, 100)]
+    for activation_type in ['', 'relu', 'sigmoid', 'softrelu', 'tanh', 'softsign']:
+        if activation_type == '':
+            cell = mx.gluon.rnn.LSTMCell(100, prefix='rnn_')
+        else:
+            cell = mx.gluon.rnn.LSTMCell(100, prefix='rnn_', activation=activation_type, recurrent_activation=activation_type)
+        inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+        outputs, _ = cell.unroll(3, inputs)
+        outputs = mx.sym.Group(outputs)
+        assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+        assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
+
+        args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+        assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
 def test_lstm_forget_bias():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services