You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/10 18:14:46 UTC

[GitHub] Roshrini closed pull request #12946: [WIP] Onnx export API optional args and additional operator support, Fixes #12682

Roshrini closed pull request #12946: [WIP] Onnx export API optional args and additional operator support, Fixes #12682
URL: https://github.com/apache/incubator-mxnet/pull/12946
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 11e75d9a600..f7bda29d10b 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -218,7 +218,6 @@ def convert_convolution(node, **kwargs):
 
     return [conv_node]
 
-
 @mx_op.register("FullyConnected")
 def convert_fully_connected(node, **kwargs):
     """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
@@ -443,20 +442,19 @@ def convert_dot(node, **kwargs):
     transpose_a, transpose_b attributes."""
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
-    trans_a_node = None
-    trans_b_node = None
+    trans_a_node = trans_b_node = None
 
     trans_a = get_boolean_attribute_value(attrs, "transpose_a")
     trans_b = get_boolean_attribute_value(attrs, "transpose_b")
 
     op_name = "transpose" + str(kwargs["idx"])
+    input_node_a = op_name + "_a"
+    input_node_b = op_name + "_b"
 
     if trans_a:
         trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a')
-        input_node_a = op_name+"_a"
     if trans_b:
         trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b')
-        input_node_b = op_name+"_b"
 
     matmul_node = onnx.helper.make_node(
         'MatMul',
@@ -617,13 +615,21 @@ def convert_exp(node, **kwargs):
     return create_basic_op_node('Exp', node, kwargs)
 
 @mx_op.register("_copy")
-def convert_identity(node, **kwargs):
+def convert_copy(node, **kwargs):
     """Map MXNet's _copy operator attributes to onnx's Identity operator
     and return the created node.
     """
     return create_basic_op_node('Identity', node, kwargs)
 
 
+@mx_op.register("identity")
+def convert_identity(node, **kwargs):
+    """Map MXNet's identity operator attributes to onnx's ConstantFill operator
+    and return the created node.
+    """
+    return create_basic_op_node('ConstantFill', node, kwargs)
+
+
 @mx_op.register("LeakyReLU")
 def convert_leakyrelu(node, **kwargs):
     """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators
@@ -681,7 +687,7 @@ def convert_softmax_output(node, **kwargs):
     """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
     and return the created node.
     """
-    name, _, _ = get_inputs(node, kwargs)
+    name = node["name"]
 
     input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
     input1 = kwargs["proc_nodes"][input1_idx]
@@ -693,10 +699,38 @@ def convert_softmax_output(node, **kwargs):
         axis=1,
         name=name
     )
-
     return [softmax_node]
 
 
+@mx_op.register("LogisticRegressionOutput")
+def convert_logistic_regression_output(node, **kwargs):
+    """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
+    and return the created node.
+    """
+    name = node["name"]
+    input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
+    input1 = kwargs["proc_nodes"][input1_idx]
+
+    sigmoid_node = onnx.helper.make_node(
+        "Sigmoid",
+        [input1.output[0]],
+        [name],
+        name=name
+    )
+    return [sigmoid_node]
+
+@mx_op.register("BlockGrad")
+def convert_blockgrad(node, **kwargs):
+    """ Skip operator  """
+    return create_basic_op_node('ConstantFill', node, kwargs)
+
+
+@mx_op.register("make_loss")
+def convert_makeloss(node, **kwargs):
+    """ Skip operator  """
+    return create_basic_op_node('ConstantFill', node, kwargs)
+
+
 @mx_op.register("Concat")
 def convert_concat(node, **kwargs):
     """Map MXNet's Concat operator attributes to onnx's Concat operator
@@ -843,6 +877,8 @@ def scalar_op_helper(node, op_name, **kwargs):
     """Helper function for scalar arithmetic operations"""
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
+    from onnx import numpy_helper
+
     scalar_value = [float(attrs.get("scalar", 1))]
 
     initializer = kwargs["initializer"]
@@ -852,13 +888,19 @@ def scalar_op_helper(node, op_name, **kwargs):
     for i in initializer:
         if i.name == input_nodes[0]:
             if op_name == 'Mul':
-                new_initializer = onnx.numpy_helper.to_array(i) * scalar_value[0]
+                new_initializer = numpy_helper.to_array(i) * scalar_value[0]
             elif op_name == 'Sub':
-                new_initializer = onnx.numpy_helper.to_array(i) - scalar_value[0]
+                if name.startswith("_rminusscalar"):
+                    new_initializer = scalar_value[0] - numpy_helper.to_array(i)
+                else:
+                    new_initializer = numpy_helper.to_array(i) - scalar_value[0]
             elif op_name == 'Add':
-                new_initializer = onnx.numpy_helper.to_array(i) + scalar_value[0]
+                new_initializer = numpy_helper.to_array(i) + scalar_value[0]
             elif op_name == 'Div':
-                new_initializer = onnx.numpy_helper.to_array(i) / scalar_value[0]
+                if name.startswith("_rdivscalar"):
+                    new_initializer = scalar_value[0] / numpy_helper.to_array(i)
+                else:
+                    new_initializer = numpy_helper.to_array(i) / scalar_value[0]
             flag = False
             break
 
@@ -869,6 +911,7 @@ def scalar_op_helper(node, op_name, **kwargs):
         dims = np.shape(np_arr)
 
         scalar_op_name = "scalar_op" + str(kwargs["idx"])
+        # Convert scalar value into node
         tensor_node = onnx.helper.make_tensor_value_info(scalar_op_name, data_type, dims)
 
         initializer.append(
@@ -907,7 +950,6 @@ def scalar_op_helper(node, op_name, **kwargs):
         )
         return [tensor_node]
 
-# Convert scalar value into node and pass it as input to mul_node
 @mx_op.register("_mul_scalar")
 def convert_mul_scalar(node, **kwargs):
     """Map MXNet's _mul_scalar operator attributes to onnx's Mul operator.
@@ -916,8 +958,6 @@ def convert_mul_scalar(node, **kwargs):
     """
     return scalar_op_helper(node, 'Mul', **kwargs)
 
-
-# Convert scalar value into node and pass it as input to mul_node
 @mx_op.register("_minus_scalar")
 def convert_minus_scalar(node, **kwargs):
     """Map MXNet's _minus_scalar operator attributes to onnx's Minus operator.
@@ -926,8 +966,14 @@ def convert_minus_scalar(node, **kwargs):
     """
     return scalar_op_helper(node, 'Sub', **kwargs)
 
+@mx_op.register("_rminus_scalar")
+def convert_rminus_scalar(node, **kwargs):
+    """Map MXNet's _rminus_scalar operator attributes to onnx's Minus operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Sub', **kwargs)
 
-# Convert scalar value into node and pass it as input to mul_node
 @mx_op.register("_plus_scalar")
 def convert_add_scalar(node, **kwargs):
     """Map MXNet's _plus_scalar operator attributes to onnx's Add operator.
@@ -936,7 +982,6 @@ def convert_add_scalar(node, **kwargs):
     """
     return scalar_op_helper(node, 'Add', **kwargs)
 
-# Convert scalar value into node and pass it as input to mul_node
 @mx_op.register("_div_scalar")
 def convert_div_scalar(node, **kwargs):
     """Map MXNet's _div_scalar operator attributes to onnx's Div operator.
@@ -945,6 +990,14 @@ def convert_div_scalar(node, **kwargs):
     """
     return scalar_op_helper(node, 'Div', **kwargs)
 
+@mx_op.register("_rdiv_scalar")
+def convert_rdiv_scalar(node, **kwargs):
+    """Map MXNet's _rdiv_scalar operator attributes to onnx's Div operator.
+    Creates a new node for the input scalar value, adds it to the initializer
+    and return multiple created nodes.
+    """
+    return scalar_op_helper(node, 'Div', **kwargs)
+
 
 # Sorting and Searching
 @mx_op.register("argmax")
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index e5158051d6f..f09f674e7d2 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -32,7 +32,7 @@
 from ._export_helper import load_module
 
 
-def export_model(sym, params, input_shape, input_type=np.float32,
+def export_model(sym, params, input_shape, input_type=np.float32, label_names=None, label_shapes=None,
                  onnx_file_path='model.onnx', verbose=False):
     """Exports the MXNet model file, passed as a parameter, into ONNX model.
     Accepts both symbol,parameter objects as well as json and params filepaths as input.
@@ -49,6 +49,10 @@ def export_model(sym, params, input_shape, input_type=np.float32,
         Input shape of the model e.g [(1,3,224,224)]
     input_type : data type
         Input data type e.g. np.float32
+    label_names : List of str
+        Optional list of label e.g. ['regression_label']
+    label_shapes : List of tuple
+        Optional a list of (name, shape) pairs e.g [('regression_label', (1,3,224,224))]
     onnx_file_path : str
         Path where to save the generated onnx file
     verbose : Boolean
@@ -75,11 +79,11 @@ def export_model(sym, params, input_shape, input_type=np.float32,
         sym_obj, params_obj = load_module(sym, params)
         onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
                                                        mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
-                                                       verbose=verbose)
+                                                       label_names, label_shapes, verbose=verbose)
     elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
         onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
                                                        mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
-                                                       verbose=verbose)
+                                                       label_names, label_shapes, verbose=verbose)
     else:
         raise ValueError("Input sym and params should either be files or objects")
 
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index b02d970f9c2..4a9dba07e34 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -96,7 +96,7 @@ def convert_layer(node, **kwargs):
         return convert_func(node, **kwargs)
 
     @staticmethod
