You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/18 17:04:22 UTC

[incubator-mxnet] branch master updated: fix rnn layer kernel forward (#10982)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5cddc2d  fix rnn layer kernel forward (#10982)
5cddc2d is described below

commit 5cddc2dad8b27985546cbba51ad98fff3d22a879
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Fri May 18 10:04:16 2018 -0700

    fix rnn layer kernel forward (#10982)
---
 python/mxnet/gluon/rnn/rnn_layer.py     | 4 ++--
 tests/python/unittest/test_gluon_rnn.py | 2 ++
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 89224cf..2beae96 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -23,7 +23,7 @@
 from __future__ import print_function
 __all__ = ['RNN', 'LSTM', 'GRU']
 
-from ... import ndarray, autograd
+from ... import ndarray
 from .. import Block
 from . import rnn_cell
 
@@ -186,7 +186,7 @@ class _RNNLayer(Block):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
                 self.i2h_weight[i]._finish_deferred_init()
         if inputs.context.device_type == 'gpu' or \
-           self._mode == 'lstm' and not (self._dropout and autograd.is_training()):
+           self._mode == 'lstm' and not self._dropout:
             out = self._forward_kernel(inputs, states)
         else:
             out = self._forward(inputs, states)
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 24d5a93..9dbcb3b 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -268,6 +268,8 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
             assert isinstance(out, mx.nd.NDArray)
         out.backward()
 
+    layer(inputs, states) # test is_training = false
+
     if not run_only:
         mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
         mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.