You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/12/20 03:43:44 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX fix softmax (#19691)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 5fce08a  [v1.x] ONNX fix softmax (#19691)
5fce08a is described below

commit 5fce08a57dcbc968d8bb49c394e44642235b8288
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Sat Dec 19 19:41:48 2020 -0800

    [v1.x] ONNX fix softmax (#19691)
    
    * fix softmax
    
    * add test
    
    * fix typo
    
    * fix test shape
    
    * update test data type
    
    * add more tests
    
    * fix temperature
    
    * fix onnx2mx
    
    * remove temperature
    
    * update msg
    
    * update msg
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 82 +++++++++++++++++++---
 .../mxnet/contrib/onnx/onnx2mx/_op_translations.py |  2 +-
 tests/python-pytest/onnx/test_operators.py         | 19 ++++-
 3 files changed, 93 insertions(+), 10 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 6caee80..d301975 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -874,20 +874,86 @@ def convert_softmax(node, **kwargs):
     """Map MXNet's softmax operator attributes to onnx's Softmax operator
     and return the created node.
     """
+    from onnx.helper import make_node
+    from onnx import TensorProto
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
     axis = int(attrs.get("axis", -1))
+    temperature = attrs.get("temperature", None)
+    if temperature and float(temperature) != 1.0:
+        raise NotImplementedError("Temperature is not supported for now.")
+    use_length = attrs.get("use_length", None)
+    input_type = kwargs["in_type"]
+    data = input_nodes[0]
 
-    softmax_node = onnx.helper.make_node(
-        "Softmax",
-        input_nodes,
-        [name],
-        axis=axis,
-        name=name
-    )
+    nodes = [
+        make_node("Exp", [data], [name+"_exp_out"]),
+        make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1)
+    ]
+    if len(input_nodes) == 1:
+        nodes += [
+            make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name)
+        ]
+        return nodes
+    elif use_length == "True":
+        length = input_nodes[1]
 
-    return [softmax_node]
+        nodes += [
+            # const nodes
+            create_tensor([axis], name+"_axis", kwargs["initializer"]),
+            create_tensor([], name+"_void", kwargs["initializer"]),
+            create_tensor([0], name+"_0", kwargs["initializer"]),
+            create_tensor([1], name+"_1", kwargs["initializer"]),
+            create_const_scalar_node(name+'_-1_s', np.int64(-1), kwargs),
+            create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
+            create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
+            # cast data type
+            make_node("Cast", [length], [name+"_length"], to=int(TensorProto.INT64)),
+            make_node("Cast", [name+"_0"], [name+"_0_itype"], to=input_type),
+            make_node("Cast", [name+"_1"], [name+"_1_itype"], to=input_type),
+            # softmax output
+            make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]),
+            # update axis
+            make_node("Shape", [data], [name+"_shape0_out"]),
+            make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
+            make_node("Add", [name+"_in_dim", name+"_axis"], [name+"_dim+axis"]),
+            make_node("Less", [name+"_axis", name+"_0_s"], [name+"_less0_out"]),
+            make_node("Where", [name+"_less0_out", name+"_dim+axis", name+"_axis"], [name+"_final_axis"]),
+            # data mask
+            make_node("Add", [name+"_final_axis", name+"_1_s"], [name+"_final_axis+1"]),
+            make_node("Slice", [name+"_shape0_out", name+"_final_axis", name+"_final_axis+1"], [name+"_axis_dim"]),
+            make_node("Reshape", [name+"_axis_dim", name+"_void"], [name+"_axis_dim_s"]),
+            make_node("Range", [name+"_0_s", name+"_axis_dim_s", name+"_1_s"], [name+"_range0_out"]),
+            # one hot for axis
+            make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]),
+            make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range1_out"]),
+            make_node("Equal", [name+"_range1_out", name+"_final_axis"], [name+"_equal_out"]),
+            make_node("Cast", [name+"_equal_out"], [name+"_one_hot"], to=int(TensorProto.INT64)),
+            # reshape data mask for less
+            make_node("Sub", [name+"_axis_dim_s", name+"_1_s"], [name+"_sub0_out"]),
+            make_node("Mul", [name+"_one_hot", name+"_sub0_out"], [name+"_mul0_out"]),
+            make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add0_out"]),
+            make_node('Reshape', [name+"_range0_out", name+"_add0_out"], [name+"_reshape0_out"]),
+            # reshape length for less
+            make_node("Mul", [name+"_one_hot", name+"_-1_s"], [name+"_mul1_out"]),
+            make_node("Add", [name+"_mul1_out", name+"_1_s"], [name+"_add1_out"]),
+            make_node("Sub", [name+"_shape0_out", name+"_1_s"], [name+"_sub1_out"]),
+            make_node("Mul", [name+"_add1_out", name+"_sub1_out"], [name+"_mul2_out"]),
+            make_node("Add", [name+"_mul2_out", name+"_1_s"], [name+"_add2_out"]),
+            make_node('Reshape', [name+"_length", name+"_add2_out"], [name+"_reshape1_out"]),
+            # mask output
+            make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]),
+            make_node("Cast", [name+"_less_out"], [name+"_mask"], to=input_type),
+            make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]),
+            make_node("ReduceSum", [name+"_mul3_out"], [name+"_rsum1_out"], axes=[axis], keepdims=1),
+            make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]),
+            make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),
+            make_node("Div", [name+"_mul3_out", name+"_where_out"], [name], name=name)
+        ]
+        return nodes
 
+    else:
+        raise NotImplementedError("use_length must be true when both data and length are paased in.")
 
 # There's also mx.sym.softmax(), which doesn't do cross-entropy loss,
 # just softmax for inference - hence the name convert_softmax_output.
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index 51fe418..69bec1d 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -314,7 +314,7 @@ def _selu(attrs, inputs, proto_obj):
 def softmax(attrs, inputs, proto_obj):
     """Softmax function."""
     if 'axis' not in attrs:
-        attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1})
+        attrs = translation_utils._add_extra_attributes(attrs, {'axis': -1})
     return 'softmax', attrs, inputs
 
 def log_softmax(attrs, inputs, proto_obj):
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index ef4310f..057a279 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -51,7 +51,8 @@ def op_export_test(model_name, Model, inputs, tmp_path):
 
     def onnx_rt(onnx_file, inputs):
         sess = rt.InferenceSession(onnx_file)
-        input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
+        dtype_0 = inputs[0].asnumpy().dtype
+        input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
         pred = sess.run(None, input_dict)[0]
         return pred
 
@@ -309,3 +310,19 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):
     M = def_model('Cast', dtype=dst_dtype)
     x = mx.nd.ones(shape, dtype=src_dtype)
     op_export_test('Cast', M, [x], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32'])
+def test_onnx_export_softmax(tmp_path, dtype):
+    x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype)
+    M1 = def_model('softmax')
+    op_export_test('softmax_1', M1, [x], tmp_path)
+    M2 = def_model('softmax', use_length=True, axis=0)
+    l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int)
+    op_export_test('softmax_2', M2, [x, l2], tmp_path)
+    M3 = def_model('softmax', use_length=True, axis=-1)
+    l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int)
+    op_export_test('softmax_3', M3, [x, l3], tmp_path)
+    M4 = def_model('softmax', use_length=True, axis=1)
+    l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int)
+    op_export_test('softmax_4', M4, [x, l4], tmp_path)