You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/10 19:57:27 UTC

[GitHub] Roshrini closed pull request #13821: onnx export ops

Roshrini closed pull request #13821: onnx export ops
URL: https://github.com/apache/incubator-mxnet/pull/13821
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 61cb353ded4..f9bb5d6d7fe 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -656,12 +656,19 @@ def convert_exp(node, **kwargs):
     return create_basic_op_node('Exp', node, kwargs)
 
 @mx_op.register("_copy")
-def convert_identity(node, **kwargs):
+def convert_copy(node, **kwargs):
     """Map MXNet's _copy operator attributes to onnx's Identity operator
     and return the created node.
     """
     return create_basic_op_node('Identity', node, kwargs)
 
+@mx_op.register("identity")
+def convert_identity(node, **kwargs):
+    """Map MXNet's identity operator attributes to onnx's ConstantFill operator
+    and return the created node.
+    """
+    return create_basic_op_node('ConstantFill', node, kwargs)
+
 @mx_op.register("InstanceNorm")
 def convert_instancenorm(node, **kwargs):
     """Map MXNet's InstanceNorm operator attributes to onnx's InstanceNormalization operator
@@ -752,6 +759,31 @@ def convert_softmax_output(node, **kwargs):
 
     return [softmax_node]
 
+@mx_op.register("LogisticRegressionOutput")
+def convert_logistic_regression_output(node, **kwargs):
+    """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
+    and return the created node.
+    """
+    name = node["name"]
+    input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+    input1 = kwargs["proc_nodes"][input1_idx]
+    sigmoid_node = onnx.helper.make_node(
+        "Sigmoid",
+        [input1.name],
+        [name],
+        name=name
+    )
+    return [sigmoid_node]
+
+@mx_op.register("BlockGrad")
+def convert_blockgrad(node, **kwargs):
+    """ Skip operator  """
+    return create_basic_op_node('ConstantFill', node, kwargs)
+
+@mx_op.register("MakeLoss")
+def convert_makeloss(node, **kwargs):
+    """ Skip operator  """
+    return create_basic_op_node('ConstantFill', node, kwargs)
 
 @mx_op.register("Concat")
 def convert_concat(node, **kwargs):
@@ -898,7 +930,7 @@ def convert_clip(node, **kwargs):
 def scalar_op_helper(node, op_name, **kwargs):
     """Helper function for scalar arithmetic operations"""
     name, input_nodes, attrs = get_inputs(node, kwargs)
-
+    from onnx import numpy_helper
     input_type = kwargs["in_type"]
     scalar_value = np.array([attrs.get("scalar", 1)],
                             dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type])
@@ -910,13 +942,21 @@ def scalar_op_helper(node, op_name, **kwargs):
     for i in initializer:
         if i.name == input_nodes[0]:
             if op_name == 'Mul':
-                new_initializer = onnx.numpy_helper.to_array(i) * scalar_value[0]
+                new_initializer = numpy_helper.to_array(i) * scalar_value[0]
             elif op_name == 'Sub':
-                new_initializer = onnx.numpy_helper.to_array(i) - scalar_value[0]
+                if name.startswith("_rminusscalar"):
+                    new_initializer = scalar_value[0] - numpy_helper.to_array(i)
+                else:
+                    new_initializer = numpy_helper.to_array(i) - scalar_value[0]
             elif op_name == 'Add':
-                new_initializer = onnx.numpy_helper.to_array(i) + scalar_value[0]
+                new_initializer = numpy_helper.to_array(i) + scalar_value[0]
             elif op_name == 'Div':
-                new_initializer = onnx.numpy_helper.to_array(i) / scalar_value[0]
+                if name.startswith("_rdivscalar"):
+                    new_initializer = scalar_value[0] / numpy_helper.to_array(i)
+                else:
+                    new_initializer = numpy_helper.to_array(i) / scalar_value[0]
+            elif op_name == 'Pow':
+                new_initializer = numpy_helper.to_array(i) ** scalar_value[0]
             flag = False
             break
 
@@ -982,6 +1022,13 @@ def convert_minus_scalar(node, **kwargs):
     """
     return scalar_op_helper(node, 'Sub', **kwargs)
 
