You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/07/20 22:11:20 UTC
[incubator-mxnet] branch master updated: fix normalize mean error
bug (#15539)
This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8f5930b fix normalize mean error bug (#15539)
8f5930b is described below
commit 8f5930b2c95a6b7594ff6535a097e35b3315bc6d
Author: nicklhy <ni...@gmail.com>
AuthorDate: Sun Jul 21 06:10:36 2019 +0800
fix normalize mean error bug (#15539)
* fix normalize mean error bug
* add scalar mean/std tests for image_normalize
---
src/operator/image/image_random-inl.h | 2 +-
tests/python/unittest/test_operator.py | 53 ++++++++++++++++++++++++++++++++--
2 files changed, 52 insertions(+), 3 deletions(-)
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index aeb189f..e00b255 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -339,7 +339,7 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
std::vector<float> mean(3);
std::vector<float> std(3);
if (param.mean.ndim() == 1) {
- mean[0] = mean[1] = mean[3] = param.mean[0];
+ mean[0] = mean[1] = mean[2] = param.mean[0];
} else {
mean[0] = param.mean[0];
mean[1] = param.mean[1];
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index fea07f5..915a83f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -8678,7 +8678,7 @@ def test_invalid_max_pooling_pad_type_same():
@with_seed()
def test_image_normalize():
- # Part 1 - Test 3D Input
+ # Part 1 - Test 3D input with 3D mean/std
shape_3d = (3, 28, 28)
mean = (0, 1, 2)
std = (3, 2, 1)
@@ -8709,7 +8709,7 @@ def test_image_normalize():
# check backward using finite difference
check_numeric_gradient(img_norm_sym, [data_in_3d], atol=0.001)
- # Part 2 - Test 4D Input
+ # Part 2 - Test 4D input with 3D mean/std
shape_4d = (2, 3, 28, 28)
data_in_4d = mx.nd.random.uniform(0, 1, shape_4d)
@@ -8741,6 +8741,55 @@ def test_image_normalize():
# check backward using finite difference
check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001)
+ # Part 3 - Test 3D input with scalar mean/std
+ shape_3d = (3, 28, 28)
+ mean = 1.0
+ std = 2.0
+
+ data_in_3d = mx.nd.random.uniform(0, 1, shape_3d)
+ data_expected_3d = data_in_3d.asnumpy()
+ data_expected_3d[:][:][:] = (data_expected_3d[:][:][:] - 1.0) / 2.0
+
+ data = mx.symbol.Variable('data')
+ img_norm_sym = mx.sym.image.normalize(data=data, mean=mean, std=std)
+
+ # check forward
+ check_symbolic_forward(img_norm_sym, [data_in_3d], [data_expected_3d],
+ rtol=1e-5, atol=1e-5)
+
+ # Gradient is 1/std_dev
+ grad_expected_3d = np.ones(shape_3d)
+ grad_expected_3d[:][:][:] = 1 / 2.0
+
+ # check backward
+ check_symbolic_backward(img_norm_sym, location=[data_in_3d], out_grads=[mx.nd.ones(shape_3d)],
+ expected=[grad_expected_3d], rtol=1e-5, atol=1e-5)
+
+ # check backward using finite difference
+ check_numeric_gradient(img_norm_sym, [data_in_3d], atol=0.001)
+
+ # Part 4 - Test 4D input with scalar mean/std
+ shape_4d = (2, 3, 28, 28)
+
+ data_in_4d = mx.nd.random.uniform(0, 1, shape_4d)
+ data_expected_4d = data_in_4d.asnumpy()
+ data_expected_4d[:][:][:][:] = (data_expected_4d[:][:][:][:] - 1.0) / 2.0
+
+ # check forward
+ check_symbolic_forward(img_norm_sym, [data_in_4d], [data_expected_4d],
+ rtol=1e-5, atol=1e-5)
+
+ # Gradient is 1/std_dev
+ grad_expected_4d = np.ones(shape_4d)
+ grad_expected_4d[:][:][:][:] = 1 / 2.0
+
+ # check backward
+ check_symbolic_backward(img_norm_sym, location=[data_in_4d], out_grads=[mx.nd.ones(shape_4d)],
+ expected=[grad_expected_4d], rtol=1e-5, atol=1e-5)
+
+ # check backward using finite difference
+ check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001)
+
@with_seed()
def test_index_array():
def test_index_array_default():