You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2020/08/03 14:22:10 UTC

[incubator-mxnet] branch v1.x updated: [v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602) (#18708)

This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 73d3a7b  [v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602) (#18708)
73d3a7b is described below

commit 73d3a7bc9cef596b5cc8ba6d0e3cf21e1bc08fee
Author: bgawrych <ba...@intel.com>
AuthorDate: Mon Aug 3 16:20:49 2020 +0200

    [v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602) (#18708)
    
    * [v1.x] Backport of fix npx.softmax for 0-sized inputs (#18158)
    
    Co-authored-by: Hao Jin <hj...@gmail.com>
    
    * Fix softmax, logsoftmax failed on empty ndarray (#18602)
    
    * Fix failing empty array (log_)softmax
    
    * Modify test for npx (log_)softmax
    
    * Fix softmax, logsoftmax backward failed on empty ndarray (#18710)
    
    Co-authored-by: Yiyan66 <57...@users.noreply.github.com>
    Co-authored-by: Hao Jin <hj...@gmail.com>
    Co-authored-by: Bart Gawrych <ga...@intel.com>
---
 src/operator/nn/log_softmax.cc               |  2 +
 src/operator/nn/softmax-inl.h                | 58 +++++++++++++++-------------
 src/operator/nn/softmax.cc                   |  2 +
 src/operator/numpy/np_boolean_mask_assign.cc |  6 ++-
 tests/python/unittest/test_numpy_op.py       | 57 +++++++++++++++++++++++++++
 5 files changed, 96 insertions(+), 29 deletions(-)

diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 16324b5..28ae8cf 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
                                    const std::vector<NDArray>& inputs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
+  if (inputs[0].shape().Size() == 0U) return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
@@ -57,6 +58,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                        const std::vector<NDArray>& inputs,
                                        const std::vector<OpReqType>& req,
                                        const std::vector<NDArray>& outputs) {
+  if (inputs[0].shape().Size() == 0U) return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index f8a3fe4..018d851 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -71,6 +71,7 @@ template<typename OP, bool negate, typename AType, typename DType, typename OTyp
 inline void Softmax(Stream<cpu> *s, DType *in, OType *out, IType *length,
                     Shape<ndim> shape, int axis, const DType temperature) {
   index_t M = shape[axis];
+  if (M == 0) return;
   index_t N = shape.Size()/M;
   Shape<ndim> stride = calc_stride(shape);
   Shape<ndim> sshape = shape;
@@ -186,6 +187,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
                         DType *igrad, IType *length, Shape<ndim> shape,
                         int axis, const DType temperature) {
   index_t M = shape[axis];
+  if (M == 0) return;
   index_t N = shape.Size()/M;
   Shape<ndim> stride = calc_stride(shape);
   Shape<ndim> sshape = shape;
@@ -402,6 +404,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
   const int x_bits = 7;
   const int x_size = 1 << x_bits;
   index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
   index_t N = shape.Size()/M;
   Shape<ndim> stride = calc_stride(shape);
   Shape<ndim> sshape = shape;
@@ -555,6 +558,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
   const int x_bits = 7;
   const int x_size = 1 << x_bits;
   index_t M = shape[axis];
+  if (M == 0 || shape.Size() == 0) return;
   index_t N = shape.Size()/M;
   Shape<ndim> stride = calc_stride(shape);
   Shape<ndim> sshape = shape;
@@ -775,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
                     const std::vector<OpReqType>& req,
                     const std::vector<TBlob>& outputs) {
   using namespace mxnet_op;
-  if (req[0] == kNullOp) return;
+  if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
   CHECK_NE(req[0], kAddTo);
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   int axis = CheckAxis(param.axis, inputs[0].ndim());
@@ -798,35 +802,35 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
         type = inputs[1].type_flag_;
       }
       MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
-          IType* mask_ptr = nullptr;
-          if (param.use_length.value()) {
-            mask_ptr = inputs[1].dptr<IType>();
+        IType* mask_ptr = nullptr;
+        if (param.use_length.value()) {
+          mask_ptr = inputs[1].dptr<IType>();
+        }
+        if (safe_acc) {
+          if (shape.ndim() == 2) {
+            Softmax<OP, negate, AType>(
+              ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+              outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
+              axis, static_cast<DType>(temperature));
+          } else {
+            Softmax<OP, negate, AType>(
+              ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+              outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
+              axis, static_cast<DType>(temperature));
           }
-          if (safe_acc) {
-            if (shape.ndim() == 2) {
-              Softmax<OP, negate, AType>(
-                  ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-                  outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
-                  axis, static_cast<DType>(temperature));
-            } else {
-              Softmax<OP, negate, AType>(
-                  ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-                  outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
-                  axis, static_cast<DType>(temperature));
-            }
+        } else {
+          if (shape.ndim() == 2) {
+            Softmax<OP, negate, DType>(
+              ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+              outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
+              axis, static_cast<DType>(temperature));
           } else {
-            if (shape.ndim() == 2) {
-              Softmax<OP, negate, DType>(
-                  ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-                  outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
-                  axis, static_cast<DType>(temperature));
-            } else {
-              Softmax<OP, negate, DType>(
-                  ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-                  outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
-                  axis, static_cast<DType>(temperature));
-            }
+            Softmax<OP, negate, DType>(
+              ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+              outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
+              axis, static_cast<DType>(temperature));
           }
+        }
       });
     });
   });
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 50cfc2f..9b28b71 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
+  if (inputs[0].shape().Size() == 0U) return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
@@ -58,6 +59,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                     const std::vector<NDArray>& inputs,
                                     const std::vector<OpReqType>& req,
                                     const std::vector<NDArray>& outputs) {
+  if (inputs[0].shape().Size() == 0U) return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
diff --git a/src/operator/numpy/np_boolean_mask_assign.cc b/src/operator/numpy/np_boolean_mask_assign.cc
index ef7cce4..cc58b5a 100644
--- a/src/operator/numpy/np_boolean_mask_assign.cc
+++ b/src/operator/numpy/np_boolean_mask_assign.cc
@@ -221,10 +221,9 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
   // If there's no True in mask, return directly
   if (valid_num == 0) return;
 
-  const TShape& vshape = inputs[2].shape_;
-
   if (inputs.size() == 3U) {
     // tensor case
+    const TShape& vshape = inputs.at(2).shape_;
     if (inputs[2].shape_.Size() != 1) {
       auto vndim = vshape.ndim();
       auto dndim = dshape.ndim();
@@ -254,6 +253,8 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
   }
 
   if (inputs.size() == 3U) {
+    // tensor case
+    const TShape& vshape = inputs.at(2).shape_;
     MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
       if (inputs[2].shape_.Size() == 1) {
         Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
@@ -269,6 +270,7 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
       }
     });
   } else {
+    // scalar case
     CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided";
     MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
       Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 5d3c03e..97c7d86 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1558,6 +1558,63 @@ def test_npx_batch_norm():
                             data_grad_req,
                             gamma_grad_req, beta_grad_req)
 
+@with_seed()
+@use_np
+def test_npx_softmax():
+    class TestSoftmax(HybridBlock):
+        def __init__(self, axis):
+            super(TestSoftmax, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, a):
+            return F.npx.softmax(a, axis=axis)
+
+    class TestLogSoftmax(HybridBlock):
+        def __init__(self, axis):
+            super(TestLogSoftmax, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, a):
+            return F.npx.log_softmax(a, axis=axis)
+
+    def np_softmax(x, axis=-1):
+        if (x.shape[axis] == 0):
+            return _np.sum(x, axis=axis, keepdims=True)
+        x = x - _np.max(x, axis=axis, keepdims=True)
+        x = _np.exp(x)
+        x /= _np.sum(x, axis=axis, keepdims=True)
+        return x
+
+    def np_log_softmax(x, axis=-1):
+        return _np.log(np_softmax(x, axis))
+
+    #(operator, function) tuples
+    tested_ops = [(TestSoftmax, np_softmax),
+                  (TestLogSoftmax, np_log_softmax)]
+
+    # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
+    for SoftmaxOp, softmax_function in tested_ops:
+        for hybridize in [True, False]:
+            for shape in [(3, 0, 4), (0, 0)]:
+                mx_a = np.random.uniform(size=shape)
+                mx_a.attach_grad()
+                for axis in range(-len(shape), len(shape)):
+                    test_softmax_op = SoftmaxOp(axis)
+                    if hybridize:
+                        test_softmax_op.hybridize()
+
+                    with mx.autograd.record():
+                        mx_out = test_softmax_op(mx_a)
+
+                    mx_out.wait_to_read()
+
+                    np_out = softmax_function(mx_a.asnumpy(), axis)
+                    assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
+
+                    mx_out.backward()
+                    mx_a.grad.wait_to_read()
+                    assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)
+
 
 @with_seed()
 @use_np