You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/03 18:54:57 UTC
[incubator-mxnet] branch v1.x updated: Add onnx export function for
log2 operator,
add operator unit test and update tests to allow comparing NaN values.
(#19822)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 4fbe3d2 Add onnx export function for log2 operator, add operator unit test and update tests to allow comparing NaN values. (#19822)
4fbe3d2 is described below
commit 4fbe3d2d8b7f58f9916dfe6afa126150d2e9701c
Author: Joe Evans <jo...@gmail.com>
AuthorDate: Wed Feb 3 10:53:09 2021 -0800
Add onnx export function for log2 operator, add operator unit test and update tests to allow comparing NaN values. (#19822)
Co-authored-by: Joe Evans <jo...@amazon.com>
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 24 ++++++++++++++++++++++
tests/python-pytest/onnx/test_operators.py | 11 ++++++++--
2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 37fc542..b9a7ef0 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -3811,3 +3811,27 @@ def convert_batch_dot(node, **kwargs):
]
return nodes
+
+
+@mx_op.register("log2")
+def convert_log2(node, **kwargs):
+ """Map MXNet's log2 operator attributes to onnx's operator.
+ """
+ from onnx.helper import make_node, make_tensor
+ name, input_nodes, _ = get_inputs(node, kwargs)
+
+ input_type = kwargs["in_type"]
+ dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]
+
+ ln2 = np.array([0.693147180559945309], dtype=dtype)
+ if dtype == 'float16':
+ ln2 = ln2.view(dtype=np.uint16)
+ ln2v = make_tensor(name+'_ln2', input_type, [1], ln2)
+
+ nodes = [
+ make_node('Log', [input_nodes[0]], [name+'_log']),
+ make_node('Constant', [], [name+'_ln2'], value=ln2v),
+ make_node('Div', [name+'_log', name+'_ln2'], [name], name=name)
+ ]
+
+ return nodes
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 049f33e..de3b1e9 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -69,9 +69,9 @@ def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False):
pred_nat = pred_nat[0]
if isinstance(pred_nat, list):
for i in range(len(pred_nat)):
- assert_almost_equal(pred_nat[i], pred_onx[i])
+ assert_almost_equal(pred_nat[i], pred_onx[i], equal_nan=True)
else:
- assert_almost_equal(pred_nat, pred_onx[0])
+ assert_almost_equal(pred_nat, pred_onx[0], equal_nan=True)
def test_onnx_export_abs(tmp_path):
@@ -752,6 +752,13 @@ def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b):
op_export_test('batch_dot2', M2, [x2, y2], tmp_path)
+@pytest.mark.parametrize('dtype', ['float16', 'float32'])
+def test_onnx_export_log2(tmp_path, dtype):
+ x = mx.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
+ M = def_model('log2')
+ op_export_test('log2', M, [x], tmp_path)
+
+
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('axis', [None, 1, [1,2], -1])
def test_onnx_export_sum(tmp_path, dtype, axis):