+@mx_op.register("_rminus_scalar")
+def convert_rminus_scalar(node, **kwargs):
+    """Map MXNet's _rminus_scalar operator attributes to onnx's Sub operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Sub', **kwargs)
 
 # Convert scalar value into node and pass it as input to mul_node
 @mx_op.register("_plus_scalar")
@@ -1001,6 +1048,21 @@ def convert_div_scalar(node, **kwargs):
     """
     return scalar_op_helper(node, 'Div', **kwargs)
 
+@mx_op.register("_rdiv_scalar")
+def convert_rdiv_scalar(node, **kwargs):
+    """Map MXNet's _rdiv_scalar operator attributes to onnx's Div operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Div', **kwargs)
+
+@mx_op.register("_power_scalar")
+def convert_pow_scalar(node, **kwargs):
+    """Map MXNet's _pow_scalar operator attributes to onnx's Pow operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Pow', **kwargs)
 
 # Sorting and Searching
 @mx_op.register("argmax")
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index ef3bda30df3..fbfebbc9ee6 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -730,7 +730,6 @@ def spacetodepth(attrs, inputs, proto_obj):
 
     return "space_to_depth", new_attrs, inputs
 
-
 def hardmax(attrs, inputs, proto_obj):
     """Returns batched one-hot vectors."""
     input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0]
diff --git a/tests/python-pytest/onnx/backend_test.py b/tests/python-pytest/onnx/backend_test.py
index 048a6782c24..8eaa303a6c1 100644
--- a/tests/python-pytest/onnx/backend_test.py
+++ b/tests/python-pytest/onnx/backend_test.py
@@ -71,7 +71,8 @@ def prepare_tests(backend, oper):
     for std_model_test in std_models:
         BACKEND_TESTS.include(std_model_test)
 
-    BACKEND_TESTS.exclude('.*bcast.*')
+    # Tests for scalar ops are in test_node.py
+    BACKEND_TESTS.exclude('.*scalar.*')
 
     return BACKEND_TESTS
 
diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py
index 9b5ff97d165..c8a523edc17 100644
--- a/tests/python-pytest/onnx/test_node.py
+++ b/tests/python-pytest/onnx/test_node.py
@@ -161,6 +161,30 @@ def test_import_export(self):
                 if check_shape:
                     npt.assert_equal(output[0].shape, outputshape)
 
+        input1 = get_rnd((1, 10, 2, 3))
+        ipsym = mx.sym.Variable("input1")
+        for test in test_scalar_ops:
+            if test == 'Add':
+                outsym = 2 + ipsym
+            if test == "Sub":
+                outsym = ipsym - 2
+            if test == "rSub":
+                outsym = ipsym.__rsub__(2)
+            if test == "Mul":
+                outsym = 2 * ipsym
+            if test == "Div":
+                outsym = ipsym / 2
+            if test == "Pow":
+                outsym = ipsym ** 2
+            forward_op = forward_pass(outsym, None, None, ['input1'], input1)
+            converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], np.float32,
+                                                      onnx_file_path=outsym.name + ".onnx")
+
+            sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
+        result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
+
+        npt.assert_almost_equal(result, forward_op)
+
     def test_imports(self):
         for test in import_test_cases:
             test_name, onnx_name, inputs, np_op, attrs = test
@@ -173,7 +197,6 @@ def test_imports(self):
                 mxnet_out = bkd_rep.run(inputs)
                 npt.assert_almost_equal(np_out, mxnet_out)
 
-
 # test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
 # fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
 #                   'remove': [attr_name],
@@ -198,6 +221,8 @@ def test_imports(self):
      {'block_size': 2}, False, {}, True, False),
     ("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
      {'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
+    ("test_logistic_regression", mx.sym.LogisticRegressionOutput, "Sigmoid",
+     [get_rnd((1000, 1000)), get_rnd((1000, 1000))], {}, True, {}, True, False),
     ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
      {'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
     ("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
@@ -223,12 +248,13 @@ def test_imports(self):
      {'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
 ]
 
+test_scalar_ops = ['Add', 'Sub', 'rSub' 'Mul', 'Div', 'Pow']
+
 # test_case = ("test_case_name", "ONNX_op_name", [input_list], np_op, attribute map)
 import_test_cases = [
     ("test_lpnormalization_default", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':-1}),
     ("test_lpnormalization_ord1", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':-1}),
-    ("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1}),
-    ("test_lpnormalization_ord_axis", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':2})
+    ("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
 ]
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services