You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2020/07/09 04:40:58 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] Backport of Fix
BatchNorm backward synchronization (#18644) (#18654)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 024daa6 [v1.x] Backport of Fix BatchNorm backward synchronization (#18644) (#18654)
024daa6 is described below
commit 024daa6b56fab4b96f135fd0c5c9489505ba307a
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Thu Jul 9 06:39:38 2020 +0200
[v1.x] Backport of Fix BatchNorm backward synchronization (#18644) (#18654)
* Add test for BatchNorm running variables synchronization
* Fix BatchNorm backward synchronization
It fixes issue #18610
---
src/operator/nn/batch_norm.cc | 3 +++
tests/python/unittest/test_gluon.py | 26 ++++++++++++++++++++++++++
2 files changed, 29 insertions(+)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index af8f25a..3e36559 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -640,6 +640,9 @@ then set ``gamma`` to 1 and its gradient to 0.
NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_inputs(8)
.set_num_outputs(3)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) {
+ return std::vector<uint32_t>{6, 7}; // moving_mean, moving_var
+})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index cf6bc36..60fd526 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -759,6 +759,32 @@ def test_pool():
@with_seed()
+def test_batchnorm_backward_synchronization():
+ """
+ Tests if synchronization of BatchNorm running variables is done correctly.
+ If not, the test sometimes fails - depending on the timing.
+ """
+ ctx = mx.test_utils.default_context()
+
+ for variable in ['running_var', 'running_mean']:
+ for _ in range(20):
+ layer = nn.BatchNorm()
+ layer.initialize(ctx=ctx)
+ for _ in range(3):
+ data = mx.nd.random.normal(loc=10, scale=2, shape=(1, 3, 10, 10), ctx=ctx)
+ with mx.autograd.record():
+ out = layer(data)
+ out.backward()
+
+ # check if each read give the same value
+ var1 = getattr(layer, variable).data().asnumpy()
+ for _ in range(10):
+ var2 = getattr(layer, variable).data().asnumpy()
+ if (var1 != var2).any():
+ raise AssertionError("Two consecutive reads of " + variable + " give different results")
+
+
+@with_seed()
def test_batchnorm():
layer = nn.BatchNorm(in_channels=10)
check_layer_forward(layer, (2, 10, 10, 10))