You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/05 20:28:22 UTC
[incubator-mxnet] branch v1.x updated: batchnorm tests (#19836)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new f651452 batchnorm tests (#19836)
f651452 is described below
commit f65145212a0f5c80db822581ae9b8d3d02a4aca1
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Fri Feb 5 12:26:38 2021 -0800
batchnorm tests (#19836)
Co-authored-by: Wei Chu <we...@amazon.com>
---
python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 4 ++++
tests/python-pytest/onnx/test_operators.py | 12 ++++++++++++
2 files changed, 16 insertions(+)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index fbf55df..3576242 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -396,6 +396,10 @@ def convert_batchnorm(node, **kwargs):
momentum = float(attrs.get("momentum", 0.9))
eps = float(attrs.get("eps", 0.001))
+ axis = int(attrs.get("axis", 1))
+
+ if axis != 1:
+ raise NotImplementedError("batchnorm axis != 1 is currently not supported.")
bn_node = onnx.helper.make_node(
"BatchNormalization",
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 8bd7ef0..f4b44e5 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -927,3 +927,15 @@ def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group,
**kwargs)
inputs = [x, w] if no_bias else [x, w, b]
op_export_test('convolution', M, inputs, tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float32', 'float64'])
+@pytest.mark.parametrize('momentum', [0.9, 0.5, 0.1])
+def test_onnx_export_batchnorm(tmp_path, dtype, momentum):
+ x = mx.nd.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
+ gamma = mx.nd.random.normal(0, 10, (3)).astype(dtype)
+ beta = mx.nd.random.normal(0, 10, (3)).astype(dtype)
+ moving_mean = mx.nd.random.normal(0, 10, (3)).astype(dtype)
+ moving_var = mx.nd.abs(mx.nd.random.normal(0, 10, (3))).astype(dtype)
+ M = def_model('BatchNorm', eps=1e-5, momentum=momentum, fix_gamma=False, use_global_stats=False)
+ op_export_test('BatchNorm1', M, [x, gamma, beta, moving_mean, moving_var], tmp_path)