You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/12/30 03:47:52 UTC

[incubator-mxnet] branch master updated: masked_log_softmax -inf for masked values (#19703)

This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d8a214  masked_log_softmax -inf for masked values (#19703)
3d8a214 is described below

commit 3d8a214a242ca5c593c561387c34129098090445
Author: Moises Hernandez <50...@users.noreply.github.com>
AuthorDate: Tue Dec 29 19:45:16 2020 -0800

    masked_log_softmax -inf for masked values (#19703)
    
    * Set -INF in FWD masked values. Remove scale
    
    * fix lint
---
 src/operator/nn/log_softmax.cc         |   3 +-
 src/operator/nn/log_softmax.cu         |   3 +-
 src/operator/nn/softmax-inl.h          | 145 ++++++++++++++++-----------------
 src/operator/nn/softmax.cc             |   3 +-
 src/operator/nn/softmax.cu             |   3 +-
 tests/python/unittest/test_operator.py |  36 +++++---
 6 files changed, 100 insertions(+), 93 deletions(-)

diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 6aae7e9..2a1d1b3 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -168,7 +168,8 @@ This is equivalent to computing masked softmax followed by log.)code")
   [](const NodeAttrs& attrs){
     return std::vector<std::string>{"data", "mask"};
   })
-.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::log_softmax_fwd,
+                                     true>)
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     auto data_grad = MakeNode("_backward_masked_log_softmax", n->attrs.name + "_backward_data",
diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu
index 2a54cd3..396a4e8 100644
--- a/src/operator/nn/log_softmax.cu
+++ b/src/operator/nn/log_softmax.cu
@@ -36,7 +36,8 @@ NNVM_REGISTER_OP(_backward_log_softmax)
                                                         mxnet_op::log_softmax_bwd>);
 
 NNVM_REGISTER_OP(masked_log_softmax)
-.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::log_softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::log_softmax_fwd,
+                                                          true>);
 
 NNVM_REGISTER_OP(_backward_masked_log_softmax)
 .set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu, mshadow_op::left,
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index b53b8a4..512d8d2 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -163,12 +163,11 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out, IType *length,
   }
 }
 
-struct masked_softmax_where_scale {
+struct masked_softmax_where {
   template<typename DType, int ndim>
   MSHADOW_XINLINE static void Map(index_t id, DType* out, const bool* cond,
                                   const DType* x, const double y,
-                                  Shape<ndim> data_shape, Shape<ndim> mask_shape,
-                                  const double scale) {
+                                  Shape<ndim> data_shape, Shape<ndim> mask_shape) {
     index_t mask_pos = 0;
     index_t stride = 1;
     for (index_t i = ndim-1, j = id; i >=0; --i) {
@@ -179,31 +178,32 @@ struct masked_softmax_where_scale {
       stride *= mask_shape[i];
       j = tmp;
     }
-    KERNEL_ASSIGN(out[id], kWriteTo, (cond[mask_pos] ? x[id] / static_cast<DType>(scale) :
-                                                       static_cast<DType>(y)));
+    KERNEL_ASSIGN(out[id], kWriteTo, (cond[mask_pos] ? x[id] : static_cast<DType>(y)));
   }
 };
 
-template<typename OP, bool negate, typename AType, typename DType, int ndim>
+template<typename OP, bool masked_neg_inf, bool negate,
+         typename AType, typename DType, int ndim>
 inline void MaskedSoftmax(Stream<cpu> *s, DType *in, DType *out, bool *mask,
                           Shape<ndim> data_shape, Shape<ndim> mask_shape,
-                          int axis, const double scale,
-                          const double temperature, bool normalize,
+                          int axis, const double temperature, bool normalize,
                           const OpContext& ctx) {
   Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1, DType>(
       Shape1(data_shape.Size()), s);
-  DType* masked_scaled_input = TBlob(workspace).dptr<DType>();
+  DType* masked_input = TBlob(workspace).dptr<DType>();
 
   double neg = MinValue<DType>();
-  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), masked_scaled_input,
-                                                  mask, in, neg, data_shape, mask_shape,
-                                                  scale);
+  Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), masked_input,
+                                            mask, in, neg, data_shape, mask_shape);
   int* max_lenghts = nullptr;
