You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/05 20:28:22 UTC

[incubator-mxnet] branch v1.x updated: batchnorm tests (#19836)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new f651452  batchnorm tests (#19836)
f651452 is described below

commit f65145212a0f5c80db822581ae9b8d3d02a4aca1
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Fri Feb 5 12:26:38 2021 -0800

    batchnorm tests (#19836)
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 python/mxnet/contrib/onnx/mx2onnx/_op_translations.py |  4 ++++
 tests/python-pytest/onnx/test_operators.py            | 12 ++++++++++++
 2 files changed, 16 insertions(+)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index fbf55df..3576242 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -396,6 +396,10 @@ def convert_batchnorm(node, **kwargs):
 
     momentum = float(attrs.get("momentum", 0.9))
     eps = float(attrs.get("eps", 0.001))
+    axis = int(attrs.get("axis", 1))
+
+    if axis != 1:
+        raise NotImplementedError("batchnorm axis != 1 is currently not supported.")
 
     bn_node = onnx.helper.make_node(
         "BatchNormalization",
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 8bd7ef0..f4b44e5 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -927,3 +927,15 @@ def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group,
                   **kwargs)
     inputs = [x, w] if no_bias else [x, w, b]
     op_export_test('convolution', M, inputs, tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float32', 'float64'])
+@pytest.mark.parametrize('momentum', [0.9, 0.5, 0.1])
+def test_onnx_export_batchnorm(tmp_path, dtype, momentum):
+    x = mx.nd.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
+    gamma = mx.nd.random.normal(0, 10, (3)).astype(dtype)
+    beta = mx.nd.random.normal(0, 10, (3)).astype(dtype)
+    moving_mean = mx.nd.random.normal(0, 10, (3)).astype(dtype)
+    moving_var = mx.nd.abs(mx.nd.random.normal(0, 10, (3))).astype(dtype)
+    M = def_model('BatchNorm', eps=1e-5, momentum=momentum, fix_gamma=False, use_global_stats=False)
+    op_export_test('BatchNorm1', M, [x, gamma, beta, moving_mean, moving_var], tmp_path)