You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/19 17:43:36 UTC
[incubator-mxnet] branch master updated: Softmax with length
(#15169)
This is an automated email from the ASF dual-hosted git repository.
haoj pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 076b2f3 Softmax with length (#15169)
076b2f3 is described below
commit 076b2f330c60f05cb939beea28dd04cd571a34c0
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Fri Jul 19 10:43:10 2019 -0700
Softmax with length (#15169)
* softmax with length forward
* softmax with length backward
* new macro to reduce compile-time heap usage and limit length to integers only
* address comments
---
src/operator/mxnet_op.h | 51 ++++
src/operator/nn/softmax-inl.h | 428 ++++++++++++++++++++++++++++-----
src/operator/nn/softmax.cc | 33 ++-
tests/python/unittest/test_operator.py | 39 ++-
4 files changed, 487 insertions(+), 64 deletions(-)
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index f17b708..52788f6 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -363,6 +363,57 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}
+#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\
+ switch (type) { \
+ case mshadow::kFloat32: \
+ { \
+ typedef float DType; \
+ LOG(FATAL) << "This operation only support " \
+ "integer types, not float32"; \
+ } \
+ break; \
+ case mshadow::kFloat64: \
+ { \
+ typedef double DType; \
+ LOG(FATAL) << "This operation only support " \
+ "integer types, not float64"; \
+ } \
+ break; \
+ case mshadow::kFloat16: \
+ { \
+ typedef mshadow::half::half_t DType; \
+ LOG(FATAL) << "This operation only support " \
+ "integer types, not float16"; \
+ } \
+ break; \
+ case mshadow::kUint8: \
+ { \
+ typedef uint8_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kInt8: \
+ { \
+ typedef int8_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kInt32: \
+ { \
+ typedef int32_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kInt64: \
+ { \
+ typedef int64_t DType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ default: \
+ LOG(FATAL) << "Unknown type enum " << type; \
+ }
+
/*!
* \brief assign the val to out according
* to request in Kernel::Launch
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index d6113b0..2c82d83 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -75,7 +75,7 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
index_t sa = stride[axis];
#pragma omp parallel for
- for (int i = 0; i < static_cast<int>(N); ++i) {
+ for (index_t i = 0; i < N; ++i) {
index_t base = unravel_dot(i, sshape, stride);
DType mmax = negate ? -in[base] : in[base];
@@ -113,6 +113,60 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
}
}
+template<typename OP, bool negate, typename AType, typename DType, typename OType,
+ typename IType, int ndim>
+inline void SoftmaxWithLength(Stream<cpu> *s, DType *in, OType *out, IType *length,
+ Shape<ndim> shape, int axis, const DType temperature) {
+ index_t M = shape[axis];
+ index_t N = shape.Size()/M;
+ Shape<ndim> stride = calc_stride(shape);
+ Shape<ndim> sshape = shape;
+ sshape[axis] = 1;
+ index_t sa = stride[axis];
+
+ #pragma omp parallel for
+ for (index_t i = 0; i < N; ++i) {
+ index_t len = static_cast<index_t>(length[i]);
+ index_t base = unravel_dot(i, sshape, stride);
+
+ DType mmax = negate ? -in[base] : in[base];
+ DType val;
+ for (index_t j = 1; j < len; ++j) {
+ val = negate ? -in[base + j*sa] : in[base + j*sa];
+ if (mmax < val) mmax = val;
+ }
+ for (index_t j = len; j < M; ++j) {
+ out[base + j*sa] = OType(0.0f);
+ }
+
+ AType sum = AType(0);
+ DType in_val;
+ // By default temperature is 1.0.
+ // Adding a branch here to save the CPU 'divide-by-1' computation at runtime
+ if (temperature == 1.0) {
+ for (index_t j = 0; j < len; ++j) {
+ in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+ sum += std::exp(in_val - mmax);
+ }
+
+ for (index_t j = 0; j < len; ++j) {
+ in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+ out[base + j*sa] = OP::Map(in_val - mmax, sum);
+ }
+ } else {
+ for (index_t j = 0; j < len; ++j) {
+ in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+ sum += std::exp((in_val - mmax)/temperature);
+ }
+
+ for (index_t j = 0; j < len; ++j) {
+ in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+ out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum);
+ }
+ }
+ }
+}
+
struct softmax_bwd {
template<typename DType, typename AType>
@@ -136,7 +190,7 @@ struct log_softmax_bwd {
template<typename OP1, typename OP2, int Req, bool negate,
- typename AType, typename DType, typename OType, int ndim>
+ typename AType, typename DType, typename OType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const DType temperature) {
@@ -148,7 +202,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
index_t sa = stride[axis];
#pragma omp parallel for
- for (int i = 0; i < static_cast<int>(N); ++i) {
+ for (index_t i = 0; i < N; ++i) {
index_t base = unravel_dot(i, sshape, stride);
AType sum = AType(0);
@@ -177,10 +231,55 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
}
}
+template<typename OP1, typename OP2, int Req, bool negate,
+ typename AType, typename DType, typename OType, typename IType, int ndim>
+inline void SoftmaxWithLengthGrad(Stream<cpu> *s, OType *out, OType *ograd,
+ DType *igrad, IType *length, Shape<ndim> shape,
+ int axis, const DType temperature) {
+ index_t M = shape[axis];
+ index_t N = shape.Size()/M;
+ Shape<ndim> stride = calc_stride(shape);
+ Shape<ndim> sshape = shape;
+ sshape[axis] = 1;
+ index_t sa = stride[axis];
+
+ #pragma omp parallel for
+ for (index_t i = 0; i < N; ++i) {
+ index_t base = unravel_dot(i, sshape, stride);
+ index_t len = static_cast<index_t>(length[i]);
+
+ AType sum = AType(0);
+ for (index_t j = 0; j < len; ++j) {
+ sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
+ }
+
+ // By default temperature is 1.0.
+ // Adding a branch here to save the CPU 'divide-by-1' computation at runtime
+ DType final_result;
+ if (temperature == 1.0) {
+ for (index_t j = 0; j < M; ++j) {
+ final_result = negate ?
+ -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) :
+ OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
+ final_result = (j < len) ? final_result : DType(0.0f);
+ KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
+ }
+ } else {
+ for (index_t j = 0; j < M; ++j) {
+ final_result = negate ?
+ -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature :
+ OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature;
+ final_result = (j < len) ? final_result : DType(0.0f);
+ KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
+ }
+ }
+ }
+}
+
#ifdef __CUDACC__
template<int x_bits, typename OP, bool negate, typename AType, int ndim,
- typename DType, typename OType>
+ typename DType, typename OType>
__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
const double temperature) {
@@ -235,9 +334,68 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
}
+template<int x_bits, typename OP, bool negate, typename AType, int ndim,
+ typename DType, typename OType, typename IType>
+__global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length,
+ index_t M, int axis, Shape<ndim> sshape,
+ Shape<ndim> stride, const double temperature) {
+ const unsigned x_size = 1 << x_bits;
+ __shared__ AType smem[x_size];
+ index_t sa = stride[axis];
+ index_t base = unravel_dot(blockIdx.x, sshape, stride);
+ index_t x = threadIdx.x;
+ index_t len = static_cast<index_t>(length[blockIdx.x]);
+
+ red::maximum::SetInitValue(smem[x]);
+ for (index_t i = x; i < len; i += x_size) {
+ smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+ }
+ __syncthreads();
+ cuda::Reduce1D<red::maximum, x_bits>(smem);
+ __syncthreads();
+ DType smax = smem[0];
+ __syncthreads();
+
+ red::sum::SetInitValue(smem[x]);
+ DType val;
+ for (index_t i = x; i < len; i += x_size) {
+ val = negate ? -in[base + i*sa]:in[base + i*sa];
+ smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
+ }
+ __syncthreads();
+ cuda::Reduce1D<red::sum, x_bits>(smem);
+ __syncthreads();
+ AType ssum = smem[0];
+ __syncthreads();
+
+ for (index_t i = x; i < M; i += x_size) {
+ val = negate ? -in[base + i*sa] : in[base + i*sa];
+ out[base + i*sa] =
+ (i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) : OType(0.0f);
+ }
+}
+
+template<typename OP, bool negate, typename AType, typename DType, typename OType,
+ typename IType, int ndim>
+inline void SoftmaxWithLength(Stream<gpu> *s, DType *in, OType *out, IType *length,
+ Shape<ndim> shape, int axis, const double temperature) {
+ const int x_bits = 7;
+ const int x_size = 1 << x_bits;
+ index_t M = shape[axis];
+ index_t N = shape.Size()/M;
+ Shape<ndim> stride = calc_stride(shape);
+ Shape<ndim> sshape = shape;
+ sshape[axis] = 1;
+
+ softmax_with_length_kernel<x_bits, OP, negate, AType, ndim>
+ <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+ in, out, length, M, axis, sshape, stride, temperature);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
+}
+
template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
- typename DType, typename OType>
+ typename DType, typename OType>
__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, const double temperature) {
@@ -269,7 +427,7 @@ __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
- typename DType, typename OType>
+ typename DType, typename OType>
inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const double temperature) {
@@ -286,6 +444,60 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
}
+
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+ typename DType, typename OType, typename IType>
+__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad,
+ IType *length, index_t M, int axis,
+ Shape<ndim> sshape, Shape<ndim> stride,
+ const double temperature) {
+ const unsigned x_size = 1 << x_bits;
+ __shared__ AType smem[x_size];
+ index_t sa = stride[axis];
+ index_t base = unravel_dot(blockIdx.x, sshape, stride);
+ index_t x = threadIdx.x;
+ index_t len = static_cast<index_t>(length[blockIdx.x]);
+
+ red::sum::SetInitValue(smem[x]);
+ for (index_t i = x; i < len; i += x_size) {
+ smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
+ }
+ __syncthreads();
+ cuda::Reduce1D<red::sum, x_bits>(smem);
+ __syncthreads();
+ AType ssum = smem[0];
+ __syncthreads();
+
+ DType final_result;
+ for (index_t i = x; i < M; i += x_size) {
+ final_result =
+ negate ?
+ -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) :
+ OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
+ final_result = (i < len) ? final_result : DType(0.0f);
+ KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast<DType>(temperature));
+ }
+}
+
+
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+ typename DType, typename OType, typename IType>
+inline void SoftmaxWithLengthGrad(Stream<gpu> *s, OType *out, OType *ograd,
+ DType *igrad, IType *length, Shape<ndim> shape, int axis,
+ const double temperature) {
+ const int x_bits = 7;
+ const int x_size = 1 << x_bits;
+ index_t M = shape[axis];
+ index_t N = shape.Size()/M;
+ Shape<ndim> stride = calc_stride(shape);
+ Shape<ndim> sshape = shape;
+ sshape[axis] = 1;
+
+ softmax_with_length_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
+ <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+ out, ograd, igrad, length, M, axis, sshape, stride, temperature);
+ MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_with_length_grad_kernel);
+}
#endif
} // namespace mxnet_op
@@ -295,6 +507,7 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
int axis;
dmlc::optional<double> temperature;
dmlc::optional<int> dtype;
+ dmlc::optional<bool> use_length;
DMLC_DECLARE_PARAMETER(SoftmaxParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
.describe("The axis along which to compute softmax.");
@@ -307,6 +520,9 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
.set_default(dmlc::optional<int>())
.describe("DType of the output in case this can't be inferred. "
"Defaults to the same as input's dtype if not defined (dtype=None).");
+ DMLC_DECLARE_FIELD(use_length)
+ .set_default(dmlc::optional<bool>(false))
+ .describe("Whether to use the length input as a mask over the data input.");
}
};
@@ -315,27 +531,71 @@ static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
return param.dtype.has_value() && param.dtype.value() != -1;
}
+static inline bool softmax_use_length(const nnvm::NodeAttrs& attrs) {
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ return param.use_length.value();
+}
+
static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
- CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);
if (softmax_has_dtype_override(attrs)) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
type_assign(&(*in_attrs)[0], (*out_attrs)[0]);
return true;
} else {
- return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs);
+ std::vector<int> tmp = {in_attrs->at(0)};
+ return ElemwiseType<1, 1>(attrs, &tmp, out_attrs);
+ }
+}
+
+static inline bool SoftmaxOpShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector *in_attrs,
+ mxnet::ShapeVector *out_attrs) {
+ CHECK_EQ(out_attrs->size(), 1U);
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), param.use_length.value() ? 2U : 1U);
+
+ if (param.use_length.value()) {
+ mxnet::TShape& dshape = in_attrs->at(0);
+ mxnet::TShape tmp_shape((dshape.ndim() == 1) ? 1U : dshape.ndim() - 1, 1);
+ int j = 0;
+ for (int i = 0; i < dshape.ndim(); ++i) {
+ if (i != param.axis) {
+ tmp_shape[j++] = dshape[i];
+ }
+ }
+ SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape);
}
+ mxnet::ShapeVector tmp = {in_attrs->at(0)};
+ return ElemwiseShape<1, 1>(attrs, &tmp, out_attrs);
}
static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
- if (softmax_has_dtype_override(attrs)) {
- return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
+ if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+ if (softmax_use_length(attrs)) {
+ mxnet::ShapeVector ins = {in_attrs->at(0), in_attrs->at(1), in_attrs->at(3)};
+ mxnet::ShapeVector dgrad = {out_attrs->at(0)};
+ bool res = ElemwiseShape<3, 1>(attrs, &ins, &dgrad);
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, ins[0]);
+ SHAPE_ASSIGN_CHECK(*in_attrs, 1, ins[1]);
+ SHAPE_ASSIGN_CHECK(*in_attrs, 3, ins[2]);
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, dgrad[0]);
+ mxnet::ShapeVector length = {in_attrs->at(2)};
+ mxnet::ShapeVector lgrad = {out_attrs->at(1)};
+ res = (res && ElemwiseShape<1, 1>(attrs, &length, &lgrad));
+ SHAPE_ASSIGN_CHECK(*in_attrs, 2, length[0]);
+ SHAPE_ASSIGN_CHECK(*out_attrs, 1, lgrad[0]);
+ return res;
+ } else {
+ return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
+ }
} else {
return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs);
}
@@ -344,17 +604,21 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
- CHECK_EQ(out_attrs->size(), 1);
- if (softmax_has_dtype_override(attrs)) {
- CHECK_EQ(in_attrs->size(), 3);
+ CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);
+ if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+ CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 4U : 3U);
int in_dtype = (*in_attrs)[1];
- int out_dtype = (*in_attrs)[2];
+ int out_dtype = (*in_attrs)[softmax_use_length(attrs) ? 3 : 2];
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
+ if (softmax_use_length(attrs)) {
+ TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(2));
+ }
- return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
+ return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1 &&
+ (*out_attrs)[1] != -1 && (*in_attrs)[1] != -1;
} else {
- CHECK_EQ(in_attrs->size(), 2);
+ CHECK_EQ(in_attrs->size(), 2U);
int out_dtype = (*in_attrs)[1];
TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
@@ -365,20 +629,31 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
static inline std::vector<std::pair<int, int> >
SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
- if (softmax_has_dtype_override(attrs)) {
- return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
+ if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+ if (softmax_use_length(attrs)) {
+ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 1}, {3, 0}};
+ } else {
+ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
+ }
} else {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
}
}
static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) {
- return softmax_has_dtype_override(attrs) ? 3 : 2;
+ if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+ return softmax_use_length(attrs) ? 4 : 3;
+ }
+ return 2;
}
static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) {
- if (softmax_has_dtype_override(attrs)) {
- return std::vector<std::string>{"ograd", "data", "output"};
+ if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+ if (softmax_use_length(attrs)) {
+ return std::vector<std::string>{"ograd", "data", "length", "output"};
+ } else {
+ return std::vector<std::string>{"ograd", "data", "output"};
+ }
} else {
return std::vector<std::string>{"ograd", "output"};
}
@@ -388,7 +663,7 @@ struct SoftmaxFGradient {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
- if (softmax_has_dtype_override(n->attrs)) {
+ if (softmax_has_dtype_override(n->attrs) || softmax_use_length(n->attrs)) {
return ElemwiseGradUseInOut {op_name}(n, ograds);
} else {
return ElemwiseGradUseOut {op_name}(n, ograds);
@@ -419,30 +694,46 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
- if (safe_acc) {
- if (shape.ndim() == 2) {
- Softmax<OP, negate, AType>(
- ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), shape.get<2>(), axis,
- static_cast<DType>(temperature));
+ if (!param.use_length.value()) {
+ if (safe_acc) {
+ if (shape.ndim() == 2) {
+ Softmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), shape.get<2>(), axis,
+ static_cast<DType>(temperature));
+ } else {
+ Softmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), shape.get<3>(), axis,
+ static_cast<DType>(temperature));
+ }
} else {
- Softmax<OP, negate, AType>(
- ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), shape.get<3>(), axis,
- static_cast<DType>(temperature));
+ if (shape.ndim() == 2) {
+ Softmax<OP, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), shape.get<2>(), axis,
+ static_cast<DType>(temperature));
+ } else {
+ Softmax<OP, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), shape.get<3>(), axis,
+ static_cast<DType>(temperature));
+ }
}
} else {
- if (shape.ndim() == 2) {
- Softmax<OP, negate, DType>(
+ MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, {
+ if (shape.ndim() == 2) {
+ SoftmaxWithLength<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), shape.get<2>(), axis,
- static_cast<DType>(temperature));
- } else {
- Softmax<OP, negate, DType>(
+ outputs[0].dptr<OType>(), inputs[1].dptr<IType>(),
+ shape.get<2>(), axis, static_cast<DType>(temperature));
+ } else {
+ SoftmaxWithLength<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<OType>(), shape.get<3>(), axis,
- static_cast<DType>(temperature));
- }
+ outputs[0].dptr<OType>(), inputs[1].dptr<IType>(),
+ shape.get<3>(), axis, static_cast<DType>(temperature));
+ }
+ });
}
});
});
@@ -464,35 +755,56 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
+ out_idx = softmax_use_length(attrs) ? 3 : out_idx;
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
- if (safe_acc) {
- if (shape.ndim() == 2) {
- SoftmaxGrad<OP1, OP2, Req, negate, AType>(
- ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
- inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
- shape.get<2>(), axis, static_cast<DType>(temperature));
+ if (!softmax_use_length(attrs)) {
+ if (safe_acc) {
+ if (shape.ndim() == 2) {
+ SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+ inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+ shape.get<2>(), axis, static_cast<DType>(temperature));
+ } else {
+ SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+ inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+ shape.get<3>(), axis, static_cast<DType>(temperature));
+ }
} else {
- SoftmaxGrad<OP1, OP2, Req, negate, AType>(
- ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
- inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
- shape.get<3>(), axis, static_cast<DType>(temperature));
+ if (shape.ndim() == 2) {
+ SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+ inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+ shape.get<2>(), axis, static_cast<DType>(temperature));
+ } else {
+ SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+ ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+ inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+ shape.get<3>(), axis, static_cast<DType>(temperature));
+ }
}
} else {
- if (shape.ndim() == 2) {
- SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+ MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, {
+ if (req[1] != kNullOp) {
+ mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
+ ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
+ }
+ if (shape.ndim() == 2) {
+ SoftmaxWithLengthGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
- shape.get<2>(), axis, static_cast<DType>(temperature));
- } else {
- SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+ inputs[2].dptr<IType>(), shape.get<2>(), axis, static_cast<DType>(temperature));
+ } else {
+ SoftmaxWithLengthGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
- shape.get<3>(), axis, static_cast<DType>(temperature));
- }
+ inputs[2].dptr<IType>(), shape.get<3>(), axis, static_cast<DType>(temperature));
+ }
+ });
}
});
});
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index e44bbbb..5a581e4 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -59,14 +59,23 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
- CHECK_EQ(in_attrs->size(), 1);
- CHECK_EQ(out_attrs->size(), 1);
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+
+ if (param.use_length.value()) {
+ auto& out_stype = out_attrs->at(0);
+ return storage_type_assign(&out_stype, kDefaultStorage,
+ dispatch_mode, DispatchMode::kFCompute);
+ }
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
#endif
+
+
NNVM_REGISTER_OP(softmax)
.describe(R"code(Applies the softmax function.
@@ -92,6 +101,13 @@ Example::
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<SoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+ [](const NodeAttrs& attrs){
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ return (param.use_length.value()) ?
+ std::vector<std::string>{"data", "length"} :
+ std::vector<std::string>{"data"};
+})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
@@ -103,20 +119,27 @@ Example::
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
+// .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
-.set_num_inputs(1)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ return (param.use_length.value()) ? 2 : 1;
+ })
.set_num_outputs(1)
-.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<mxnet::FInferShape>("FInferShape", SoftmaxOpShape)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "The input array.")
+.add_argument("length", "NDArray-or-Symbol", "The length array.")
.add_arguments(SoftmaxParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_softmax)
.set_num_inputs(SoftmaxGradOpNumInputs)
-.set_num_outputs(1)
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+ return (softmax_use_length(attrs) ? 2 : 1);
+ })
.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
.set_attr<mxnet::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 749f0f2..fea07f5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5196,6 +5196,39 @@ def test_softmax_dtype():
check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
'float32', 'float64', 'float64')
+
+@with_seed()
+def test_softmax_with_length():
+ def np_softmax_with_length(data, length):
+ res = np.zeros(data.shape)
+ for i in range(length.shape[0]):
+ for j in range(length.shape[1]):
+ leng = int(length[i, j])
+ res[i, 0:leng, j] = np_softmax(data[i, 0:leng, j])
+ return res
+
+ ndim = 3
+ shape = rand_shape_nd(ndim, dim=10)
+ len_shape = list(shape)
+ del len_shape[1]
+ len_shape = tuple(len_shape)
+ for dtype in [np.float16, np.float32, np.float64]:
+ mx_data = rand_ndarray(shape, dtype=dtype)
+ np_data = mx_data.asnumpy()
+ np_length = np.random.randint(1, shape[1] + 1, len_shape)
+ mx_length = mx.nd.array(np_length, dtype=np.int32)
+ np_out = np_softmax_with_length(np_data, np_length)
+ data = mx.sym.Variable("data")
+ length = mx.sym.Variable("length")
+ mx_sym = mx.sym.softmax(data=data, length=length, use_length=True, axis=1)
+ location = {"data": mx_data, "length": mx_length}
+ rtol = 1e-2 if dtype == np.float16 else 1e-3
+ atol = 1e-4 if dtype == np.float16 else 1e-5
+ check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy")
+ check_symbolic_backward(mx_sym, location, [np.ones(shape, dtype=dtype)],
+ [np.zeros(shape), np.zeros(len_shape, dtype=np.int32)], rtol=1e-2, atol=1e-3, dtype="asnumpy")
+
+
@with_seed()
def test_pick():
def test_pick_helper(index_type=np.int32):
@@ -8034,7 +8067,11 @@ def test_op_all_names_monitor():
check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output'])
sm_sym = mx.sym.softmax(data, name='softmax')
- check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output'])
+ check_name(sm_sym, ['data', 'softmax_data', 'softmax_output'])
+
+ length = mx.sym.Variable("length", shape=(10, 10, 10))
+ sm_sym = mx.sym.softmax(data, length, axis=1, use_length=True, name='softmax')
+ check_name(sm_sym, ['data', 'softmax_data', 'length', 'softmax_length', 'softmax_output'])
sa_sym = mx.sym.SoftmaxActivation(data, name='softmax')
check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])