-  Softmax<OP, negate, AType, DType>(s, masked_scaled_input, out, max_lenghts,
+  double masked_value = 0.0;
+  if (masked_neg_inf)
+    masked_value = -INFINITY;
+  Softmax<OP, negate, AType, DType>(s, masked_input, out, max_lenghts,
                                     data_shape, axis, temperature);
-  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), out,
-                                                  mask, out, 0.0, data_shape, mask_shape,
-                                                  1.0);
+  Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), out,
+                                            mask, out, masked_value, data_shape,
+                                            mask_shape);
 }
 
 struct softmax_bwd {
@@ -308,22 +308,20 @@ template<typename OP1, typename OP2, int Req, bool negate, typename AType, int n
 inline void MaskedSoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
                               DType *igrad, bool *mask, Shape<ndim> data_shape,
                               Shape<ndim> mask_shape, int axis,
-                              const double scale, const double temperature,
+                              const double temperature,
                               const OpContext& ctx) {
   Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1, DType>(
     Shape1(data_shape.Size()), s);
   DType* masked_ograd = TBlob(workspace).dptr<DType>();
-  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), masked_ograd,
-                                                  mask, ograd, 0.0, data_shape, mask_shape,
-                                                  1.0);
+  Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), masked_ograd,
+                                                  mask, ograd, 0.0, data_shape, mask_shape);
   int* max_lenghts = nullptr;
   SoftmaxGrad<OP1, OP2, Req, negate, AType, DType, DType, int, ndim>(
       s, out, masked_ograd, igrad,
       max_lenghts, data_shape,
       axis, temperature);
-  Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), igrad,
-                                                  mask, igrad, 0.0, data_shape, mask_shape,
-                                                  scale);
+  Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), igrad,
+                                            mask, igrad, 0.0, data_shape, mask_shape);
 }
 
 #ifdef __CUDACC__
@@ -484,12 +482,12 @@ MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const Shape<ndim>&
   return ret;
 }
 
-template<bool normalize, int x_bits, typename OP, bool negate, typename AType,
-         int ndim, typename DType>
+template<bool normalize, int x_bits, typename OP, bool masked_neg_inf,
+         bool negate, typename AType, int ndim, typename DType>
 __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
                                       index_t M, int axis, Shape<ndim> sshape,
                                       Shape<ndim> stride, Shape<ndim> mask_shape,
-                                      const double scale, const double temperature) {
+                                      const double temperature) {
   extern __shared__ double shared[];
   AType* smem = reinterpret_cast<AType*>(shared);  // x_size
 
@@ -512,7 +510,7 @@ __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
     __syncthreads();
     cuda::Reduce1D<red::maximum, x_bits>(smem);
     __syncthreads();
-    smax = smem[0] / scale;
+    smax = smem[0];
     __syncthreads();
   }
 
@@ -521,7 +519,7 @@ __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
   for (index_t i = x; i < M; i += x_size) {
     bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask + i*sa_mask];
     if (mask_value) {
-      val = (negate ? -in[base + i*sa]:in[base + i*sa]) / scale;
+      val = (negate ? -in[base + i*sa]:in[base + i*sa]);
       smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
     }
   }
@@ -531,21 +529,25 @@ __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
   AType ssum = smem[0];
   __syncthreads();
 
+  double masked_value = 0.0;
+  if (masked_neg_inf)
+    masked_value = -INFINITY;
   for (index_t i = x; i < M; i += x_size) {
-    val = (negate ? -in[base + i*sa] : in[base + i*sa]) / scale;
+    val = (negate ? -in[base + i*sa] : in[base + i*sa]);
     bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask + i*sa_mask];
     out[base + i*sa] =
       mask_value ? DType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) :
-                             DType(0.0f);
+                             DType(masked_value);
   }
 }
 
-template<bool normalize, typename OP, bool negate, typename AType, typename LType,
-         typename LTypeMask, typename DType, int ndim>
+template<bool normalize, typename OP,  bool masked_neg_inf, bool negate, typename AType,
+         typename LType, typename LTypeMask, typename DType, int ndim>
 __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool *in_mask,
                                               const index_t M, int axis, Shape<ndim> sshape,
