You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/03/12 03:33:51 UTC

[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #20001: Onnx Dynamic Shapes Support

Zha0q1 commented on a change in pull request #20001:
URL: https://github.com/apache/incubator-mxnet/pull/20001#discussion_r592887059



##########
File path: python/mxnet/contrib/onnx/mx2onnx/export_model.py
##########
@@ -62,42 +74,57 @@ def export_model(sym, params, input_shape, input_type=np.float32,
     """
 
     try:
-        from onnx import helper, mapping
+        from onnx import helper, mapping, shape_inference
         from onnx.defs import onnx_opset_version
     except ImportError:
         raise ImportError("Onnx and protobuf need to be installed. "
                           + "Instructions to install - https://github.com/onnx/onnx")
 
+    if input_type is not None:
+        in_types = input_type
+
+    if input_shape is not None:
+        in_shapes = input_shape
+
     converter = MXNetGraph()
     if opset_version is None:
         # default is to use latest opset version the onnx package supports
         opset_version = onnx_opset_version()
 
-    if not isinstance(input_type, list):
-        input_type = [input_type for _ in range(len(input_shape))]
-    input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type]
+    if not isinstance(in_types, list):
+        in_types = [in_types for _ in range(len(in_shapes))]
+    in_types_t = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(i_t)] for i_t in in_types]
     # if input parameters are strings(file paths), load files and create symbol parameter objects
     if isinstance(sym, string_types) and isinstance(params, string_types):
         logging.info("Converting json and weight file to sym and params")
         sym_obj, params_obj = load_module(sym, params)
-        onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
-                                                       input_dtype,
-                                                       verbose=verbose, opset_version=opset_version)
+        onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, in_shapes,
+                                                       in_types_t,
+                                                       verbose=verbose, opset_version=opset_version,
+                                                       dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
     elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
-        onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
-                                                       input_dtype,
-                                                       verbose=verbose, opset_version=opset_version)
+        onnx_graph = converter.create_onnx_graph_proto(sym, params, in_shapes,
+                                                       in_types_t,
+                                                       verbose=verbose, opset_version=opset_version,
+                                                       dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
     else:
         raise ValueError("Input sym and params should either be files or objects")
 
     # Create the model (ModelProto)
     onnx_model = helper.make_model(onnx_graph)
 
+    # Run shape inference on the model. Due to ONNX bug/incompatibility this may or may not crash

Review comment:
       This is an optional step. Doing shape inference may help with some runtime optimization and we can visualize the graph better, since the noes will have input and output shapes labeled




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org