You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/11/20 05:39:15 UTC

[incubator-mxnet] branch master updated: Fix ONNX export of keepdims param (#12924)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b8d1c02  Fix ONNX export of keepdims param (#12924)
b8d1c02 is described below

commit b8d1c02f79947968c5f06b89f15cb021115767dc
Author: Moritz Maxeiner <mo...@gmail.com>
AuthorDate: Tue Nov 20 06:38:56 2018 +0100

    Fix ONNX export of keepdims param (#12924)
    
    - keepdims is stored as a boolean string, not an int
    - Operators:
      - argmax
      - argmin
      - min
      - max
      - prod
      - mean
---
 python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 7e84cea..73ca07b 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -955,7 +955,7 @@ def convert_argmax(node, **kwargs):
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
     axis = int(attrs.get("axis"))
-    keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs  else 1
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
     node = onnx.helper.make_node(
         'ArgMax',
@@ -975,7 +975,7 @@ def convert_argmin(node, **kwargs):
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
     axis = int(attrs.get("axis"))
-    keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs  else 1
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
     node = onnx.helper.make_node(
         'ArgMin',
@@ -1012,7 +1012,7 @@ def convert_min(node, **kwargs):
     mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
 
-    keepdims = int(attrs.get("keepdims", 0))
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
     if axes is not None:
         node = onnx.helper.make_node(
@@ -1047,7 +1047,7 @@ def convert_max(node, **kwargs):
     mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
 
-    keepdims = int(attrs.get("keepdims", 0))
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
     if axes is not None:
         node = onnx.helper.make_node(
@@ -1082,7 +1082,7 @@ def convert_mean(node, **kwargs):
     mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
 
-    keepdims = int(attrs.get("keepdims", 0))
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
     if axes is not None:
         node = onnx.helper.make_node(
@@ -1117,7 +1117,7 @@ def convert_prod(node, **kwargs):
     mx_axis = attrs.get("axis", None)
     axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None
 
-    keepdims = int(attrs.get("keepdims", 0))
+    keepdims = get_boolean_attribute_value(attrs, "keepdims")
 
     if axes is not None:
         node = onnx.helper.make_node(