-                                              Shape<ndim> mask_shape, const double scale,
-                                              const double temperature, const int rows_per_block,
+                                              Shape<ndim> mask_shape,
+                                              const double temperature,
+                                              const int rows_per_block,
                                               const index_t total_rows,
                                               const size_t size_input_shared,
                                               const size_t size_mask_shared) {
@@ -616,7 +618,7 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool
       scratch[threadIdx.x] = my_value;
     }
     __syncthreads();
-    smax = scratch[threadIdx.x - threadIdx.x % threads_per_row]  / scale;
+    smax = scratch[threadIdx.x - threadIdx.x % threads_per_row];
     __syncthreads();
   }
 
@@ -624,7 +626,7 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool
   red::sum::SetInitValue(my_sum);
   for (index_t i = my_id; i < M; i += threads_per_row) {
     if (row_mask[i]) {
-      const DType val = (negate ? -row[i] : row[i]) / scale;
+      const DType val = (negate ? -row[i] : row[i]);
       my_sum += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
     }
   }
@@ -646,10 +648,13 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool
   AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
   __syncthreads();
 
+  double masked_value = 0.0;
+  if (masked_neg_inf)
+    masked_value = -INFINITY;
   for (index_t i = my_id; i < M; i += threads_per_row) {
-    const DType val = (negate ? -row[i] : row[i]) / scale;
+    const DType val = (negate ? -row[i] : row[i]);
     row[i] = row_mask[i] ? DType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) :
-                           DType(0.0f);
+                           DType(masked_value);
   }
   __syncthreads();
 
@@ -699,11 +704,11 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
   }
 }
 
-template<typename OP, bool negate, typename AType, typename DType,
-         typename OType, int ndim>
+template<typename OP, bool masked_neg_inf, bool negate,
+         typename AType, typename DType, typename OType, int ndim>
 inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
                           Shape<ndim> data_shape, Shape<ndim> mask_shape,
-                          int axis, const double scale, const double temperature,
+                          int axis, const double temperature,
                           bool normalize, const OpContext& ctx) {
   const int x_bits = 7;
   const int x_size = 1 << x_bits;
@@ -747,16 +752,18 @@ inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
 
         int nblocks = (N + rows_per_block - 1) / rows_per_block;
         if (normalize) {
-          masked_softmax_stride1_kernel<true, OP, negate, AType, LType, LTypeMask>
+          masked_softmax_stride1_kernel<true, OP, masked_neg_inf, negate,
+                                        AType, LType, LTypeMask>
             <<<nblocks, softmax_threads_per_block, amount_shared,
                mshadow::Stream<gpu>::GetStream(s)>>>(
-              in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+              in, out, mask, M, axis, sshape, mask_shape, temperature,
               rows_per_block, N, size_input_shared, size_mask_shared);
         } else {
-          masked_softmax_stride1_kernel<false, OP, negate, AType, LType, LTypeMask>
+          masked_softmax_stride1_kernel<false, OP, masked_neg_inf, negate,
+                                        AType, LType, LTypeMask>
             <<<nblocks, softmax_threads_per_block, amount_shared,
                mshadow::Stream<gpu>::GetStream(s)>>>(
-              in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+              in, out, mask, M, axis, sshape, mask_shape, temperature,
               rows_per_block, N, size_input_shared, size_mask_shared);
         }
       });
@@ -765,13 +772,13 @@ inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
   } else {
     size_t amount_shared = x_size * sizeof(AType);
     if (normalize) {
-      masked_softmax_kernel<true, x_bits, OP, negate, AType, ndim>
+      masked_softmax_kernel<true, x_bits, OP, masked_neg_inf, negate, AType, ndim>
         <<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
-          in, out, mask, M, axis, sshape, stride, mask_shape, scale, temperature);
+          in, out, mask, M, axis, sshape, stride, mask_shape, temperature);
     } else {
-      masked_softmax_kernel<false, x_bits, OP, negate, AType, ndim>
+      masked_softmax_kernel<false, x_bits, OP, masked_neg_inf, negate, AType, ndim>
         <<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
-          in, out, mask, M, axis, sshape, stride, mask_shape, scale, temperature);
+          in, out, mask, M, axis, sshape, stride, mask_shape, temperature);
     }
     MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_kernel);
   }
