You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/12/29 16:39:58 UTC

[incubator-mxnet] branch master updated: ONNX import: Hardmax (#13717)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 991bf3b  ONNX import: Hardmax (#13717)
991bf3b is described below

commit 991bf3b64f186295345a14f3ce4e6a8f364e8bed
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Sat Dec 29 08:39:42 2018 -0800

    ONNX import: Hardmax (#13717)
    
    * ONNX import: Hardmax
    
    * Fix lint errors
    
    * add github link for issue with reshape
---
 .../mxnet/contrib/onnx/onnx2mx/_import_helper.py   |  5 +++--
 .../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 26 ++++++++++++++++++++++
 tests/python-pytest/onnx/test_cases.py             |  3 ++-
 3 files changed, 31 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index 2ceabae..5b33f9f 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -23,7 +23,7 @@ from ._op_translations import add, subtract, multiply, divide, absolute, negativ
 from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
 from ._op_translations import softplus, shape, gather, lp_pooling, size
 from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
-from ._op_translations import concat
+from ._op_translations import concat, hardmax
 from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
 from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
 from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
@@ -144,5 +144,6 @@ _convert_map = {
     'HardSigmoid'       : hardsigmoid,
     'LpPool'            : lp_pooling,
     'DepthToSpace'      : depthtospace,
-    'SpaceToDepth'      : spacetodepth
+    'SpaceToDepth'      : spacetodepth,
+    'Hardmax'           : hardmax
 }
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 7028325..ce0e0e5 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -714,3 +714,29 @@ def spacetodepth(attrs, inputs, proto_obj):
     new_attrs = translation_utils._fix_attribute_names(attrs, {'blocksize':'block_size'})
 
     return "space_to_depth", new_attrs, inputs
+
+
+def hardmax(attrs, inputs, proto_obj):
+    """Returns batched one-hot vectors."""
+    input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0]
+    input_shape = input_tensor_data[1]
+
+    axis = int(attrs.get('axis', 1))
+    axis = axis if axis >= 0 else len(input_shape) + axis
+
+    if axis == len(input_shape) - 1:
+        amax = symbol.argmax(inputs[0], axis=-1)
+        one_hot = symbol.one_hot(amax, depth=input_shape[-1])
+        return one_hot, attrs, inputs
+
+    # since reshape doesn't take a tensor for shape,
+    # computing with np.prod. This needs to be changed to
+    # to use mx.sym.prod() when mx.sym.reshape() is fixed.
+    # (https://github.com/apache/incubator-mxnet/issues/10789)
+    new_shape = (int(np.prod(input_shape[:axis])),
+                 int(np.prod(input_shape[axis:])))
+    reshape_op = symbol.reshape(inputs[0], new_shape)
+    amax = symbol.argmax(reshape_op, axis=-1)
+    one_hot = symbol.one_hot(amax, depth=new_shape[-1])
+    hardmax_op = symbol.reshape(one_hot, input_shape)
+    return hardmax_op, attrs, inputs
diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py
index 92e80e0..6a189b6 100644
--- a/tests/python-pytest/onnx/test_cases.py
+++ b/tests/python-pytest/onnx/test_cases.py
@@ -90,7 +90,8 @@ IMPLEMENTED_OPERATORS_TEST = {
                'test_averagepool_2d_strides',
                'test_averagepool_3d',
                'test_LpPool_',
-               'test_split_equal'
+               'test_split_equal',
+               'test_hardmax'
                ],
     'export': ['test_random_uniform',
                'test_random_normal',