You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/10/15 20:44:26 UTC

[incubator-mxnet] branch master updated: fixed symbols naming in RNNCell, LSTMCell, GRUCell (#12794)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b89a36d  fixed symbols naming in RNNCell, LSTMCell, GRUCell (#12794)
b89a36d is described below

commit b89a36d94b5b694b8fd926e6249f7490b38432f6
Author: Lorenzo Stella <lo...@users.noreply.github.com>
AuthorDate: Mon Oct 15 22:44:06 2018 +0200

    fixed symbols naming in RNNCell, LSTMCell, GRUCell (#12794)
    
    * fixed symbols naming in RNNCell and LSTMCell
    
    * fixed GRUCell as well
    
    * added test
    
    * fixed tests?
---
 python/mxnet/gluon/rnn/rnn_cell.py      | 25 +++++++++++------
 tests/python/unittest/test_gluon_rnn.py | 48 +++++++++++++++++++++++++++++++++
 2 files changed, 65 insertions(+), 8 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 557837c..0f16a89 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -398,7 +398,8 @@ class RNNCell(HybridRecurrentCell):
         h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
                                num_hidden=self._hidden_size,
                                name=prefix+'h2h')
-        output = self._get_activation(F, i2h + h2h, self._activation,
+        i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
+        output = self._get_activation(F, i2h_plus_h2h, self._activation,
                                       name=prefix+'out')
 
         return output, [output]
@@ -511,7 +512,7 @@ class LSTMCell(HybridRecurrentCell):
                                num_hidden=self._hidden_size*4, name=prefix+'i2h')
         h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
                                num_hidden=self._hidden_size*4, name=prefix+'h2h')
-        gates = i2h + h2h
+        gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
         slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
         in_gate = self._get_activation(
             F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
@@ -521,9 +522,10 @@ class LSTMCell(HybridRecurrentCell):
             F, slice_gates[2], self._activation, name=prefix+'c')
         out_gate = self._get_activation(
             F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
-        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
+        next_c = F._internal._plus(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
+                                   F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
                                    name=prefix+'state')
-        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation),
+        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
                                   name=prefix+'out')
 
         return next_h, [next_h, next_c]
@@ -635,15 +637,22 @@ class GRUCell(HybridRecurrentCell):
         h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
                                            name=prefix+'h2h_slice')
 
-        reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
+        reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid",
                                   name=prefix+'r_act')
-        update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
+        update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid",
                                    name=prefix+'z_act')
 
-        next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
+        next_h_tmp = F.Activation(F.elemwise_add(i2h,
+                                                 F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'),
+                                                 name=prefix+'plus2'),
+                                  act_type="tanh",
                                   name=prefix+'h_act')
 
-        next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
+        ones = F.ones_like(update_gate, name=prefix+"ones_like0")
+        next_h = F._internal._plus(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
+                                                  next_h_tmp,
+                                                  name=prefix+'mul1'),
+                                   F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
                                    name=prefix+'out')
 
         return next_h, [next_h]
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 4e8241f..c1d5f6a 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -379,6 +379,54 @@ def test_rnn_cells():
     net.add(gluon.rnn.GRUCell(100, input_size=100))
     check_rnn_forward(net, mx.nd.ones((8, 3, 200)))
 
+
+def test_rnn_cells_export_import():
+    class RNNLayer(gluon.HybridBlock):
+        def __init__(self):
+            super(RNNLayer, self).__init__()
+            with self.name_scope():
+                self.cell = gluon.rnn.RNNCell(hidden_size=1)
+
+        def hybrid_forward(self, F, seq):
+            outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
+            return outputs
+
+    class LSTMLayer(gluon.HybridBlock):
+        def __init__(self):
+            super(LSTMLayer, self).__init__()
+            with self.name_scope():
+                self.cell = gluon.rnn.LSTMCell(hidden_size=1)
+
+        def hybrid_forward(self, F, seq):
+            outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
+            return outputs
+
+    class GRULayer(gluon.HybridBlock):
+        def __init__(self):
+            super(GRULayer, self).__init__()
+            with self.name_scope():
+                self.cell = gluon.rnn.GRUCell(hidden_size=1)
+
+        def hybrid_forward(self, F, seq):
+            outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
+            return outputs
+
+    for hybrid in [RNNLayer(), LSTMLayer(), GRULayer()]:
+        hybrid.initialize()
+        hybrid.hybridize()
+        input = mx.nd.ones(shape=(1, 2, 1))
+        output1 = hybrid(input)
+        hybrid.export(path="./model", epoch=0)
+        symbol = mx.gluon.SymbolBlock.imports(
+            symbol_file="./model-symbol.json",
+            input_names=["data"],
+            param_file="./model-0000.params",
+            ctx=mx.Context.default_ctx
+        )
+        output2 = symbol(input)
+        assert_almost_equal(output1.asnumpy(), output2.asnumpy())
+
+
 def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
     layer.collect_params().initialize()
     inputs.attach_grad()