@@ -898,7 +905,6 @@ __global__ void masked_softmax_stride1_grad_kernel(const OType *out, const OType
                                                    const index_t M, int axis,
                                                    Shape<ndim> sshape,
                                                    Shape<ndim> mask_shape,
-                                                   const double scale,
                                                    const double temperature,
                                                    const int rows_per_block,
                                                    const index_t total_rows,
@@ -975,14 +981,12 @@ __global__ void masked_softmax_stride1_grad_kernel(const OType *out, const OType
   AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
   __syncthreads();
 
-  AType temperature_scale = static_cast<AType>(temperature) *
-                            static_cast<AType>(scale);
   for (index_t i = my_id; i < M; i += threads_per_row) {
     const DType val =
       negate ?
       -OP2::Map(row[i + M], row[i], ssum):
       OP2::Map(row[i + M], row[i], ssum);
-    row[i] = row_mask[i] ? DType(val / static_cast<DType>(temperature_scale)) :
+    row[i] = row_mask[i] ? DType(val / static_cast<DType>(temperature)) :
                            DType(0.0f);
     if (Req == kAddTo) {
       row[i] += igrad[my_row * M + i];
@@ -1003,7 +1007,7 @@ __global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType *igra
                                            const bool *in_mask, index_t M, int axis,
                                            Shape<ndim> sshape, Shape<ndim> stride,
                                            Shape<ndim> mask_shape,
-                                           const double scale, const double temperature) {
+                                           const double temperature) {
   const unsigned x_size = 1 << x_bits;
   __shared__ AType smem[x_size];
   index_t sa = stride[axis];
@@ -1026,15 +1030,13 @@ __global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType *igra
   __syncthreads();
 
   DType final_result;
-  AType temperature_scale = static_cast<AType>(temperature) *
-                            static_cast<AType>(scale);
   for (index_t i = x; i < M; i += x_size) {
     bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask + i*sa_mask];
     final_result =
       negate ?
       -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum):
       OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
-    final_result = mask_value ? final_result / static_cast<DType>(temperature_scale) : DType(0.0f);
+    final_result = mask_value ? final_result / static_cast<DType>(temperature) : DType(0.0f);
     KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result);
   }
 }
@@ -1086,7 +1088,7 @@ template<typename OP1, typename OP2, int Req, bool negate, typename AType, int n
 inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
                               DType *igrad, bool *mask, Shape<ndim> data_shape,
                               Shape<ndim> mask_shape, int axis,
-                              const double scale, const double temperature,
+                              const double temperature,
                               const OpContext& ctx) {
   const int x_bits = 7;
   const int x_size = 1 << x_bits;
@@ -1133,14 +1135,14 @@ inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
           <<<nblocks, softmax_threads_per_block, amount_shared,
              mshadow::Stream<gpu>::GetStream(s)>>>(
             out, ograd, igrad, mask, M, axis, sshape, mask_shape,
-            scale, temperature, rows_per_block, N, size_input_shared, size_mask_shared);
+            temperature, rows_per_block, N, size_input_shared, size_mask_shared);
       });
     });
     MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_stride1_grad_kernel);
   } else {
     masked_softmax_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
       <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
-        out, ograd, igrad, mask, M, axis, sshape, stride, mask_shape, scale, temperature);
+        out, ograd, igrad, mask, M, axis, sshape, stride, mask_shape, temperature);
     MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_grad_kernel);
   }
 }
@@ -1181,15 +1183,12 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
 
 struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
   int axis;