-    def forward_pass(inputs, sym, arg_params, aux_params, output_label):
+    def forward_pass(inputs, sym, arg_params, aux_params, label_name):
         """Do a forward pass based on the sym and params to get the shape
         of the output using dummy data
 
@@ -120,7 +120,7 @@ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
         # while running load_checkpoint which is not actually a graph input. So ignoring it here
         data_names = [graph_input for graph_input in sym.list_inputs()
                       if graph_input not in arg_params and graph_input not in aux_params
-                      and graph_input != output_label]
+                      and graph_input not in label_name]
 
         data_shapes = []
         # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
@@ -144,9 +144,13 @@ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
             data_forward.append(nd.array(val))
 
         test_mod.forward(io.DataBatch(data_forward))
-        result = test_mod.get_outputs()[0].asnumpy()
+        result = [i.asnumpy().shape for i in test_mod.get_outputs()]
 
-        return result.shape
+        result_shape = []
+        for idx, label in enumerate(label_name):
+            result_shape.append((label, result[idx]))
+
+        return result_shape
 
 
     @staticmethod
@@ -179,12 +183,12 @@ def split_params(sym, params):
 
 
     @staticmethod
-    def infer_output_shape(sym, params, in_shape, output_label):
+    def infer_output_shape(sym, params, in_shape, label_name):
         """Infer output shape by doing a forward pass using dummy inputs """
         # create dummy input
         inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
         arg, aux = MXNetGraph.split_params(sym, params)
-        return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
+        return MXNetGraph.forward_pass(inputs, sym, arg, aux, label_name)
 
 
     @staticmethod
@@ -193,7 +197,18 @@ def convert_weights_to_numpy(weights_dict):
         return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
                      for k, v in weights_dict.items()])
 
-    def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False):
+    @staticmethod
+    def verify_provided_labels(data_names, data_shapes, name, throw):
+        """Check that input labels matches input data shape."""
+        actual = [x[0] for x in data_shapes]
+        if sorted(data_names) != sorted(actual):
+            msg = "Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)" % (
+                name, name, str(data_shapes), str(data_names))
+            if throw:
+                raise ValueError(msg)
+
+    def create_onnx_graph_proto(self, sym, params, in_shape, in_type,
+                                label_names=None, label_shapes=None, verbose=False):
         """Convert MXNet graph to ONNX graph
 
         Parameters
