You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/17 03:36:56 UTC

[GitHub] anirudh2290 closed pull request #10512: [MXNET-309] [ONNX-MXNet] Model Metadata API

anirudh2290 closed pull request #10512: [MXNET-309] [ONNX-MXNet] Model Metadata API
URL: https://github.com/apache/incubator-mxnet/pull/10512
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md
index 44aabaf4419..6fb546fc2b4 100644
--- a/docs/api/python/contrib/onnx.md
+++ b/docs/api/python/contrib/onnx.md
@@ -13,7 +13,7 @@ With ONNX format support for MXNet, developers can build and train models with a
 ```
 
 ### Installation Instructions
-- To use this module developers need to **install ONNX**, which requires protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). Once installed, you can go through the tutorials on how to use this module.
+- To use this module developers need to **install ONNX**, which requires the protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). **MXNet currently supports ONNX v1.1.1**. Once installed, you can go through the tutorials on how to use this module.
 
 
 This document describes all the ONNX-MXNet APIs.
@@ -23,6 +23,7 @@ This document describes all the ONNX-MXNet APIs.
     :nosignatures:
 
     mxnet.contrib.onnx.import_model
+    mxnet.contrib.onnx.get_model_metadata
 ```
 
 ## ONNX Tutorials
@@ -43,7 +44,8 @@ This document describes all the ONNX-MXNet APIs.
 ```eval_rst
 
 .. automodule:: mxnet.contrib.onnx
-    :members: import_model 
+    :members: import_model
+    :members: get_model_metadata
 
 ```
 
diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md b/docs/tutorials/onnx/inference_on_onnx_model.md
index f342dad9bea..3d4072a5415 100644
--- a/docs/tutorials/onnx/inference_on_onnx_model.md
+++ b/docs/tutorials/onnx/inference_on_onnx_model.md
@@ -104,17 +104,26 @@ We pick a context, GPU if available, otherwise CPU
 ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
 ```
 
-We obtain the data names of the inputs to the model, by listing all the inputs to the symbol graph and excluding the argument and auxiliary parameters from that list:
+We obtain the data names of the inputs to the model by using the model metadata API: 
 
 ```python
-data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in aux_params]
-print(data_names)
+model_metadata = onnx_mxnet.get_model_metadata(onnx_path)
+print(model_metadata)
 ```
 
+```
+{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],
+ 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}
+```
 
-```['gpu_0/data_0']```
+```python
+data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
+print(data_names)
+```
 
+```
+[u'gpu_0/data_0']
+```
 
 And load them into a MXNet Gluon symbol block. 
 
diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py
index a52f1a892a6..fcb8ccc88ed 100644
--- a/example/onnx/super_resolution.py
+++ b/example/onnx/super_resolution.py
@@ -55,10 +55,8 @@ def get_test_image():
 
 def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
     """Perform inference on image using mxnet"""
-    # To fetch the data names of the input to the model we list the inputs of the symbol graph
-    # and exclude the argument and auxiliary parameters from the list
-    data_names = [graph_input for graph_input in sym.list_inputs()
-                  if graph_input not in arg_params and graph_input not in aux_params]
+    metadata = onnx_mxnet.get_model_metadata('super_resolution.onnx')
+    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
     # create module
     mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
     mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)])
diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py
index 169ac673455..fb8488ca4f2 100644
--- a/python/mxnet/contrib/onnx/__init__.py
+++ b/python/mxnet/contrib/onnx/__init__.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 """Module for ONNX model format support for Apache MXNet."""
 
-from ._import.import_model import import_model
+from ._import.import_model import import_model, get_model_metadata
diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/_import/import_model.py
index 1bd4b418bc3..4e4d7863755 100644
--- a/python/mxnet/contrib/onnx/_import/import_model.py
+++ b/python/mxnet/contrib/onnx/_import/import_model.py
@@ -52,3 +52,33 @@ def import_model(model_file):
     model_proto = onnx.load(model_file)
     sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
     return sym, arg_params, aux_params
+
+def get_model_metadata(model_file):
+    """
+    Returns the name and shape information of input and output tensors of the given ONNX model file.
+
+    Parameters
+    ----------
+    model_file : str
+        ONNX model file name
+
+    Returns
+    -------
+    model_metadata : dict
+        A dictionary object mapping various metadata to its corresponding value.
+        The dictionary will have the following template.
+        {
+            'input_tensor_data' : <list of tuples representing the shape of the input paramters>,
+            'output_tensor_data' : <list of tuples representing the shape of the output
+                                    of the model>
+        }
+    """
+    graph = GraphProto()
+    try:
+        import onnx
+    except ImportError:
+        raise ImportError("Onnx and protobuf need to be installed. "
+                          + "Instructions to install - https://github.com/onnx/onnx")
+    model_proto = onnx.load(model_file)
+    metadata = graph.get_graph_metadata(model_proto.graph)
+    return metadata
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py
index 5192c6f8a85..db233578ff9 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -132,6 +132,29 @@ def from_onnx(self, graph):
             out = out[0]
         return out, argDict, auxDict
 
+    def get_graph_metadata(self, graph):
+        """
+        Get the model metadata from a given onnx graph.
+        """
+        _params = set()
+        for tensor_vals in graph.initializer:
+            _params.add(tensor_vals.name)
+
+        input_data = []
+        for graph_input in graph.input:
+            if graph_input.name not in _params:
+                shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim]
+                input_data.append((graph_input.name, tuple(shape)))
+
+        output_data = []
+        for graph_out in graph.output:
+            shape = [val.dim_value for val in graph_out.type.tensor_type.shape.dim]
+            output_data.append((graph_out.name, tuple(shape)))
+        metadata = {'input_tensor_data' : input_data,
+                    'output_tensor_data' : output_data
+                   }
+        return metadata
+
     def _parse_array(self, tensor_proto):
         """Grab data in TensorProto and convert to numpy array."""
         try:
diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/onnx_test.py
index e75ef69eea4..b3718c9beb8 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/onnx_test.py
@@ -186,12 +186,17 @@ def test_bvlc_googlenet():
     model_path, inputs, outputs = get_test_files('bvlc_googlenet')
     logging.info("Translating Googlenet model from ONNX to Mxnet")
     sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))]
+    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
 
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in aux_params]
         mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
         mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
@@ -210,12 +215,17 @@ def test_bvlc_reference_caffenet():
     model_path, inputs, outputs = get_test_files('bvlc_reference_caffenet')
     logging.info("Translating Caffenet model from ONNX to Mxnet")
     sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))]
+    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
 
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in aux_params]
         mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
         mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
@@ -234,12 +244,17 @@ def test_bvlc_rcnn_ilsvrc13():
     model_path, inputs, outputs = get_test_files('bvlc_reference_rcnn_ilsvrc13')
     logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet")
     sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data') == [(u'fc-rcnn_1', (1, 200))]
+    data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')]
 
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in aux_params]
         mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
         mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services