-  dmlc::optional<double> scale_factor;
   dmlc::optional<double> temperature;
   dmlc::optional<int> dtype;
   dmlc::optional<bool> normalize;
   DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) {
     DMLC_DECLARE_FIELD(axis).set_default(-1)
     .describe("The axis along which to compute softmax.");
-    DMLC_DECLARE_FIELD(scale_factor).set_default(dmlc::optional<double>())
-    .describe("Scaling factor applied before softmax");
     DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
     .describe("Temperature parameter in softmax");
     DMLC_DECLARE_FIELD(normalize)
@@ -1492,7 +1491,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
-template<typename xpu, typename OP, bool negate = false>
+template<typename xpu, typename OP, bool masked_neg_inf, bool negate = false>
 void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
                           const OpContext& ctx,
                           const std::vector<TBlob>& inputs,
@@ -1503,8 +1502,6 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
   CHECK_NE(req[0], kAddTo);
   const MaskedSoftmaxParam& param = nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
   int axis = CheckAxis(param.axis, inputs[0].ndim());
-  const double scale = param.scale_factor.has_value() ?
-    param.scale_factor.value() : 1.0;
   const double temperature = param.temperature.has_value() ?
     param.temperature.value() : 1.0;
   bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
@@ -1518,17 +1515,17 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
     MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
       bool* mask_ptr = inputs[1].dptr<bool>();
       if (safe_acc) {
-        MaskedSoftmax<OP, negate, AType>(
+        MaskedSoftmax<OP, masked_neg_inf, negate, AType>(
           ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
           outputs[0].dptr<DType>(), mask_ptr,
           inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
-          axis, scale, temperature, param.normalize.value(), ctx);
+          axis, temperature, param.normalize.value(), ctx);
       } else {
-        MaskedSoftmax<OP, negate, DType>(
+        MaskedSoftmax<OP, masked_neg_inf, negate, DType>(
           ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
           outputs[0].dptr<DType>(), mask_ptr,
           inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
-          axis, scale, temperature, param.normalize.value(), ctx);
+          axis, temperature, param.normalize.value(), ctx);
       }
     });
   });
@@ -1616,8 +1613,6 @@ void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
   if (req[0] == kNullOp) return;
   const MaskedSoftmaxParam& param = nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
   int axis = CheckAxis(param.axis, inputs[0].ndim());
-  const double scale = param.scale_factor.has_value() ?
-    param.scale_factor.value() : 1.0;
   const double temperature = param.temperature.has_value() ?
     param.temperature.value() : 1.0;
 
@@ -1634,15 +1629,13 @@ void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
               ctx.get_stream<xpu>(), out_ptr,
               ograd_ptr, grad_data, mask_ptr,
               inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
-              axis, static_cast<DType>(scale),
-              static_cast<DType>(temperature), ctx);
+              axis, static_cast<DType>(temperature), ctx);
         } else {
           MaskedSoftmaxGrad<OP1, OP2, Req, negate, DType>(
               ctx.get_stream<xpu>(), out_ptr,
               ograd_ptr, grad_data, mask_ptr,
               inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
-              axis, static_cast<DType>(scale),
-              static_cast<DType>(temperature), ctx);
+              axis, static_cast<DType>(temperature), ctx);
         }
       });
     });
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index cf67853..b3ffd42 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -201,7 +201,8 @@ NNVM_REGISTER_OP(masked_softmax)
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"output"};
   })
-.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::softmax_fwd>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::softmax_fwd,
+                                     false>)
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     auto data_grad = MakeNode("_backward_masked_softmax", n->attrs.name + "_backward_data",
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index dc8fd99..c75f543 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -35,7 +35,8 @@ NNVM_REGISTER_OP(_backward_softmax)
 .set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, op::mshadow_op::mul,
                                                         mxnet_op::softmax_bwd>);
 NNVM_REGISTER_OP(masked_softmax)
-.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::softmax_fwd,
+                                                          false>);
 
 NNVM_REGISTER_OP(_backward_masked_softmax)
 .set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu, op::mshadow_op::mul,
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5034b07..7a85364 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4941,27 +4941,31 @@ def test_softmax_with_length():
                                 [np.zeros(shape), np.zeros(len_shape, dtype=np.int32)],
                                 rtol=1e-2, atol=2e-3 if dtype == np.float16 else 1e-3, dtype="asnumpy")
 
-def np_softmax(x, axis=-1, scale_factor=1.0, temperature=1.0, normalize=True):
-    x = x / scale_factor
+def np_softmax(x, axis=-1, temperature=1.0, normalize=True):
     if normalize:
         x = x - np.max(x, axis=axis, keepdims=True)
     x = np.exp(x / temperature)
     x /= np.sum(x, axis=axis, keepdims=True)
     return x
 
