You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/02/21 00:46:02 UTC

[incubator-mxnet] branch master updated: onnx broadcast ops fixes (#13604)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0398a7e  onnx broadcast ops fixes (#13604)
0398a7e is described below

commit 0398a7ea33f0a8da1f79b2792a7d79c6ebead73f
Author: Roshani Nagmote <ro...@gmail.com>
AuthorDate: Wed Feb 20 16:45:44 2019 -0800

    onnx broadcast ops fixes (#13604)
    
    * broadcasting fixes
    
    * fix
    
    * addressing comments
    
    * fix
    
    * fix
---
 .../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 34 +++-------------------
 .../contrib/onnx/onnx2mx/_translation_utils.py     | 16 +++++++++-
 2 files changed, 19 insertions(+), 31 deletions(-)

diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index a7cef76..1a8d2ce 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -62,48 +62,22 @@ def sample_multinomial(attrs, inputs, proto_obj):
     new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
     return 'sample_multinomial', new_attrs, inputs
 
-
 # Arithmetic Operations
 def add(attrs, inputs, proto_obj):
     """Adding two tensors"""
-    new_attr = {}
-    if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        broadcast_axis = attrs['axis']
-        op_value = translation_utils._fix_broadcast('broadcast_add', inputs,
-                                                    broadcast_axis, proto_obj)
-        return op_value, new_attr, inputs
-    return 'broadcast_add', new_attr, inputs
+    return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_add')
 
 def subtract(attrs, inputs, proto_obj):
     """Subtracting two tensors"""
-    new_attr = {}
-    if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        broadcast_axis = attrs['axis']
-        op_value = translation_utils._fix_broadcast('broadcast_sub', inputs,
-                                                    broadcast_axis, proto_obj)
-        return op_value, new_attr, inputs
-    return 'broadcast_sub', new_attr, inputs
-
+    return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_sub')
 
 def multiply(attrs, inputs, proto_obj):
     """Multiply two tensors"""
-    new_attr = {}
-    if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        broadcast_axis = attrs['axis']
-        op_value = translation_utils._fix_broadcast('broadcast_mul', inputs,
-                                                    broadcast_axis, proto_obj)
-        return op_value, new_attr, inputs
-    return 'broadcast_mul', new_attr, inputs
+    return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_mul')
 
 def divide(attrs, inputs, proto_obj):
     """Divide two tensors"""
-    new_attr = {}
-    if 'broadcast' in attrs and attrs['broadcast'] == 1:
-        broadcast_axis = attrs['axis']
-        op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
-                                                    broadcast_axis, proto_obj)
-        return op_value, new_attr, inputs
-    return 'broadcast_div', new_attr, inputs
+    return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_div')
 
 def mean(attrs, inputs, proto_obj):
     """Mean of all the input tensors."""
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
index 6fd5266..0c67305 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
@@ -221,7 +221,7 @@ def get_input_shape(sym, proto_obj):
     model_input_shape = [data[1] for data  in proto_obj.model_metadata.get('input_tensor_data')]
     data_names = [data[0] for data  in proto_obj.model_metadata.get('input_tensor_data')]
 
-    #creating dummy inputs
+    # creating dummy inputs
     inputs = []
     for  in_shape in model_input_shape:
         inputs.append(nd.ones(shape=in_shape))
@@ -245,3 +245,17 @@ def get_input_shape(sym, proto_obj):
     result = mod.get_outputs()[0].asnumpy()
 
     return result.shape
+
+def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name):
+    """Helper function for broadcast arithmetic ops."""
+    new_attr = {}
+    op_names = ['batchnorm, convolution, deconvolution']
+    if 'broadcast' in attrs and attrs['broadcast'] == 1:
+        broadcast_axis = attrs['axis']
+        for op_name in op_names:
+            # if input is bias which comes after conv, deconv, batchnorm operators
+            # then only reshape bias term
+            if inputs[0].name.startswith(op_name):
+                op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj)
+                return op_value, new_attr, inputs
+    return current_op_name, new_attr, inputs