You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/11/13 23:42:01 UTC
[incubator-mxnet] branch master updated: Fix test failure due to
hybridize call in test_gluon_rnn.test_layer_fill_shape (#13043)
This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7dfcc94 Fix test failure due to hybridize call in test_gluon_rnn.test_layer_fill_shape (#13043)
7dfcc94 is described below
commit 7dfcc94d3ea493569fd314cbf86dde7d5c0010bd
Author: Yuxi Hu <da...@gmail.com>
AuthorDate: Tue Nov 13 15:41:49 2018 -0800
Fix test failure due to hybridize call in test_gluon_rnn.test_layer_fill_shape (#13043)
* Restore hybridize call in test_gluon_rnn.test_layer_fill_shape
* reset bulk_size when cached op forward hit error to fix the test failure
* add try-catch block to reset bulk_size in more places to prevent potential bugs
* more cleanup upon exception in Imperative::Backward
---
CONTRIBUTORS.md | 1 +
src/imperative/cached_op.cc | 26 ++++++++++++++++++--------
src/imperative/imperative.cc | 13 ++++++++++---
tests/python/unittest/test_gluon_rnn.py | 1 +
4 files changed, 30 insertions(+), 11 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index a690fb1..a2e19c5 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -190,6 +190,7 @@ List of Contributors
* [Denisa Roberts](https://github.com/D-Roberts)
* [Dick Carter](https://github.com/DickJC123)
* [Rahul Padmanabhan](https://github.com/rahul3)
+* [Yuxi Hu](https://github.com/yuxihu)
Label Bot
---------
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 1f115cd..a836765 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -855,10 +855,15 @@ OpStatePtr CachedOp::Forward(
int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
OpStatePtr op_state;
- if (config_.static_alloc) {
- op_state = StaticForward(default_ctx, inputs, outputs);
- } else {
- op_state = DynamicForward(default_ctx, inputs, outputs);
+ try {
+ if (config_.static_alloc) {
+ op_state = StaticForward(default_ctx, inputs, outputs);
+ } else {
+ op_state = DynamicForward(default_ctx, inputs, outputs);
+ }
+ } catch (const dmlc::Error& e) {
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+ throw e;
}
Engine::Get()->set_bulk_size(prev_bulk_size);
@@ -1058,10 +1063,15 @@ void CachedOp::Backward(
int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
- if (config_.static_alloc) {
- StaticBackward(retain_graph, state, inputs, reqs, outputs);
- } else {
- DynamicBackward(retain_graph, state, inputs, reqs, outputs);
+ try {
+ if (config_.static_alloc) {
+ StaticBackward(retain_graph, state, inputs, reqs, outputs);
+ } else {
+ DynamicBackward(retain_graph, state, inputs, reqs, outputs);
+ }
+ } catch (const dmlc::Error& e) {
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+ throw e;
}
Engine::Get()->set_bulk_size(prev_bulk_size);
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 0c5ff84..32ff8d3 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -494,9 +494,16 @@ std::vector<NDArray*> Imperative::Backward(
bool prev_training = set_is_training(is_train);
int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
- RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
- std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
- is_recording());
+ try {
+ RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+ std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
+ is_recording());
+ } catch (const dmlc::Error& e) {
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+ set_is_recording(prev_recording);
+ set_is_training(prev_training);
+ throw e;
+ }
Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index bfe9592..eee3add 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -594,6 +594,7 @@ def test_cell_fill_shape():
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
+ layer.hybridize()
check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
print(layer)
assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1]