You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2020/01/16 18:47:48 UTC

[incubator-mxnet] branch v1.6.x updated: fix lstm layer with projection save params (#17266) (#17288)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new a003132  fix lstm layer with projection save params (#17266) (#17288)
a003132 is described below

commit a0031324174d8e6b44b633f8213fffe8b222acb2
Author: Frank Liu <fr...@gmail.com>
AuthorDate: Thu Jan 16 10:47:09 2020 -0800

    fix lstm layer with projection save params (#17266) (#17288)
    
    Co-authored-by: Sheng Zha <sz...@users.noreply.github.com>
---
 python/mxnet/gluon/rnn/rnn_layer.py | 2 +-
 tests/python/gpu/test_gluon_gpu.py  | 2 ++
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 9807c5e..f4489b7 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -126,7 +126,7 @@ class _RNNLayer(HybridBlock):
     def _collect_params_with_prefix(self, prefix=''):
         if prefix:
             prefix += '.'
-        pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z')
+        pattern = re.compile(r'(l|r)(\d)_(i2h|h2h|h2r)_(weight|bias)\Z')
         def convert_key(m, bidirectional): # for compatibility with old parameter format
             d, l, g, t = [m.group(i) for i in range(1, 5)]
             if bidirectional:
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index b938b57..64a8040 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -126,6 +126,8 @@ def test_lstmp():
     check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5),
                             mx.nd.ones((8, 3, 20)),
                             [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True, ctx=ctx)
+    lstm_layer.save_parameters('gpu_tmp.params')
+    lstm_layer.load_parameters('gpu_tmp.params')
 
 
 @with_seed()