@@ -206,6 +221,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
             Input shape of the model e.g [(1,3,224,224)]
         in_type : data type
             Input data type e.g. np.float32
+        out_label : List of str
+            Optional list of output label names
+        out_shape : List of tuple
+            Optional output shape of the model e.g [(1,3,224,224)]
         verbose : Boolean
             If true will print logs of the model conversion
 
@@ -226,10 +245,17 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
         # name is "Softmax", this node will have a name "Softmax_label". Also, the new node
         # will always be second last node in the json graph.
         # Deriving the output_label name.
-        output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label"
+        output_suffix = '_output'
+        output_names = [o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]
+
+        if not label_names:
+            label_names = [output_name + '_label' for output_name in output_names]
 
         # Determine output shape
-        output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)
+        if not label_shapes:
+            label_shapes = MXNetGraph.infer_output_shape(sym, params, in_shape, label_names)
+        else:
+            MXNetGraph.verify_provided_labels(label_names, label_shapes, 'label', True)
 
         weights = MXNetGraph.convert_weights_to_numpy(params)
 
@@ -253,10 +279,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
             # in params dict
             if op == "null" and name not in params:
                 # Handling graph input
-
                 # Skipping output_label node, as this node is not part of graph
                 # Refer "output_label" assignment above for more details.
