You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/10/15 20:44:26 UTC
[incubator-mxnet] branch master updated: fixed symbols naming in
RNNCell, LSTMCell, GRUCell (#12794)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b89a36d fixed symbols naming in RNNCell, LSTMCell, GRUCell (#12794)
b89a36d is described below
commit b89a36d94b5b694b8fd926e6249f7490b38432f6
Author: Lorenzo Stella <lo...@users.noreply.github.com>
AuthorDate: Mon Oct 15 22:44:06 2018 +0200
fixed symbols naming in RNNCell, LSTMCell, GRUCell (#12794)
* fixed symbols naming in RNNCell and LSTMCell
* fixed GRUCell as well
* added test
* fixed tests?
---
python/mxnet/gluon/rnn/rnn_cell.py | 25 +++++++++++------
tests/python/unittest/test_gluon_rnn.py | 48 +++++++++++++++++++++++++++++++++
2 files changed, 65 insertions(+), 8 deletions(-)
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 557837c..0f16a89 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -398,7 +398,8 @@ class RNNCell(HybridRecurrentCell):
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
num_hidden=self._hidden_size,
name=prefix+'h2h')
- output = self._get_activation(F, i2h + h2h, self._activation,
+ i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
+ output = self._get_activation(F, i2h_plus_h2h, self._activation,
name=prefix+'out')
return output, [output]
@@ -511,7 +512,7 @@ class LSTMCell(HybridRecurrentCell):
num_hidden=self._hidden_size*4, name=prefix+'i2h')
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
num_hidden=self._hidden_size*4, name=prefix+'h2h')
- gates = i2h + h2h
+ gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
in_gate = self._get_activation(
F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
@@ -521,9 +522,10 @@ class LSTMCell(HybridRecurrentCell):
F, slice_gates[2], self._activation, name=prefix+'c')
out_gate = self._get_activation(
F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
- next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
+ next_c = F._internal._plus(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
+ F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
name=prefix+'state')
- next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation),
+ next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
name=prefix+'out')
return next_h, [next_h, next_c]
@@ -635,15 +637,22 @@ class GRUCell(HybridRecurrentCell):
h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
name=prefix+'h2h_slice')
- reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
+ reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid",
name=prefix+'r_act')
- update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
+ update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid",
name=prefix+'z_act')
- next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
+ next_h_tmp = F.Activation(F.elemwise_add(i2h,
+ F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'),
+ name=prefix+'plus2'),
+ act_type="tanh",
name=prefix+'h_act')
- next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
+ ones = F.ones_like(update_gate, name=prefix+"ones_like0")
+ next_h = F._internal._plus(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
+ next_h_tmp,
+ name=prefix+'mul1'),
+ F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
name=prefix+'out')
return next_h, [next_h]
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 4e8241f..c1d5f6a 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -379,6 +379,54 @@ def test_rnn_cells():
net.add(gluon.rnn.GRUCell(100, input_size=100))
check_rnn_forward(net, mx.nd.ones((8, 3, 200)))
+
+def test_rnn_cells_export_import():
+ class RNNLayer(gluon.HybridBlock):
+ def __init__(self):
+ super(RNNLayer, self).__init__()
+ with self.name_scope():
+ self.cell = gluon.rnn.RNNCell(hidden_size=1)
+
+ def hybrid_forward(self, F, seq):
+ outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
+ return outputs
+
+ class LSTMLayer(gluon.HybridBlock):
+ def __init__(self):
+ super(LSTMLayer, self).__init__()
+ with self.name_scope():
+ self.cell = gluon.rnn.LSTMCell(hidden_size=1)
+
+ def hybrid_forward(self, F, seq):
+ outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
+ return outputs
+
+ class GRULayer(gluon.HybridBlock):
+ def __init__(self):
+ super(GRULayer, self).__init__()
+ with self.name_scope():
+ self.cell = gluon.rnn.GRUCell(hidden_size=1)
+
+ def hybrid_forward(self, F, seq):
+ outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
+ return outputs
+
+ for hybrid in [RNNLayer(), LSTMLayer(), GRULayer()]:
+ hybrid.initialize()
+ hybrid.hybridize()
+ input = mx.nd.ones(shape=(1, 2, 1))
+ output1 = hybrid(input)
+ hybrid.export(path="./model", epoch=0)
+ symbol = mx.gluon.SymbolBlock.imports(
+ symbol_file="./model-symbol.json",
+ input_names=["data"],
+ param_file="./model-0000.params",
+ ctx=mx.Context.default_ctx
+ )
+ output2 = symbol(input)
+ assert_almost_equal(output1.asnumpy(), output2.asnumpy())
+
+
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
inputs.attach_grad()