You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sx...@apache.org on 2020/12/30 03:47:52 UTC
[incubator-mxnet] branch master updated: masked_log_softmax -inf
for masked values (#19703)
This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3d8a214 masked_log_softmax -inf for masked values (#19703)
3d8a214 is described below
commit 3d8a214a242ca5c593c561387c34129098090445
Author: Moises Hernandez <50...@users.noreply.github.com>
AuthorDate: Tue Dec 29 19:45:16 2020 -0800
masked_log_softmax -inf for masked values (#19703)
* Set -INF in FWD masked values. Remove scale
* fix lint
---
src/operator/nn/log_softmax.cc | 3 +-
src/operator/nn/log_softmax.cu | 3 +-
src/operator/nn/softmax-inl.h | 145 ++++++++++++++++-----------------
src/operator/nn/softmax.cc | 3 +-
src/operator/nn/softmax.cu | 3 +-
tests/python/unittest/test_operator.py | 36 +++++---
6 files changed, 100 insertions(+), 93 deletions(-)
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 6aae7e9..2a1d1b3 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -168,7 +168,8 @@ This is equivalent to computing masked softmax followed by log.)code")
[](const NodeAttrs& attrs){
return std::vector<std::string>{"data", "mask"};
})
-.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::log_softmax_fwd,
+ true>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto data_grad = MakeNode("_backward_masked_log_softmax", n->attrs.name + "_backward_data",
diff --git a/src/operator/nn/log_softmax.cu b/src/operator/nn/log_softmax.cu
index 2a54cd3..396a4e8 100644
--- a/src/operator/nn/log_softmax.cu
+++ b/src/operator/nn/log_softmax.cu
@@ -36,7 +36,8 @@ NNVM_REGISTER_OP(_backward_log_softmax)
mxnet_op::log_softmax_bwd>);
NNVM_REGISTER_OP(masked_log_softmax)
-.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::log_softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::log_softmax_fwd,
+ true>);
NNVM_REGISTER_OP(_backward_masked_log_softmax)
.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu, mshadow_op::left,
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index b53b8a4..512d8d2 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -163,12 +163,11 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out, IType *length,
}
}
-struct masked_softmax_where_scale {
+struct masked_softmax_where {
template<typename DType, int ndim>
MSHADOW_XINLINE static void Map(index_t id, DType* out, const bool* cond,
const DType* x, const double y,
- Shape<ndim> data_shape, Shape<ndim> mask_shape,
- const double scale) {
+ Shape<ndim> data_shape, Shape<ndim> mask_shape) {
index_t mask_pos = 0;
index_t stride = 1;
for (index_t i = ndim-1, j = id; i >=0; --i) {
@@ -179,31 +178,32 @@ struct masked_softmax_where_scale {
stride *= mask_shape[i];
j = tmp;
}
- KERNEL_ASSIGN(out[id], kWriteTo, (cond[mask_pos] ? x[id] / static_cast<DType>(scale) :
- static_cast<DType>(y)));
+ KERNEL_ASSIGN(out[id], kWriteTo, (cond[mask_pos] ? x[id] : static_cast<DType>(y)));
}
};
-template<typename OP, bool negate, typename AType, typename DType, int ndim>
+template<typename OP, bool masked_neg_inf, bool negate,
+ typename AType, typename DType, int ndim>
inline void MaskedSoftmax(Stream<cpu> *s, DType *in, DType *out, bool *mask,
Shape<ndim> data_shape, Shape<ndim> mask_shape,
- int axis, const double scale,
- const double temperature, bool normalize,
+ int axis, const double temperature, bool normalize,
const OpContext& ctx) {
Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1, DType>(
Shape1(data_shape.Size()), s);
- DType* masked_scaled_input = TBlob(workspace).dptr<DType>();
+ DType* masked_input = TBlob(workspace).dptr<DType>();
double neg = MinValue<DType>();
- Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), masked_scaled_input,
- mask, in, neg, data_shape, mask_shape,
- scale);
+ Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), masked_input,
+ mask, in, neg, data_shape, mask_shape);
int* max_lenghts = nullptr;
- Softmax<OP, negate, AType, DType>(s, masked_scaled_input, out, max_lenghts,
+ double masked_value = 0.0;
+ if (masked_neg_inf)
+ masked_value = -INFINITY;
+ Softmax<OP, negate, AType, DType>(s, masked_input, out, max_lenghts,
data_shape, axis, temperature);
- Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), out,
- mask, out, 0.0, data_shape, mask_shape,
- 1.0);
+ Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), out,
+ mask, out, masked_value, data_shape,
+ mask_shape);
}
struct softmax_bwd {
@@ -308,22 +308,20 @@ template<typename OP1, typename OP2, int Req, bool negate, typename AType, int n
inline void MaskedSoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
DType *igrad, bool *mask, Shape<ndim> data_shape,
Shape<ndim> mask_shape, int axis,
- const double scale, const double temperature,
+ const double temperature,
const OpContext& ctx) {
Tensor<cpu, 1, DType> workspace = ctx.requested[0].get_space_typed<cpu, 1, DType>(
Shape1(data_shape.Size()), s);
DType* masked_ograd = TBlob(workspace).dptr<DType>();
- Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), masked_ograd,
- mask, ograd, 0.0, data_shape, mask_shape,
- 1.0);
+ Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), masked_ograd,
+ mask, ograd, 0.0, data_shape, mask_shape);
int* max_lenghts = nullptr;
SoftmaxGrad<OP1, OP2, Req, negate, AType, DType, DType, int, ndim>(
s, out, masked_ograd, igrad,
max_lenghts, data_shape,
axis, temperature);
- Kernel<masked_softmax_where_scale, cpu>::Launch(s, data_shape.Size(), igrad,
- mask, igrad, 0.0, data_shape, mask_shape,
- scale);
+ Kernel<masked_softmax_where, cpu>::Launch(s, data_shape.Size(), igrad,
+ mask, igrad, 0.0, data_shape, mask_shape);
}
#ifdef __CUDACC__
@@ -484,12 +482,12 @@ MSHADOW_XINLINE index_t get_mask_position(const index_t idx, const Shape<ndim>&
return ret;
}
-template<bool normalize, int x_bits, typename OP, bool negate, typename AType,
- int ndim, typename DType>
+template<bool normalize, int x_bits, typename OP, bool masked_neg_inf,
+ bool negate, typename AType, int ndim, typename DType>
__global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, Shape<ndim> mask_shape,
- const double scale, const double temperature) {
+ const double temperature) {
extern __shared__ double shared[];
AType* smem = reinterpret_cast<AType*>(shared); // x_size
@@ -512,7 +510,7 @@ __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
__syncthreads();
cuda::Reduce1D<red::maximum, x_bits>(smem);
__syncthreads();
- smax = smem[0] / scale;
+ smax = smem[0];
__syncthreads();
}
@@ -521,7 +519,7 @@ __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
for (index_t i = x; i < M; i += x_size) {
bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask + i*sa_mask];
if (mask_value) {
- val = (negate ? -in[base + i*sa]:in[base + i*sa]) / scale;
+ val = (negate ? -in[base + i*sa]:in[base + i*sa]);
smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
}
}
@@ -531,21 +529,25 @@ __global__ void masked_softmax_kernel(DType *in, DType *out, bool *in_mask,
AType ssum = smem[0];
__syncthreads();
+ double masked_value = 0.0;
+ if (masked_neg_inf)
+ masked_value = -INFINITY;
for (index_t i = x; i < M; i += x_size) {
- val = (negate ? -in[base + i*sa] : in[base + i*sa]) / scale;
+ val = (negate ? -in[base + i*sa] : in[base + i*sa]);
bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask + i*sa_mask];
out[base + i*sa] =
mask_value ? DType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) :
- DType(0.0f);
+ DType(masked_value);
}
}
-template<bool normalize, typename OP, bool negate, typename AType, typename LType,
- typename LTypeMask, typename DType, int ndim>
+template<bool normalize, typename OP, bool masked_neg_inf, bool negate, typename AType,
+ typename LType, typename LTypeMask, typename DType, int ndim>
__global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool *in_mask,
const index_t M, int axis, Shape<ndim> sshape,
- Shape<ndim> mask_shape, const double scale,
- const double temperature, const int rows_per_block,
+ Shape<ndim> mask_shape,
+ const double temperature,
+ const int rows_per_block,
const index_t total_rows,
const size_t size_input_shared,
const size_t size_mask_shared) {
@@ -616,7 +618,7 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool
scratch[threadIdx.x] = my_value;
}
__syncthreads();
- smax = scratch[threadIdx.x - threadIdx.x % threads_per_row] / scale;
+ smax = scratch[threadIdx.x - threadIdx.x % threads_per_row];
__syncthreads();
}
@@ -624,7 +626,7 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool
red::sum::SetInitValue(my_sum);
for (index_t i = my_id; i < M; i += threads_per_row) {
if (row_mask[i]) {
- const DType val = (negate ? -row[i] : row[i]) / scale;
+ const DType val = (negate ? -row[i] : row[i]);
my_sum += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
}
}
@@ -646,10 +648,13 @@ __global__ void masked_softmax_stride1_kernel(const DType *in, DType *out, bool
AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
__syncthreads();
+ double masked_value = 0.0;
+ if (masked_neg_inf)
+ masked_value = -INFINITY;
for (index_t i = my_id; i < M; i += threads_per_row) {
- const DType val = (negate ? -row[i] : row[i]) / scale;
+ const DType val = (negate ? -row[i] : row[i]);
row[i] = row_mask[i] ? DType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) :
- DType(0.0f);
+ DType(masked_value);
}
__syncthreads();
@@ -699,11 +704,11 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
}
}
-template<typename OP, bool negate, typename AType, typename DType,
- typename OType, int ndim>
+template<typename OP, bool masked_neg_inf, bool negate,
+ typename AType, typename DType, typename OType, int ndim>
inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
Shape<ndim> data_shape, Shape<ndim> mask_shape,
- int axis, const double scale, const double temperature,
+ int axis, const double temperature,
bool normalize, const OpContext& ctx) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
@@ -747,16 +752,18 @@ inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
int nblocks = (N + rows_per_block - 1) / rows_per_block;
if (normalize) {
- masked_softmax_stride1_kernel<true, OP, negate, AType, LType, LTypeMask>
+ masked_softmax_stride1_kernel<true, OP, masked_neg_inf, negate,
+ AType, LType, LTypeMask>
<<<nblocks, softmax_threads_per_block, amount_shared,
mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+ in, out, mask, M, axis, sshape, mask_shape, temperature,
rows_per_block, N, size_input_shared, size_mask_shared);
} else {
- masked_softmax_stride1_kernel<false, OP, negate, AType, LType, LTypeMask>
+ masked_softmax_stride1_kernel<false, OP, masked_neg_inf, negate,
+ AType, LType, LTypeMask>
<<<nblocks, softmax_threads_per_block, amount_shared,
mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, mask, M, axis, sshape, mask_shape, scale, temperature,
+ in, out, mask, M, axis, sshape, mask_shape, temperature,
rows_per_block, N, size_input_shared, size_mask_shared);
}
});
@@ -765,13 +772,13 @@ inline void MaskedSoftmax(Stream<gpu> *s, DType *in, OType *out, bool *mask,
} else {
size_t amount_shared = x_size * sizeof(AType);
if (normalize) {
- masked_softmax_kernel<true, x_bits, OP, negate, AType, ndim>
+ masked_softmax_kernel<true, x_bits, OP, masked_neg_inf, negate, AType, ndim>
<<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, mask, M, axis, sshape, stride, mask_shape, scale, temperature);
+ in, out, mask, M, axis, sshape, stride, mask_shape, temperature);
} else {
- masked_softmax_kernel<false, x_bits, OP, negate, AType, ndim>
+ masked_softmax_kernel<false, x_bits, OP, masked_neg_inf, negate, AType, ndim>
<<<N, x_size, amount_shared, mshadow::Stream<gpu>::GetStream(s)>>>(
- in, out, mask, M, axis, sshape, stride, mask_shape, scale, temperature);
+ in, out, mask, M, axis, sshape, stride, mask_shape, temperature);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_kernel);
}
@@ -898,7 +905,6 @@ __global__ void masked_softmax_stride1_grad_kernel(const OType *out, const OType
const index_t M, int axis,
Shape<ndim> sshape,
Shape<ndim> mask_shape,
- const double scale,
const double temperature,
const int rows_per_block,
const index_t total_rows,
@@ -975,14 +981,12 @@ __global__ void masked_softmax_stride1_grad_kernel(const OType *out, const OType
AType ssum = scratch[threadIdx.x - threadIdx.x % threads_per_row];
__syncthreads();
- AType temperature_scale = static_cast<AType>(temperature) *
- static_cast<AType>(scale);
for (index_t i = my_id; i < M; i += threads_per_row) {
const DType val =
negate ?
-OP2::Map(row[i + M], row[i], ssum):
OP2::Map(row[i + M], row[i], ssum);
- row[i] = row_mask[i] ? DType(val / static_cast<DType>(temperature_scale)) :
+ row[i] = row_mask[i] ? DType(val / static_cast<DType>(temperature)) :
DType(0.0f);
if (Req == kAddTo) {
row[i] += igrad[my_row * M + i];
@@ -1003,7 +1007,7 @@ __global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType *igra
const bool *in_mask, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
Shape<ndim> mask_shape,
- const double scale, const double temperature) {
+ const double temperature) {
const unsigned x_size = 1 << x_bits;
__shared__ AType smem[x_size];
index_t sa = stride[axis];
@@ -1026,15 +1030,13 @@ __global__ void masked_softmax_grad_kernel(OType *out, OType *ograd, DType *igra
__syncthreads();
DType final_result;
- AType temperature_scale = static_cast<AType>(temperature) *
- static_cast<AType>(scale);
for (index_t i = x; i < M; i += x_size) {
bool mask_value = bcst_mask_axis ? in_mask[base_mask] : in_mask[base_mask + i*sa_mask];
final_result =
negate ?
-OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum):
OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
- final_result = mask_value ? final_result / static_cast<DType>(temperature_scale) : DType(0.0f);
+ final_result = mask_value ? final_result / static_cast<DType>(temperature) : DType(0.0f);
KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result);
}
}
@@ -1086,7 +1088,7 @@ template<typename OP1, typename OP2, int Req, bool negate, typename AType, int n
inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, bool *mask, Shape<ndim> data_shape,
Shape<ndim> mask_shape, int axis,
- const double scale, const double temperature,
+ const double temperature,
const OpContext& ctx) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
@@ -1133,14 +1135,14 @@ inline void MaskedSoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
<<<nblocks, softmax_threads_per_block, amount_shared,
mshadow::Stream<gpu>::GetStream(s)>>>(
out, ograd, igrad, mask, M, axis, sshape, mask_shape,
- scale, temperature, rows_per_block, N, size_input_shared, size_mask_shared);
+ temperature, rows_per_block, N, size_input_shared, size_mask_shared);
});
});
MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_stride1_grad_kernel);
} else {
masked_softmax_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
- out, ograd, igrad, mask, M, axis, sshape, stride, mask_shape, scale, temperature);
+ out, ograd, igrad, mask, M, axis, sshape, stride, mask_shape, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(masked_softmax_grad_kernel);
}
}
@@ -1181,15 +1183,12 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
int axis;
- dmlc::optional<double> scale_factor;
dmlc::optional<double> temperature;
dmlc::optional<int> dtype;
dmlc::optional<bool> normalize;
DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
.describe("The axis along which to compute softmax.");
- DMLC_DECLARE_FIELD(scale_factor).set_default(dmlc::optional<double>())
- .describe("Scaling factor applied before softmax");
DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
.describe("Temperature parameter in softmax");
DMLC_DECLARE_FIELD(normalize)
@@ -1492,7 +1491,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
});
}
-template<typename xpu, typename OP, bool negate = false>
+template<typename xpu, typename OP, bool masked_neg_inf, bool negate = false>
void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
@@ -1503,8 +1502,6 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
CHECK_NE(req[0], kAddTo);
const MaskedSoftmaxParam& param = nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
- const double scale = param.scale_factor.has_value() ?
- param.scale_factor.value() : 1.0;
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
@@ -1518,17 +1515,17 @@ void MaskedSoftmaxCompute(const nnvm::NodeAttrs& attrs,
MXNET_NDIM_SWITCH(inputs[0].ndim(), ndim, {
bool* mask_ptr = inputs[1].dptr<bool>();
if (safe_acc) {
- MaskedSoftmax<OP, negate, AType>(
+ MaskedSoftmax<OP, masked_neg_inf, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), mask_ptr,
inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
- axis, scale, temperature, param.normalize.value(), ctx);
+ axis, temperature, param.normalize.value(), ctx);
} else {
- MaskedSoftmax<OP, negate, DType>(
+ MaskedSoftmax<OP, masked_neg_inf, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), mask_ptr,
inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
- axis, scale, temperature, param.normalize.value(), ctx);
+ axis, temperature, param.normalize.value(), ctx);
}
});
});
@@ -1616,8 +1613,6 @@ void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
if (req[0] == kNullOp) return;
const MaskedSoftmaxParam& param = nnvm::get<MaskedSoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
- const double scale = param.scale_factor.has_value() ?
- param.scale_factor.value() : 1.0;
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
@@ -1634,15 +1629,13 @@ void MaskedSoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
ctx.get_stream<xpu>(), out_ptr,
ograd_ptr, grad_data, mask_ptr,
inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
- axis, static_cast<DType>(scale),
- static_cast<DType>(temperature), ctx);
+ axis, static_cast<DType>(temperature), ctx);
} else {
MaskedSoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), out_ptr,
ograd_ptr, grad_data, mask_ptr,
inputs[0].shape_.get<ndim>(), inputs[1].shape_.get<ndim>(),
- axis, static_cast<DType>(scale),
- static_cast<DType>(temperature), ctx);
+ axis, static_cast<DType>(temperature), ctx);
}
});
});
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index cf67853..b3ffd42 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -201,7 +201,8 @@ NNVM_REGISTER_OP(masked_softmax)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
-.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::softmax_fwd>)
+.set_attr<FCompute>("FCompute<cpu>", MaskedSoftmaxCompute<cpu, mxnet_op::softmax_fwd,
+ false>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto data_grad = MakeNode("_backward_masked_softmax", n->attrs.name + "_backward_data",
diff --git a/src/operator/nn/softmax.cu b/src/operator/nn/softmax.cu
index dc8fd99..c75f543 100644
--- a/src/operator/nn/softmax.cu
+++ b/src/operator/nn/softmax.cu
@@ -35,7 +35,8 @@ NNVM_REGISTER_OP(_backward_softmax)
.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>);
NNVM_REGISTER_OP(masked_softmax)
-.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::softmax_fwd>);
+.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxCompute<gpu, mxnet_op::softmax_fwd,
+ false>);
NNVM_REGISTER_OP(_backward_masked_softmax)
.set_attr<FCompute>("FCompute<gpu>", MaskedSoftmaxGradCompute<gpu, op::mshadow_op::mul,
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5034b07..7a85364 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4941,27 +4941,31 @@ def test_softmax_with_length():
[np.zeros(shape), np.zeros(len_shape, dtype=np.int32)],
rtol=1e-2, atol=2e-3 if dtype == np.float16 else 1e-3, dtype="asnumpy")
-def np_softmax(x, axis=-1, scale_factor=1.0, temperature=1.0, normalize=True):
- x = x / scale_factor
+def np_softmax(x, axis=-1, temperature=1.0, normalize=True):
if normalize:
x = x - np.max(x, axis=axis, keepdims=True)
x = np.exp(x / temperature)
x /= np.sum(x, axis=axis, keepdims=True)
return x
-def np_masked_softmax(data, mask, axis=-1, scale_factor=1.0, temperature=1.0, normalize=True):
+def np_masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
neg = -1e18
if data.dtype == np.float16:
neg = -1e4
temp = np.where(mask, data, neg)
result = np_softmax(temp, axis=axis,
- scale_factor=scale_factor,
temperature=temperature,
normalize=normalize) * mask
return result
-def np_masked_softmax_grad(out, grad_out, axis=-1, scale_factor=1.0, temperature=1.0):
+def np_masked_softmax_grad(out, grad_out, axis=-1, temperature=1.0):
temp = np.sum(out * grad_out, axis=axis, keepdims=True)
- result = out * (grad_out - temp) / (temperature * scale_factor)
+ result = out * (grad_out - temp) / temperature
+ return result
+def np_masked_log_softmax_grad(out, grad_out, mask, axis=-1, temperature=1.0):
+ grad_out = np.where(mask, grad_out, 0)
+ temp = np.sum(grad_out, axis=axis, keepdims=True)
+ result = (grad_out - np.exp(out) * temp) / temperature
+ result = np.where(mask, result, 0)
return result
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
@@ -4969,9 +4973,8 @@ def np_masked_softmax_grad(out, grad_out, axis=-1, scale_factor=1.0, temperature
@pytest.mark.parametrize('ndims', [3, 4, 5])
@pytest.mark.parametrize('n_broadcast_axis', [0, 1, 2])
@pytest.mark.parametrize('temperature', [1, 5, 9 ,11])
-@pytest.mark.parametrize('scale', [1, 2, 7, 12])
@pytest.mark.parametrize('normalize', [True])
-def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, scale, normalize):
+def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, normalize):
n_broadcast_axis = min(n_broadcast_axis, ndims - 1)
shape = rand_shape_nd(ndims, dim=10)
mx_data = rand_ndarray(shape, dtype=dtype)
@@ -4991,12 +4994,12 @@ def test_masked_softmax(dtype, axis, ndims, n_broadcast_axis, temperature, scale
np_grad = mx_grad.asnumpy()
np_out = np_masked_softmax(np_data, np_mask, axis,
- scale, temperature, normalize)
+ temperature, normalize)
np_grad_out = np_masked_softmax_grad(np_out, np_grad,
- axis, scale, temperature)
+ axis, temperature)
data = mx.sym.Variable("data")
mask = mx.sym.Variable("mask")
- mx_sym = mx.sym.masked_softmax(data=data, mask=mask, scale_factor=scale,
+ mx_sym = mx.sym.masked_softmax(data=data, mask=mask,
temperature=temperature, axis=axis,
normalize=normalize)
location = {"data": mx_data, "mask": mx_mask}
@@ -5019,15 +5022,22 @@ def test_masked_log_softmax(dtype, ndims):
np_data = mx_data.asnumpy()
np_mask = np.random.randint(0, 2, shape)
mx_mask = mx.nd.array(np_mask, dtype=np.bool)
+ mx_grad = rand_ndarray(shape, dtype=dtype)
+ np_grad = mx_grad.asnumpy()
np_out = np.log(np_masked_softmax(np_data, np_mask, axis)+1e-20) * np_mask
+ np_out_inf = np.where(np_mask, np_out, -np.inf)
+ np_grad_out = np_masked_log_softmax_grad(np_out, np_grad, np_mask, axis)
data = mx.sym.Variable("data")
mask = mx.sym.Variable("mask")
mx_sym = mx.sym.masked_log_softmax(data=data, mask=mask, axis=axis-ndims)
location = {"data": mx_data, "mask": mx_mask}
rtol = 1e-2 if dtype == np.float16 else 1e-3
atol = 1e-4 if dtype == np.float16 else 1e-5
- check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy")
- check_numeric_gradient(mx_sym, location, rtol=1e-1, atol=1e-2)
+ check_symbolic_forward(mx_sym, location, [np_out_inf], rtol=rtol, atol=atol, dtype="asnumpy")
+ check_symbolic_backward(mx_sym, location, [mx_grad],
+ [np_grad_out, np.zeros(shape, dtype=np.bool)],
+ rtol=1e-2, atol=2e-3 if dtype == np.float16 else 1e-3,
+ dtype="asnumpy", equal_nan=True)
def test_pick():