-                if name == output_label:
+                if name in label_names:
                     continue
                 converted = MXNetGraph.convert_layer(
                     node,
@@ -294,14 +319,15 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                     # If converted node is NodeProto, add it in processed nodes list
                     elif isinstance(converted_node, NodeProto):
                         onnx_processed_nodes.append(converted_node)
-                        if idx == (len(mx_graph) - 1):
+                        if converted_node.name in output_names:
+                            label_shape = [i[1] for i in label_shapes if converted_node.name + "_label" == i[0]]
                             # If converted node doesnt have name, use it from output field
                             if not converted_node.name:
                                 onnx_processed_outputs.append(
                                     make_tensor_value_info(
                                         name=converted_node.output[0],
                                         elem_type=in_type,
-                                        shape=output_shape
+                                        shape=label_shape[0]
                                     )
                                 )
                             else:
@@ -309,7 +335,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                                     make_tensor_value_info(
                                         name=converted_node.name,
                                         elem_type=in_type,
-                                        shape=output_shape
+                                        shape=label_shape[0]
                                     )
                                 )
                             if verbose:
@@ -327,12 +353,12 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
                     # refer "output_label" initialization above for more details.
                     # if extra node was added then prev_index to the last node is adjusted.
                     if idx == (len(mx_graph) - 1) and \
-                            mx_graph[len(mx_graph)-2]["name"] == output_label:
+                            mx_graph[len(mx_graph) - 2]["name"] in label_names:
                         prev_index = index_lookup[idx - 2]
                     else:
                         prev_index = index_lookup[idx - 1]
 
-                    index_lookup.append(prev_index+len(converted))
+                    index_lookup.append(prev_index + len(converted))
                 else:
                     index_lookup.append(len(converted) - 1)
             else:
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index f61910f838e..2097baabd53 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -44,6 +44,7 @@
 _convert_map = {
     # Generator Functions
     'Constant'          : identity,
+    'ConstantFill'      : identity,
     'RandomUniform'     : random_uniform,
     'RandomNormal'      : random_normal,
     'RandomUniformLike' : random_uniform,
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 9f91369d667..bd3f82b75d6 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -118,7 +118,7 @@ def test_models(model_name, input_shape, output_shape):
     onnx_file = os.path.join(dir_path, new_model_name)
 
     logging.info("Translating converted model from mxnet to ONNX")
-    converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)
+    converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file_path=onnx_file)
 
     sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path)
 
@@ -166,7 +166,7 @@ def test_model_accuracy(model_name, input_shape):
 
     logging.info("Translating converted model from mxnet to ONNX")
     converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32,
-                                                   onnx_file)
+                                                   onnx_file_path=onnx_file)
 
     sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path)
 
@@ -229,7 +229,7 @@ def test_square():
     params.update(args)
     params.update(auxs)
 
-    converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")
+    converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, onnx_file_path="square.onnx")
 
     sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
     result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services