-def np_masked_softmax(data, mask, axis=-1, scale_factor=1.0, temperature=1.0, normalize=True):
+def np_masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
     neg = -1e18
     if data.dtype == np.float16:
         neg = -1e4
     temp = np.where(mask, data, neg)
     result = np_softmax(temp, axis=axis,
-                        scale_factor=scale_factor,
                         temperature=temperature,
                         normalize=normalize) * mask
     return result
-def np_masked_softmax_grad(out, grad_out, axis=-1, scale_factor=1.0, temperature=1.0):
+def np_masked_softmax_grad(out, grad_out, axis=-1, temperature=1.0):
     temp = np.sum(out * grad_out, axis=axis, keepdims=True)
-    result = out * (grad_out - temp) / (temperature * scale_factor)
+    result = out * (grad_out - temp) / temperature
+    return result
+def np_masked_log_softmax_grad(out, grad_out, mask, axis=-1, temperature=1.0):
+    grad_out = np.where(mask, grad_out, 0)
+    temp = np.sum(grad_out, axis=axis, keepdims=True)
+    result = (grad_out - np.exp(out) * temp) / temperature
+    result = np.where(mask, result, 0)
     return result
 
 @pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
@@ -4969,9 +4973,8 @@ def np_masked_softmax_grad(out, grad_out, axis=-1, scale_factor=1.0, temperature
 @pytest.mark.parametrize('ndims', [3, 4, 5])
 @pytest.mark.parametrize('n_broadcast_axis', [0, 1, 2])
 @pytest.mark.parametrize('temperature', [1, 5, 9 ,11])
-@pytest.mark.parametrize('scale', [1, 2, 7, 12])
 @pytest.mark.parametrize('normalize', [True])
-def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, scale, normalize):
+def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, normalize):
     n_broadcast_axis = min(n_broadcast_axis, ndims - 1)
     shape = rand_shape_nd(ndims, dim=10)
     mx_data = rand_ndarray(shape, dtype=dtype)
@@ -4991,12 +4994,12 @@ def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, scale
     np_grad = mx_grad.asnumpy()
 
     np_out = np_masked_softmax(np_data, np_mask, axis,
-                                scale, temperature, normalize)
+                               temperature, normalize)
     np_grad_out = np_masked_softmax_grad(np_out, np_grad,
-                                         axis, scale, temperature)
+                                         axis, temperature)
     data = mx.sym.Variable("data")
     mask = mx.sym.Variable("mask")
-    mx_sym = mx.sym.masked_softmax(data=data, mask=mask, scale_factor=scale,
+    mx_sym = mx.sym.masked_softmax(data=data, mask=mask,
                                    temperature=temperature, axis=axis,
                                    normalize=normalize)
     location = {"data": mx_data, "mask": mx_mask}
@@ -5019,15 +5022,22 @@ def test_masked_log_softmax(dtype, ndims):
     np_data = mx_data.asnumpy()
     np_mask = np.random.randint(0, 2, shape)
     mx_mask = mx.nd.array(np_mask, dtype=np.bool)
+    mx_grad = rand_ndarray(shape, dtype=dtype)
+    np_grad = mx_grad.asnumpy()
     np_out = np.log(np_masked_softmax(np_data, np_mask, axis)+1e-20) * np_mask
+    np_out_inf = np.where(np_mask, np_out, -np.inf)
+    np_grad_out = np_masked_log_softmax_grad(np_out, np_grad, np_mask, axis)
     data = mx.sym.Variable("data")
     mask = mx.sym.Variable("mask")
     mx_sym = mx.sym.masked_log_softmax(data=data, mask=mask, axis=axis-ndims)
     location = {"data": mx_data, "mask": mx_mask}
     rtol = 1e-2 if dtype == np.float16 else 1e-3
     atol = 1e-4 if dtype == np.float16 else 1e-5
-    check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy")
-    check_numeric_gradient(mx_sym, location, rtol=1e-1, atol=1e-2)
+    check_symbolic_forward(mx_sym, location, [np_out_inf], rtol=rtol, atol=atol, dtype="asnumpy")
+    check_symbolic_backward(mx_sym, location, [mx_grad],
+                            [np_grad_out, np.zeros(shape, dtype=np.bool)],
+                            rtol=1e-2, atol=2e-3 if dtype == np.float16 else 1e-3,
+                            dtype="asnumpy", equal_nan=True)
 
 
 def test_pick():