You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/11/20 05:39:15 UTC
[incubator-mxnet] branch master updated: Fix ONNX export of
keepdims param (#12924)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b8d1c02 Fix ONNX export of keepdims param (#12924)
b8d1c02 is described below
commit b8d1c02f79947968c5f06b89f15cb021115767dc
Author: Moritz Maxeiner <mo...@gmail.com>
AuthorDate: Tue Nov 20 06:38:56 2018 +0100
Fix ONNX export of keepdims param (#12924)
- keepdims is stored as a boolean string, not an int
- Operators:
- argmax
- argmin
- min
- max
- prod
- mean
---
python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 7e84cea..73ca07b 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -955,7 +955,7 @@ def convert_argmax(node, **kwargs):
name, input_nodes, attrs = get_inputs(node, kwargs)
axis = int(attrs.get("axis"))
- keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1
+ keepdims = get_boolean_attribute_value(attrs, "keepdims")
node = onnx.helper.make_node(
'ArgMax',
@@ -975,7 +975,7 @@ def convert_argmin(node, **kwargs):
name, input_nodes, attrs = get_inputs(node, kwargs)
axis = int(attrs.get("axis"))
- keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1
+ keepdims = get_boolean_attribute_value(attrs, "keepdims")
node = onnx.helper.make_node(
'ArgMin',
@@ -1012,7 +1012,7 @@ def convert_min(node, **kwargs):
mx_axis = attrs.get("axis", None)
axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
- keepdims = int(attrs.get("keepdims", 0))
+ keepdims = get_boolean_attribute_value(attrs, "keepdims")
if axes is not None:
node = onnx.helper.make_node(
@@ -1047,7 +1047,7 @@ def convert_max(node, **kwargs):
mx_axis = attrs.get("axis", None)
axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
- keepdims = int(attrs.get("keepdims", 0))
+ keepdims = get_boolean_attribute_value(attrs, "keepdims")
if axes is not None:
node = onnx.helper.make_node(
@@ -1082,7 +1082,7 @@ def convert_mean(node, **kwargs):
mx_axis = attrs.get("axis", None)
axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
- keepdims = int(attrs.get("keepdims", 0))
+ keepdims = get_boolean_attribute_value(attrs, "keepdims")
if axes is not None:
node = onnx.helper.make_node(
@@ -1117,7 +1117,7 @@ def convert_prod(node, **kwargs):
mx_axis = attrs.get("axis", None)
axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
- keepdims = int(attrs.get("keepdims", 0))
+ keepdims = get_boolean_attribute_value(attrs, "keepdims")
if axes is not None:
node = onnx.helper.make_node(