You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2020/08/03 14:22:10 UTC
[incubator-mxnet] branch v1.x updated: [v1.x Backport] Fix softmax,
logsoftmax failed on empty ndarray (#18602) (#18708)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 73d3a7b [v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602) (#18708)
73d3a7b is described below
commit 73d3a7bc9cef596b5cc8ba6d0e3cf21e1bc08fee
Author: bgawrych <ba...@intel.com>
AuthorDate: Mon Aug 3 16:20:49 2020 +0200
[v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602) (#18708)
* [v1.x] Backport of fix npx.softmax for 0-sized inputs (#18158)
Co-authored-by: Hao Jin <hj...@gmail.com>
* Fix softmax, logsoftmax failed on empty ndarray (#18602)
* Fix failing empty array (log_)softmax
* Modify test for npx (log_)softmax
* Fix softmax, logsoftmax backward failed on empty ndarray (#18710)
Co-authored-by: Yiyan66 <57...@users.noreply.github.com>
Co-authored-by: Hao Jin <hj...@gmail.com>
Co-authored-by: Bart Gawrych <ga...@intel.com>
---
src/operator/nn/log_softmax.cc | 2 +
src/operator/nn/softmax-inl.h | 58 +++++++++++++++-------------
src/operator/nn/softmax.cc | 2 +
src/operator/numpy/np_boolean_mask_assign.cc | 6 ++-
tests/python/unittest/test_numpy_op.py | 57 +++++++++++++++++++++++++++
5 files changed, 96 insertions(+), 29 deletions(-)
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 16324b5..28ae8cf 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
+ if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
@@ -57,6 +58,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
+ if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index f8a3fe4..018d851 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -71,6 +71,7 @@ template<typename OP, bool negate, typename AType, typename DType, typename OTyp
inline void Softmax(Stream<cpu> *s, DType *in, OType *out, IType *length,
Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
+ if (M == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
@@ -186,6 +187,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape,
int axis, const DType temperature) {
index_t M = shape[axis];
+ if (M == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
@@ -402,6 +404,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
+ if (M == 0 || shape.Size() == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
@@ -555,6 +558,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
+ if (M == 0 || shape.Size() == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
@@ -775,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
- if (req[0] == kNullOp) return;
+ if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
@@ -798,35 +802,35 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
type = inputs[1].type_flag_;
}
MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
- IType* mask_ptr = nullptr;
- if (param.use_length.value()) {
- mask_ptr = inputs[1].dptr<IType>();
+ IType* mask_ptr = nullptr;
+ if (param.use_length.value()) {
+ mask_ptr = inputs[1].dptr<IType>();
+ }
+ if (safe_acc) {
+ if (shape.ndim() == 2) {
+ Softmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
+ axis, static_cast<DType>(temperature));
+ } else {
+ Softmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
+ axis, static_cast<DType>(temperature));
}
- if (safe_acc) {
- if (shape.ndim() == 2) {
- Softmax<OP, negate, AType>(
- ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
- axis, static_cast<DType>(temperature));
- } else {
- Softmax<OP, negate, AType>(
- ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
- axis, static_cast<DType>(temperature));
- }
+ } else {
+ if (shape.ndim() == 2) {
+ Softmax<OP, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
+ axis, static_cast<DType>(temperature));
} else {
- if (shape.ndim() == 2) {
- Softmax<OP, negate, DType>(
- ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
- axis, static_cast<DType>(temperature));
- } else {
- Softmax<OP, negate, DType>(
- ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
- axis, static_cast<DType>(temperature));
- }
+ Softmax<OP, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
+ axis, static_cast<DType>(temperature));
}
+ }
});
});
});
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 50cfc2f..9b28b71 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
+ if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
@@ -58,6 +59,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
+ if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
diff --git a/src/operator/numpy/np_boolean_mask_assign.cc b/src/operator/numpy/np_boolean_mask_assign.cc
index ef7cce4..cc58b5a 100644
--- a/src/operator/numpy/np_boolean_mask_assign.cc
+++ b/src/operator/numpy/np_boolean_mask_assign.cc
@@ -221,10 +221,9 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
// If there's no True in mask, return directly
if (valid_num == 0) return;
- const TShape& vshape = inputs[2].shape_;
-
if (inputs.size() == 3U) {
// tensor case
+ const TShape& vshape = inputs.at(2).shape_;
if (inputs[2].shape_.Size() != 1) {
auto vndim = vshape.ndim();
auto dndim = dshape.ndim();
@@ -254,6 +253,8 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}
if (inputs.size() == 3U) {
+ // tensor case
+ const TShape& vshape = inputs.at(2).shape_;
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
if (inputs[2].shape_.Size() == 1) {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
@@ -269,6 +270,7 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}
});
} else {
+ // scalar case
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided";
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 5d3c03e..97c7d86 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1558,6 +1558,63 @@ def test_npx_batch_norm():
data_grad_req,
gamma_grad_req, beta_grad_req)
+@with_seed()
+@use_np
+def test_npx_softmax():
+ class TestSoftmax(HybridBlock):
+ def __init__(self, axis):
+ super(TestSoftmax, self).__init__()
+ self._axis = axis
+
+ def hybrid_forward(self, F, a):
+ return F.npx.softmax(a, axis=axis)
+
+ class TestLogSoftmax(HybridBlock):
+ def __init__(self, axis):
+ super(TestLogSoftmax, self).__init__()
+ self._axis = axis
+
+ def hybrid_forward(self, F, a):
+ return F.npx.log_softmax(a, axis=axis)
+
+ def np_softmax(x, axis=-1):
+ if (x.shape[axis] == 0):
+ return _np.sum(x, axis=axis, keepdims=True)
+ x = x - _np.max(x, axis=axis, keepdims=True)
+ x = _np.exp(x)
+ x /= _np.sum(x, axis=axis, keepdims=True)
+ return x
+
+ def np_log_softmax(x, axis=-1):
+ return _np.log(np_softmax(x, axis))
+
+ #(operator, function) tuples
+ tested_ops = [(TestSoftmax, np_softmax),
+ (TestLogSoftmax, np_log_softmax)]
+
+ # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
+ for SoftmaxOp, softmax_function in tested_ops:
+ for hybridize in [True, False]:
+ for shape in [(3, 0, 4), (0, 0)]:
+ mx_a = np.random.uniform(size=shape)
+ mx_a.attach_grad()
+ for axis in range(-len(shape), len(shape)):
+ test_softmax_op = SoftmaxOp(axis)
+ if hybridize:
+ test_softmax_op.hybridize()
+
+ with mx.autograd.record():
+ mx_out = test_softmax_op(mx_a)
+
+ mx_out.wait_to_read()
+
+ np_out = softmax_function(mx_a.asnumpy(), axis)
+ assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
+
+ mx_out.backward()
+ mx_a.grad.wait_to_read()
+ assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)
+
@with_seed()
@use_np