You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2020/07/01 14:40:42 UTC
[incubator-mxnet] branch master updated: Fix BatchNorm backward
synchronization (#18644)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 37bed6e Fix BatchNorm backward synchronization (#18644)
37bed6e is described below
commit 37bed6e3af794624d651e888101eceb30c27c001
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Wed Jul 1 16:39:22 2020 +0200
Fix BatchNorm backward synchronization (#18644)
* Add test for BatchNorm running variables synchronization
* Fix BatchNorm backward synchronization
It fixes issue #18610
---
src/operator/nn/batch_norm.cc | 3 +++
tests/python/unittest/test_gluon.py | 26 ++++++++++++++++++++++++++
2 files changed, 29 insertions(+)
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 8dbd271..7e540ca 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -653,6 +653,9 @@ then set ``gamma`` to 1 and its gradient to 0.
NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_inputs(8)
.set_num_outputs(3)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) {
+ return std::vector<uint32_t>{6, 7}; // moving_mean, moving_var
+})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 47ef86f..77d5119 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -666,6 +666,32 @@ def test_pool():
@with_seed()
+@pytest.mark.parametrize('variable', ['running_var', 'running_mean'])
+def test_batchnorm_backward_synchronization(variable):
+ """
+ Tests if synchronization of BatchNorm running variables is done correctly.
+ If not, the test sometimes fails - depending on the timing.
+ """
+ ctx = mx.test_utils.default_context()
+
+ for _ in range(20):
+ layer = nn.BatchNorm()
+ layer.initialize(ctx=ctx)
+ for _ in range(3):
+ data = mx.nd.random.normal(loc=10, scale=2, shape=(1, 3, 10, 10), ctx=ctx)
+ with mx.autograd.record():
+ out = layer(data)
+ out.backward()
+
+ # check if each read give the same value
+ var1 = getattr(layer, variable).data().asnumpy()
+ for _ in range(10):
+ var2 = getattr(layer, variable).data().asnumpy()
+ if (var1 != var2).any():
+ raise AssertionError("Two consecutive reads of " + variable + " give different results")
+
+
+@with_seed()
def test_batchnorm():
layer = nn.BatchNorm(in_channels=10)
check_layer_forward(layer, (2, 10, 10, 10))