You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/19 17:43:36 UTC

[incubator-mxnet] branch master updated: Softmax with length (#15169)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 076b2f3  Softmax with length (#15169)
076b2f3 is described below

commit 076b2f330c60f05cb939beea28dd04cd571a34c0
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Fri Jul 19 10:43:10 2019 -0700

    Softmax with length (#15169)
    
    * softmax with length forward
    
    * softmax with length backward
    
    * new macro to reduce compile-time heap usage and limit length to integers only
    
    * address comments
---
 src/operator/mxnet_op.h                |  51 ++++
 src/operator/nn/softmax-inl.h          | 428 ++++++++++++++++++++++++++++-----
 src/operator/nn/softmax.cc             |  33 ++-
 tests/python/unittest/test_operator.py |  39 ++-
 4 files changed, 487 insertions(+), 64 deletions(-)

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index f17b708..52788f6 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -363,6 +363,57 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "Unknown type enum " << type;            \
   }
 
+#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      LOG(FATAL) << "This operation only support "         \
+                    "integer types, not float32";          \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      LOG(FATAL) << "This operation only support "         \
+                    "integer types, not float64";          \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    {                                                      \
+      typedef mshadow::half::half_t DType;                 \
+      LOG(FATAL) << "This operation only support "         \
+                    "integer types, not float16";          \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    {                                                      \
+      typedef uint8_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    {                                                      \
+      typedef int8_t DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    {                                                      \
+      typedef int32_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    {                                                      \
+      typedef int64_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
+
 /*!
  * \brief assign the val to out according
  * to request in Kernel::Launch
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index d6113b0..2c82d83 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -75,7 +75,7 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
   index_t sa = stride[axis];
 
   #pragma omp parallel for
-  for (int i = 0; i < static_cast<int>(N); ++i) {
+  for (index_t i = 0; i < N; ++i) {
     index_t base = unravel_dot(i, sshape, stride);
 
     DType mmax = negate ? -in[base] : in[base];
@@ -113,6 +113,60 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
   }
 }
 
+template<typename OP, bool negate, typename AType, typename DType, typename OType,
+         typename IType, int ndim>
+inline void SoftmaxWithLength(Stream<cpu> *s, DType *in, OType *out, IType *length,
+                              Shape<ndim> shape, int axis, const DType temperature) {
+  index_t M = shape[axis];
+  index_t N = shape.Size()/M;
+  Shape<ndim> stride = calc_stride(shape);
+  Shape<ndim> sshape = shape;
+  sshape[axis] = 1;
+  index_t sa = stride[axis];
+
+  #pragma omp parallel for
+  for (index_t i = 0; i < N; ++i) {
+    index_t len = static_cast<index_t>(length[i]);
+    index_t base = unravel_dot(i, sshape, stride);
+
+    DType mmax = negate ? -in[base] : in[base];
+    DType val;
+    for (index_t j = 1; j < len; ++j) {
+      val = negate ? -in[base + j*sa] : in[base + j*sa];
+      if (mmax < val) mmax = val;
+    }
+    for (index_t j = len; j < M; ++j) {
+      out[base + j*sa] = OType(0.0f);
+    }
+
+    AType sum = AType(0);
+    DType in_val;
+    // By default temperature is 1.0.
+    // Adding a branch here to save the CPU 'divide-by-1' computation at runtime
+    if (temperature == 1.0) {
+      for (index_t j = 0; j < len; ++j) {
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        sum += std::exp(in_val - mmax);
+      }
+
+      for (index_t j = 0; j < len; ++j) {
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        out[base + j*sa] = OP::Map(in_val - mmax, sum);
+      }
+    } else {
+      for (index_t j = 0; j < len; ++j) {
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        sum += std::exp((in_val - mmax)/temperature);
+      }
+
+      for (index_t j = 0; j < len; ++j) {
+        in_val = negate ? -in[base + j*sa] : in[base + j*sa];
+        out[base + j*sa] = OP::Map((in_val - mmax)/temperature, sum);
+      }
+    }
+  }
+}
+
 
 struct softmax_bwd {
   template<typename DType, typename AType>
@@ -136,7 +190,7 @@ struct log_softmax_bwd {
 
 
 template<typename OP1, typename OP2, int Req, bool negate,
-  typename AType, typename DType, typename OType, int ndim>
+         typename AType, typename DType, typename OType, int ndim>
 inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
                         DType *igrad, Shape<ndim> shape, int axis,
                         const DType temperature) {
@@ -148,7 +202,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
   index_t sa = stride[axis];
 
   #pragma omp parallel for
-  for (int i = 0; i < static_cast<int>(N); ++i) {
+  for (index_t i = 0; i < N; ++i) {
     index_t base = unravel_dot(i, sshape, stride);
 
     AType sum = AType(0);
@@ -177,10 +231,55 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
   }
 }
 
+template<typename OP1, typename OP2, int Req, bool negate,
+         typename AType, typename DType, typename OType, typename IType, int ndim>
+inline void SoftmaxWithLengthGrad(Stream<cpu> *s, OType *out, OType *ograd,
+                                  DType *igrad, IType *length, Shape<ndim> shape,
+                                  int axis, const DType temperature) {
+  index_t M = shape[axis];
+  index_t N = shape.Size()/M;
+  Shape<ndim> stride = calc_stride(shape);
+  Shape<ndim> sshape = shape;
+  sshape[axis] = 1;
+  index_t sa = stride[axis];
+
+  #pragma omp parallel for
+  for (index_t i = 0; i < N; ++i) {
+    index_t base = unravel_dot(i, sshape, stride);
+    index_t len = static_cast<index_t>(length[i]);
+
+    AType sum = AType(0);
+    for (index_t j = 0; j < len; ++j) {
+      sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
+    }
+
+    // By default temperature is 1.0.
+    // Adding a branch here to save the CPU 'divide-by-1' computation at runtime
+    DType final_result;
+    if (temperature == 1.0) {
+      for (index_t j = 0; j < M; ++j) {
+        final_result = negate ?
+                       -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) :
+                       OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
+        final_result = (j < len) ? final_result : DType(0.0f);
+        KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
+      }
+    } else {
+      for (index_t j = 0; j < M; ++j) {
+        final_result = negate ?
+                       -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature :
+                       OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature;
+        final_result = (j < len) ? final_result : DType(0.0f);
+        KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
+      }
+    }
+  }
+}
+
 
 #ifdef __CUDACC__
 template<int x_bits, typename OP, bool negate, typename AType, int ndim,
-  typename DType, typename OType>
+         typename DType, typename OType>
 __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis,
                                        Shape<ndim> sshape, Shape<ndim> stride,
                                        const double temperature) {
@@ -235,9 +334,68 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
   MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
 }
 
+template<int x_bits, typename OP, bool negate, typename AType, int ndim,
+         typename DType, typename OType, typename IType>
+__global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length,
+                                           index_t M, int axis, Shape<ndim> sshape,
+                                           Shape<ndim> stride, const double temperature) {
+  const unsigned x_size = 1 << x_bits;
+  __shared__ AType smem[x_size];
+  index_t sa = stride[axis];
+  index_t base = unravel_dot(blockIdx.x, sshape, stride);
+  index_t x = threadIdx.x;
+  index_t len = static_cast<index_t>(length[blockIdx.x]);
+
+  red::maximum::SetInitValue(smem[x]);
+  for (index_t i = x; i < len; i += x_size) {
+    smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+  }
+  __syncthreads();
+  cuda::Reduce1D<red::maximum, x_bits>(smem);
+  __syncthreads();
+  DType smax = smem[0];
+  __syncthreads();
+
+  red::sum::SetInitValue(smem[x]);
+  DType val;
+  for (index_t i = x; i < len; i += x_size) {
+    val = negate ? -in[base + i*sa]:in[base + i*sa];
+    smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
+  }
+  __syncthreads();
+  cuda::Reduce1D<red::sum, x_bits>(smem);
+  __syncthreads();
+  AType ssum = smem[0];
+  __syncthreads();
+
+  for (index_t i = x; i < M; i += x_size) {
+    val = negate ? -in[base + i*sa] : in[base + i*sa];
+    out[base + i*sa] =
+      (i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) : OType(0.0f);
+  }
+}
+
+template<typename OP, bool negate, typename AType, typename DType, typename OType,
+         typename IType, int ndim>
+inline void SoftmaxWithLength(Stream<gpu> *s, DType *in, OType *out, IType *length,
+                    Shape<ndim> shape, int axis, const double temperature) {
+  const int x_bits = 7;
+  const int x_size = 1 << x_bits;
+  index_t M = shape[axis];
+  index_t N = shape.Size()/M;
+  Shape<ndim> stride = calc_stride(shape);
+  Shape<ndim> sshape = shape;
+  sshape[axis] = 1;
+
+  softmax_with_length_kernel<x_bits, OP, negate, AType, ndim>
+    <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+      in, out, length, M, axis, sshape, stride, temperature);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
+}
+
 
 template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
-  typename DType, typename OType>
+         typename DType, typename OType>
 __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
                                         index_t M, int axis, Shape<ndim> sshape,
                                         Shape<ndim> stride, const double temperature) {
@@ -269,7 +427,7 @@ __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
 
 
 template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
-  typename DType, typename OType>
+         typename DType, typename OType>
 inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
                         DType *igrad, Shape<ndim> shape, int axis,
                         const double temperature) {
@@ -286,6 +444,60 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
       out, ograd, igrad, M, axis, sshape, stride, temperature);
   MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
 }
+
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+         typename DType, typename OType, typename IType>
+__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad,
+                                                IType *length, index_t M, int axis,
+                                                Shape<ndim> sshape, Shape<ndim> stride,
+                                                const double temperature) {
+  const unsigned x_size = 1 << x_bits;
+  __shared__ AType smem[x_size];
+  index_t sa = stride[axis];
+  index_t base = unravel_dot(blockIdx.x, sshape, stride);
+  index_t x = threadIdx.x;
+  index_t len = static_cast<index_t>(length[blockIdx.x]);
+
+  red::sum::SetInitValue(smem[x]);
+  for (index_t i = x; i < len; i += x_size) {
+    smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
+  }
+  __syncthreads();
+  cuda::Reduce1D<red::sum, x_bits>(smem);
+  __syncthreads();
+  AType ssum = smem[0];
+  __syncthreads();
+
+  DType final_result;
+  for (index_t i = x; i < M; i += x_size) {
+    final_result =
+      negate ?
+      -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) :
+      OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
+    final_result = (i < len) ? final_result : DType(0.0f);
+    KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast<DType>(temperature));
+  }
+}
+
+
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+         typename DType, typename OType, typename IType>
+inline void SoftmaxWithLengthGrad(Stream<gpu> *s, OType *out, OType *ograd,
+                                  DType *igrad, IType *length, Shape<ndim> shape, int axis,
+                                  const double temperature) {
+  const int x_bits = 7;
+  const int x_size = 1 << x_bits;
+  index_t M = shape[axis];
+  index_t N = shape.Size()/M;
+  Shape<ndim> stride = calc_stride(shape);
+  Shape<ndim> sshape = shape;
+  sshape[axis] = 1;
+
+  softmax_with_length_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
+    <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
+      out, ograd, igrad, length, M, axis, sshape, stride, temperature);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_with_length_grad_kernel);
+}
 #endif
 
 }  // namespace mxnet_op
@@ -295,6 +507,7 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
   int axis;
   dmlc::optional<double> temperature;
   dmlc::optional<int> dtype;
+  dmlc::optional<bool> use_length;
   DMLC_DECLARE_PARAMETER(SoftmaxParam) {
     DMLC_DECLARE_FIELD(axis).set_default(-1)
     .describe("The axis along which to compute softmax.");
@@ -307,6 +520,9 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
     .set_default(dmlc::optional<int>())
     .describe("DType of the output in case this can't be inferred. "
               "Defaults to the same as input's dtype if not defined (dtype=None).");
+    DMLC_DECLARE_FIELD(use_length)
+    .set_default(dmlc::optional<bool>(false))
+    .describe("Whether to use the length input as a mask over the data input.");
   }
 };
 
@@ -315,27 +531,71 @@ static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
   return param.dtype.has_value() && param.dtype.value() != -1;
 }
 
+static inline bool softmax_use_length(const nnvm::NodeAttrs& attrs) {
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  return param.use_length.value();
+}
+
 static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
                                  std::vector<int>* in_attrs,
                                  std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 1);
   CHECK_EQ(out_attrs->size(), 1);
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);
 
   if (softmax_has_dtype_override(attrs)) {
     TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
     type_assign(&(*in_attrs)[0], (*out_attrs)[0]);
     return true;
   } else {
-    return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs);
+    std::vector<int> tmp = {in_attrs->at(0)};
+    return ElemwiseType<1, 1>(attrs, &tmp, out_attrs);
+  }
+}
+
+static inline bool SoftmaxOpShape(const nnvm::NodeAttrs& attrs,
+                                  mxnet::ShapeVector *in_attrs,
+                                  mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(out_attrs->size(), 1U);
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), param.use_length.value() ? 2U : 1U);
+
+  if (param.use_length.value()) {
+    mxnet::TShape& dshape = in_attrs->at(0);
+    mxnet::TShape tmp_shape((dshape.ndim() == 1) ? 1U : dshape.ndim() - 1, 1);
+    int j = 0;
+    for (int i = 0; i < dshape.ndim(); ++i) {
+      if (i != param.axis) {
+        tmp_shape[j++] = dshape[i];
+      }
+    }
+    SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape);
   }
+  mxnet::ShapeVector tmp = {in_attrs->at(0)};
+  return ElemwiseShape<1, 1>(attrs, &tmp, out_attrs);
 }
 
 static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
                                       mxnet::ShapeVector *in_attrs,
                                       mxnet::ShapeVector *out_attrs) {
-  if (softmax_has_dtype_override(attrs)) {
-    return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
+  if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+    if (softmax_use_length(attrs)) {
+      mxnet::ShapeVector ins = {in_attrs->at(0), in_attrs->at(1), in_attrs->at(3)};
+      mxnet::ShapeVector dgrad = {out_attrs->at(0)};
+      bool res = ElemwiseShape<3, 1>(attrs, &ins, &dgrad);
+      SHAPE_ASSIGN_CHECK(*in_attrs, 0, ins[0]);
+      SHAPE_ASSIGN_CHECK(*in_attrs, 1, ins[1]);
+      SHAPE_ASSIGN_CHECK(*in_attrs, 3, ins[2]);
+      SHAPE_ASSIGN_CHECK(*out_attrs, 0, dgrad[0]);
+      mxnet::ShapeVector length = {in_attrs->at(2)};
+      mxnet::ShapeVector lgrad = {out_attrs->at(1)};
+      res = (res && ElemwiseShape<1, 1>(attrs, &length, &lgrad));
+      SHAPE_ASSIGN_CHECK(*in_attrs, 2, length[0]);
+      SHAPE_ASSIGN_CHECK(*out_attrs, 1, lgrad[0]);
+      return res;
+    } else {
+      return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
+    }
   } else {
     return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs);
   }
@@ -344,17 +604,21 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
 static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
                                      std::vector<int>* in_attrs,
                                      std::vector<int>* out_attrs) {
-  CHECK_EQ(out_attrs->size(), 1);
-  if (softmax_has_dtype_override(attrs)) {
-    CHECK_EQ(in_attrs->size(), 3);
+  CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);
+  if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+    CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 4U : 3U);
     int in_dtype = (*in_attrs)[1];
-    int out_dtype = (*in_attrs)[2];
+    int out_dtype = (*in_attrs)[softmax_use_length(attrs) ? 3 : 2];
     TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
     TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
+    if (softmax_use_length(attrs)) {
+      TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(2));
+    }
 
-    return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
+    return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1 &&
+           (*out_attrs)[1] != -1 && (*in_attrs)[1] != -1;
   } else {
-    CHECK_EQ(in_attrs->size(), 2);
+    CHECK_EQ(in_attrs->size(), 2U);
     int out_dtype = (*in_attrs)[1];
     TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
     TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
@@ -365,20 +629,31 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
 
 static inline std::vector<std::pair<int, int> >
 SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
-  if (softmax_has_dtype_override(attrs)) {
-    return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
+  if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+    if (softmax_use_length(attrs)) {
+      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 1}, {3, 0}};
+    } else {
+      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
+    }
   } else {
     return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
   }
 }
 
 static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) {
-  return softmax_has_dtype_override(attrs) ? 3 : 2;
+  if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+    return softmax_use_length(attrs) ? 4 : 3;
+  }
+  return 2;
 }
 
 static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) {
-  if (softmax_has_dtype_override(attrs)) {
-    return std::vector<std::string>{"ograd", "data", "output"};
+  if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
+    if (softmax_use_length(attrs)) {
+      return std::vector<std::string>{"ograd", "data", "length", "output"};
+    } else {
+      return std::vector<std::string>{"ograd", "data", "output"};
+    }
   } else {
     return std::vector<std::string>{"ograd", "output"};
   }
@@ -388,7 +663,7 @@ struct SoftmaxFGradient {
   const char *op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
                                           const std::vector<nnvm::NodeEntry>& ograds) const {
-    if (softmax_has_dtype_override(n->attrs)) {
+    if (softmax_has_dtype_override(n->attrs) || softmax_use_length(n->attrs)) {
       return ElemwiseGradUseInOut {op_name}(n, ograds);
     } else {
       return ElemwiseGradUseOut {op_name}(n, ograds);
@@ -419,30 +694,46 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
 
   MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-      if (safe_acc) {
-        if (shape.ndim() == 2) {
-          Softmax<OP, negate, AType>(
-              ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-              outputs[0].dptr<OType>(), shape.get<2>(), axis,
-              static_cast<DType>(temperature));
+      if (!param.use_length.value()) {
+        if (safe_acc) {
+          if (shape.ndim() == 2) {
+            Softmax<OP, negate, AType>(
+                ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+                outputs[0].dptr<OType>(), shape.get<2>(), axis,
+                static_cast<DType>(temperature));
+          } else {
+            Softmax<OP, negate, AType>(
+                ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+                outputs[0].dptr<OType>(), shape.get<3>(), axis,
+                static_cast<DType>(temperature));
+          }
         } else {
-          Softmax<OP, negate, AType>(
-              ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-              outputs[0].dptr<OType>(), shape.get<3>(), axis,
-              static_cast<DType>(temperature));
+          if (shape.ndim() == 2) {
+            Softmax<OP, negate, DType>(
+                ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+                outputs[0].dptr<OType>(), shape.get<2>(), axis,
+                static_cast<DType>(temperature));
+          } else {
+            Softmax<OP, negate, DType>(
+                ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+                outputs[0].dptr<OType>(), shape.get<3>(), axis,
+                static_cast<DType>(temperature));
+          }
         }
       } else {
-        if (shape.ndim() == 2) {
-          Softmax<OP, negate, DType>(
+        MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, {
+          if (shape.ndim() == 2) {
+            SoftmaxWithLength<OP, negate, AType>(
               ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-              outputs[0].dptr<OType>(), shape.get<2>(), axis,
-              static_cast<DType>(temperature));
-        } else {
-          Softmax<OP, negate, DType>(
+              outputs[0].dptr<OType>(), inputs[1].dptr<IType>(),
+              shape.get<2>(), axis, static_cast<DType>(temperature));
+          } else {
+            SoftmaxWithLength<OP, negate, AType>(
               ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-              outputs[0].dptr<OType>(), shape.get<3>(), axis,
-              static_cast<DType>(temperature));
-        }
+              outputs[0].dptr<OType>(), inputs[1].dptr<IType>(),
+              shape.get<3>(), axis, static_cast<DType>(temperature));
+          }
+        });
       }
     });
   });
@@ -464,35 +755,56 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
   mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
 
   int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
+  out_idx = softmax_use_length(attrs) ? 3 : out_idx;
   bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
 
   MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
     MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        if (safe_acc) {
-          if (shape.ndim() == 2) {
-            SoftmaxGrad<OP1, OP2, Req, negate, AType>(
-                ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
-                inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
-                shape.get<2>(), axis, static_cast<DType>(temperature));
+        if (!softmax_use_length(attrs)) {
+          if (safe_acc) {
+            if (shape.ndim() == 2) {
+              SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+                  ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+                  inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+                  shape.get<2>(), axis, static_cast<DType>(temperature));
+            } else {
+              SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+                  ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+                  inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+                  shape.get<3>(), axis, static_cast<DType>(temperature));
+            }
           } else {
-            SoftmaxGrad<OP1, OP2, Req, negate, AType>(
-                ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
-                inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
-                shape.get<3>(), axis, static_cast<DType>(temperature));
+            if (shape.ndim() == 2) {
+              SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+                  ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+                  inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+                  shape.get<2>(), axis, static_cast<DType>(temperature));
+            } else {
+              SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+                  ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+                  inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+                  shape.get<3>(), axis, static_cast<DType>(temperature));
+            }
           }
         } else {
-          if (shape.ndim() == 2) {
-            SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+          MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, {
+            if (req[1] != kNullOp) {
+              mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
+                ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
+            }
+            if (shape.ndim() == 2) {
+              SoftmaxWithLengthGrad<OP1, OP2, Req, negate, AType>(
                 ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
                 inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
-                shape.get<2>(), axis, static_cast<DType>(temperature));
-          } else {
-            SoftmaxGrad<OP1, OP2, Req, negate, DType>(
+                inputs[2].dptr<IType>(), shape.get<2>(), axis, static_cast<DType>(temperature));
+            } else {
+              SoftmaxWithLengthGrad<OP1, OP2, Req, negate, AType>(
                 ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
                 inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
-                shape.get<3>(), axis, static_cast<DType>(temperature));
-          }
+                inputs[2].dptr<IType>(), shape.get<3>(), axis, static_cast<DType>(temperature));
+            }
+          });
         }
       });
     });
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index e44bbbb..5a581e4 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -59,14 +59,23 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
                                       DispatchMode* dispatch_mode,
                                       std::vector<int> *in_attrs,
                                       std::vector<int> *out_attrs) {
-  CHECK_EQ(in_attrs->size(), 1);
-  CHECK_EQ(out_attrs->size(), 1);
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  if (param.use_length.value()) {
+    auto& out_stype = out_attrs->at(0);
+    return storage_type_assign(&out_stype, kDefaultStorage,
+                               dispatch_mode, DispatchMode::kFCompute);
+  }
 
   return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
                            out_attrs);
 }
 #endif
 
+
+
 NNVM_REGISTER_OP(softmax)
 .describe(R"code(Applies the softmax function.
 
@@ -92,6 +101,13 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<SoftmaxParam>)
+.set_attr<nnvm::FListOutputNames>("FListInputNames",
+    [](const NodeAttrs& attrs){
+    const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+    return (param.use_length.value()) ?
+           std::vector<std::string>{"data", "length"} :
+           std::vector<std::string>{"data"};
+})
 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
     [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"output"};
@@ -103,20 +119,27 @@ Example::
 .set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
 #endif
 .set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
+// .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
-.set_num_inputs(1)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+    return (param.use_length.value()) ? 2 : 1;
+  })
 .set_num_outputs(1)
-.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<mxnet::FInferShape>("FInferShape", SoftmaxOpShape)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
   [](const NodeAttrs& attrs){
     return std::vector<std::pair<int, int> >{{0, 0}};
   })
 .add_argument("data", "NDArray-or-Symbol", "The input array.")
+.add_argument("length", "NDArray-or-Symbol", "The length array.")
 .add_arguments(SoftmaxParam::__FIELDS__());
 
 NNVM_REGISTER_OP(_backward_softmax)
 .set_num_inputs(SoftmaxGradOpNumInputs)
-.set_num_outputs(1)
+.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
+    return (softmax_use_length(attrs) ? 2 : 1);
+  })
 .set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
 .set_attr<mxnet::FInferShape>("FInferShape", SoftmaxGradOpShape)
 .set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 749f0f2..fea07f5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5196,6 +5196,39 @@ def test_softmax_dtype():
         check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
                                   'float32', 'float64', 'float64')
 
+
+@with_seed()
+def test_softmax_with_length():
+    def np_softmax_with_length(data, length):
+        res = np.zeros(data.shape)
+        for i in range(length.shape[0]):
+            for j in range(length.shape[1]):
+                leng = int(length[i, j])
+                res[i, 0:leng, j] = np_softmax(data[i, 0:leng, j])
+        return res
+
+    ndim = 3
+    shape = rand_shape_nd(ndim, dim=10)
+    len_shape = list(shape)
+    del len_shape[1]
+    len_shape = tuple(len_shape)
+    for dtype in [np.float16, np.float32, np.float64]:
+        mx_data = rand_ndarray(shape, dtype=dtype)
+        np_data = mx_data.asnumpy()
+        np_length = np.random.randint(1, shape[1] + 1, len_shape)
+        mx_length = mx.nd.array(np_length, dtype=np.int32)
+        np_out = np_softmax_with_length(np_data, np_length)
+        data = mx.sym.Variable("data")
+        length = mx.sym.Variable("length")
+        mx_sym = mx.sym.softmax(data=data, length=length, use_length=True, axis=1)
+        location = {"data": mx_data, "length": mx_length}
+        rtol = 1e-2 if dtype == np.float16 else 1e-3
+        atol = 1e-4 if dtype == np.float16 else 1e-5
+        check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy")
+        check_symbolic_backward(mx_sym, location, [np.ones(shape, dtype=dtype)],
+                                [np.zeros(shape), np.zeros(len_shape, dtype=np.int32)], rtol=1e-2, atol=1e-3, dtype="asnumpy")
+
+
 @with_seed()
 def test_pick():
     def test_pick_helper(index_type=np.int32):
@@ -8034,7 +8067,11 @@ def test_op_all_names_monitor():
     check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output'])
 
     sm_sym = mx.sym.softmax(data, name='softmax')
-    check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output'])
+    check_name(sm_sym, ['data', 'softmax_data', 'softmax_output'])
+
+    length = mx.sym.Variable("length", shape=(10, 10, 10))
+    sm_sym = mx.sym.softmax(data, length, axis=1, use_length=True, name='softmax')
+    check_name(sm_sym, ['data', 'softmax_data', 'length', 'softmax_length', 'softmax_output'])
 
     sa_sym = mx.sym.SoftmaxActivation(data, name='softmax')
     check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])