You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/20 21:43:05 UTC

[incubator-mxnet] branch master updated: Fix a bug in SequentialRNNCell.reset() (#7449)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0efc326  Fix a bug in SequentialRNNCell.reset() (#7449)
0efc326 is described below

commit 0efc326e2243625d622a43287bf15c62e6afd1b0
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Mon Aug 21 05:43:02 2017 +0800

    Fix a bug in SequentialRNNCell.reset() (#7449)
    
    * remove self-implemented speedometer
    
    * fix bug in SequentialRNNCell.reset
    
    * Revert "remove self-implemented speedometer"
    
    This reverts commit 17aa4c0887c099f22c4769de079ef0130ed5f3e8.
    
    * fix lint
    
    * fix
    
    * fix reset in origin rnn and gluon rnn
    
    * fix origin rnn
---
 python/mxnet/gluon/rnn/rnn_cell.py | 2 ++
 python/mxnet/rnn/rnn_cell.py       | 3 +++
 2 files changed, 5 insertions(+)

diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index c9186fd..eb67fd7 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -121,6 +121,8 @@ class RecurrentCell(Block):
         """Reset before re-using the cell for another graph."""
         self._init_counter = -1
         self._counter = -1
+        for cell in self._children:
+            cell.reset()
 
     def state_info(self, batch_size=0):
         """shape and layout information of states"""
diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py
index 1c34520..b2bf107 100644
--- a/python/mxnet/rnn/rnn_cell.py
+++ b/python/mxnet/rnn/rnn_cell.py
@@ -134,6 +134,9 @@ class BaseRNNCell(object):
         """Reset before re-using the cell for another graph."""
         self._init_counter = -1
         self._counter = -1
+        if hasattr(self, '_cells'):
+            for cell in self._cells:
+                cell.reset()
 
     def __call__(self, inputs, states):
         """Unroll the RNN for one time step.

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].