You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/02/21 00:46:02 UTC
[incubator-mxnet] branch master updated: onnx broadcast ops fixes
(#13604)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0398a7e onnx broadcast ops fixes (#13604)
0398a7e is described below
commit 0398a7ea33f0a8da1f79b2792a7d79c6ebead73f
Author: Roshani Nagmote <ro...@gmail.com>
AuthorDate: Wed Feb 20 16:45:44 2019 -0800
onnx broadcast ops fixes (#13604)
* broadcasting fixes
* fix
* addressing comments
* fix
* fix
---
.../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 34 +++-------------------
.../contrib/onnx/onnx2mx/_translation_utils.py | 16 +++++++++-
2 files changed, 19 insertions(+), 31 deletions(-)
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index a7cef76..1a8d2ce 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -62,48 +62,22 @@ def sample_multinomial(attrs, inputs, proto_obj):
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
return 'sample_multinomial', new_attrs, inputs
-
# Arithmetic Operations
def add(attrs, inputs, proto_obj):
"""Adding two tensors"""
- new_attr = {}
- if 'broadcast' in attrs and attrs['broadcast'] == 1:
- broadcast_axis = attrs['axis']
- op_value = translation_utils._fix_broadcast('broadcast_add', inputs,
- broadcast_axis, proto_obj)
- return op_value, new_attr, inputs
- return 'broadcast_add', new_attr, inputs
+ return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_add')
def subtract(attrs, inputs, proto_obj):
"""Subtracting two tensors"""
- new_attr = {}
- if 'broadcast' in attrs and attrs['broadcast'] == 1:
- broadcast_axis = attrs['axis']
- op_value = translation_utils._fix_broadcast('broadcast_sub', inputs,
- broadcast_axis, proto_obj)
- return op_value, new_attr, inputs
- return 'broadcast_sub', new_attr, inputs
-
+ return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_sub')
def multiply(attrs, inputs, proto_obj):
"""Multiply two tensors"""
- new_attr = {}
- if 'broadcast' in attrs and attrs['broadcast'] == 1:
- broadcast_axis = attrs['axis']
- op_value = translation_utils._fix_broadcast('broadcast_mul', inputs,
- broadcast_axis, proto_obj)
- return op_value, new_attr, inputs
- return 'broadcast_mul', new_attr, inputs
+ return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_mul')
def divide(attrs, inputs, proto_obj):
"""Divide two tensors"""
- new_attr = {}
- if 'broadcast' in attrs and attrs['broadcast'] == 1:
- broadcast_axis = attrs['axis']
- op_value = translation_utils._fix_broadcast('broadcast_div', inputs,
- broadcast_axis, proto_obj)
- return op_value, new_attr, inputs
- return 'broadcast_div', new_attr, inputs
+ return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_div')
def mean(attrs, inputs, proto_obj):
"""Mean of all the input tensors."""
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
index 6fd5266..0c67305 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
@@ -221,7 +221,7 @@ def get_input_shape(sym, proto_obj):
model_input_shape = [data[1] for data in proto_obj.model_metadata.get('input_tensor_data')]
data_names = [data[0] for data in proto_obj.model_metadata.get('input_tensor_data')]
- #creating dummy inputs
+ # creating dummy inputs
inputs = []
for in_shape in model_input_shape:
inputs.append(nd.ones(shape=in_shape))
@@ -245,3 +245,17 @@ def get_input_shape(sym, proto_obj):
result = mod.get_outputs()[0].asnumpy()
return result.shape
+
+def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name):
+ """Helper function for broadcast arithmetic ops."""
+ new_attr = {}
+ op_names = ['batchnorm, convolution, deconvolution']
+ if 'broadcast' in attrs and attrs['broadcast'] == 1:
+ broadcast_axis = attrs['axis']
+ for op_name in op_names:
+ # if input is bias which comes after conv, deconv, batchnorm operators
+ # then only reshape bias term
+ if inputs[0].name.startswith(op_name):
+ op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj)
+ return op_value, new_attr, inputs
+ return current_op_name, new_attr, inputs