You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/04 21:52:27 UTC

[GitHub] anirudh2290 closed pull request #11106: [ONNX] Added Unsqueeze operator import support

anirudh2290 closed pull request #11106: [ONNX] Added Unsqueeze operator import support
URL: https://github.com/apache/incubator-mxnet/pull/11106
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py b/python/mxnet/contrib/onnx/_import/import_helper.py
index 175c2fb6a00..c8d45216729 100644
--- a/python/mxnet/contrib/onnx/_import/import_helper.py
+++ b/python/mxnet/contrib/onnx/_import/import_helper.py
@@ -28,7 +28,7 @@
 from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
 from .op_translations import dropout, local_response_norm, conv, deconv
 from .op_translations import reshape, cast, split, _slice, transpose, squeeze, flatten
-from .op_translations import reciprocal, squareroot, power, exponent, _log
+from .op_translations import reciprocal, squareroot, power, exponent, _log, unsqueeze
 from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
 from .op_translations import reduce_prod, avg_pooling, max_pooling
 from .op_translations import argmax, argmin, maximum, minimum
@@ -83,6 +83,7 @@
     'Slice'             : _slice,
     'Transpose'         : transpose,
     'Squeeze'           : squeeze,
+    'Unsqueeze'         : unsqueeze,
     'Flatten'           : flatten,
     #Powers
     'Reciprocal'        : reciprocal,
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py b/python/mxnet/contrib/onnx/_import/op_translations.py
index 5df9d913f11..e02cb0c2b62 100644
--- a/python/mxnet/contrib/onnx/_import/op_translations.py
+++ b/python/mxnet/contrib/onnx/_import/op_translations.py
@@ -399,6 +399,15 @@ def squeeze(attrs, inputs, proto_obj):
         mxnet_op = symbol.split(mxnet_op, axis=i-1, num_outputs=1, squeeze_axis=1)
     return mxnet_op, new_attrs, inputs
 
+def unsqueeze(attrs, inputs, cls):
+    """Inserts a new axis of size 1 into the array shape"""
+    # MXNet can only add one axis at a time.
+    mxnet_op = inputs[0]
+    for axis in attrs["axes"]:
+        mxnet_op = symbol.expand_dims(mxnet_op, axis=axis)
+
+    return mxnet_op, attrs, inputs
+
 
 def flatten(attrs, inputs, proto_obj):
     """Flattens the input array into a 2-D array by collapsing the higher dimensions."""
diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py
index d408930970b..8e6dc443bba 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -41,6 +41,7 @@
     'test_reduce_mean',
     'test_reduce_prod',
     'test_squeeze',
+    'test_unsqueeze',
     'test_softmax_example',
     'test_softmax_large_number',
     'test_softmax_axis_2',


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services