You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/21 00:37:45 UTC
[incubator-mxnet] branch master updated: softmax for fp16 with fp32
accumulator (#14098)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 862cbc6 softmax for fp16 with fp32 accumulator (#14098)
862cbc6 is described below
commit 862cbc67aacf81990b8c885847686a4c3c734cd3
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Wed Feb 20 16:37:12 2019 -0800
softmax for fp16 with fp32 accumulator (#14098)
* softmax for fp16 with fp32 accumulator
* return AType in kernel
* add dtype
* kernel
* grad use in-out only when dtype override
* simplify infer type
* address comments
---
src/operator/mxnet_op.h | 42 ++++++
src/operator/nn/softmax-inl.h | 248 ++++++++++++++++++++++++---------
src/operator/nn/softmax.cc | 66 +++++++--
tests/python/unittest/test_operator.py | 41 ++++++
4 files changed, 326 insertions(+), 71 deletions(-)
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 6cab199..d8fc503 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -249,6 +249,48 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}
+#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
+ switch (type) { \
+ case mshadow::kFloat32: \
+ { \
+ typedef float DType; \
+ typedef double AType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kFloat64: \
+ { \
+ typedef double DType; \
+ typedef double AType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kFloat16: \
+ { \
+ typedef mshadow::half::half_t DType; \
+ typedef float AType; \
+ {__VA_ARGS__} \
+ } \
+ break; \
+ case mshadow::kUint8: \
+ LOG(FATAL) << "This operation only support " \
+ "floating point types not uint8"; \
+ break; \
+ case mshadow::kInt8: \
+ LOG(FATAL) << "This operation only support " \
+ "floating point types not int8"; \
+ break; \
+ case mshadow::kInt32: \
+ LOG(FATAL) << "This operation only support " \
+ "floating point types, not int32"; \
+ break; \
+ case mshadow::kInt64: \
+ LOG(FATAL) << "This operation only support " \
+ "floating point types, not int64"; \
+ break; \
+ default: \
+ LOG(FATAL) << "Unknown type enum " << type; \
+ }
/*!
* \brief assign the val to out according
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index c063e38..90950bc 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -25,6 +25,9 @@
#ifndef MXNET_OPERATOR_NN_SOFTMAX_INL_H_
#define MXNET_OPERATOR_NN_SOFTMAX_INL_H_
+#include <algorithm>
+#include <string>
+#include <utility>
#include <vector>
#include "../mxnet_op.h"
@@ -36,23 +39,33 @@ namespace op {
namespace mxnet_op {
struct softmax_fwd {
- template<typename DType>
- MSHADOW_XINLINE static DType Map(DType a, DType b) {
- return DType(expf(a)/b);
+ template<typename AType>
+ MSHADOW_XINLINE static AType Map(float a, AType b) {
+ return AType(expf(a)/b);
+ }
+
+ template<typename AType>
+ MSHADOW_XINLINE static AType Map(double a, AType b) {
+ return AType(exp(a)/b);
}
};
struct log_softmax_fwd {
template<typename DType>
- MSHADOW_XINLINE static DType Map(DType a, DType b) {
- return DType(a - logf(b));
+ MSHADOW_XINLINE static float Map(DType a, float b) {
+ return a - logf(b);
+ }
+
+ template<typename DType>
+ MSHADOW_XINLINE static double Map(DType a, double b) {
+ return a - log(b);
}
};
-template<typename OP, bool negate, typename DType, int ndim>
-inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
+template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim>
+inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
index_t N = shape.Size()/M;
@@ -72,10 +85,9 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
if (mmax < val) mmax = val;
}
- DType sum = DType(0);
+ AType sum = AType(0);
DType in_val;
- // By default temperature is 1.0, and only in reinforcement training
- // users would set it to other values.
+ // By default temperature is 1.0.
// Adding a branch here to save the CPU 'divide-by-1' computation at runtime
if (temperature == 1.0) {
for (index_t j = 0; j < M; ++j) {
@@ -103,23 +115,29 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
struct softmax_bwd {
- template<typename DType>
- MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) {
- return DType(out * (ograd - sum));
+ template<typename DType, typename AType>
+ MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
+ return AType(out * (ograd - sum));
}
};
struct log_softmax_bwd {
- template<typename DType>
- MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) {
- return DType(ograd - expf(out)*sum);
+ template<typename AType>
+ MSHADOW_XINLINE static AType Map(float ograd, float out, AType sum) {
+ return AType(ograd - expf(out)*sum);
+ }
+
+ template<typename AType>
+ MSHADOW_XINLINE static AType Map(double ograd, double out, AType sum) {
+ return AType(ograd - exp(out)*sum);
}
};
-template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
-inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
+template<typename OP1, typename OP2, int Req, bool negate,
+ typename AType, typename DType, typename OType, int ndim>
+inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const DType temperature) {
index_t M = shape[axis];
@@ -133,13 +151,12 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
for (int i = 0; i < static_cast<int>(N); ++i) {
index_t base = unravel_dot(i, sshape, stride);
- DType sum = DType(0);
+ AType sum = AType(0);
for (index_t j = 0; j < M; ++j) {
sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
}
- // By default temperature is 1.0, and only in reinforcement training
- // users would set it to other values.
+ // By default temperature is 1.0.
// Adding a branch here to save the CPU 'divide-by-1' computation at runtime
DType final_result;
if (temperature == 1.0) {
@@ -162,19 +179,20 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
#ifdef __CUDACC__
-template<int x_bits, typename OP, bool negate, typename DType, int ndim>
-__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis,
+template<int x_bits, typename OP, bool negate, typename AType, int ndim,
+ typename DType, typename OType>
+__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
const double temperature) {
const unsigned x_size = 1 << x_bits;
- __shared__ DType smem[x_size];
+ __shared__ AType smem[x_size];
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;
red::maximum::SetInitValue(smem[x]);
for (index_t i = x; i < M; i += x_size) {
- red::maximum::Reduce(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+ smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
}
__syncthreads();
cuda::Reduce1D<red::maximum, x_bits>(smem);
@@ -186,13 +204,12 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
DType val;
for (index_t i = x; i < M; i += x_size) {
val = negate ? -in[base + i*sa]:in[base + i*sa];
- red::sum::Reduce(
- smem[x], static_cast<DType>(expf((val - smax) / static_cast<DType>(temperature))));
+ smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
}
__syncthreads();
cuda::Reduce1D<red::sum, x_bits>(smem);
__syncthreads();
- DType ssum = smem[0];
+ AType ssum = smem[0];
__syncthreads();
for (index_t i = x; i < M; i += x_size) {
@@ -201,8 +218,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
}
}
-template<typename OP, bool negate, typename DType, int ndim>
-inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
+template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim>
+inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
Shape<ndim> shape, int axis, const double temperature) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
@@ -212,31 +229,32 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
Shape<ndim> sshape = shape;
sshape[axis] = 1;
- softmax_compute_kernel<x_bits, OP, negate, DType, ndim>
+ softmax_compute_kernel<x_bits, OP, negate, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
in, out, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
}
-template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
-__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+ typename DType, typename OType>
+__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, const double temperature) {
const unsigned x_size = 1 << x_bits;
- __shared__ DType smem[x_size];
+ __shared__ AType smem[x_size];
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;
red::sum::SetInitValue(smem[x]);
for (index_t i = x; i < M; i += x_size) {
- red::sum::Reduce(smem[x], OP1::Map(ograd[base + i*sa], out[base + i*sa]));
+ smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
}
__syncthreads();
cuda::Reduce1D<red::sum, x_bits>(smem);
__syncthreads();
- DType ssum = smem[0];
+ AType ssum = smem[0];
__syncthreads();
DType final_result;
@@ -250,8 +268,9 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
}
-template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
-inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+ typename DType, typename OType>
+inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const double temperature) {
const int x_bits = 7;
@@ -262,7 +281,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
Shape<ndim> sshape = shape;
sshape[axis] = 1;
- softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, ndim>
+ softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
@@ -275,11 +294,105 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
int axis;
dmlc::optional<double> temperature;
+ dmlc::optional<int> dtype;
DMLC_DECLARE_PARAMETER(SoftmaxParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
- .describe("The axis along which to compute softmax.");
+ .describe("The axis along which to compute softmax.");
DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
- .describe("Temperature parameter in softmax");
+ .describe("Temperature parameter in softmax");
+ DMLC_DECLARE_FIELD(dtype)
+ .add_enum("float16", mshadow::kFloat16)
+ .add_enum("float32", mshadow::kFloat32)
+ .add_enum("float64", mshadow::kFloat64)
+ .set_default(dmlc::optional<int>())
+ .describe("DType of the output in case this can't be inferred. "
+ "Defaults to the same as input's dtype if not defined (dtype=None).");
+ }
+};
+
+static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+ return param.dtype.has_value() && param.dtype.value() != -1;
+}
+
+static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1);
+ CHECK_EQ(out_attrs->size(), 1);
+ const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+
+ if (softmax_has_dtype_override(attrs)) {
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
+ type_assign(&(*in_attrs)[0], (*out_attrs)[0]);
+ return true;
+ } else {
+ return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs);
+ }
+}
+
+static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape> *in_attrs,
+ std::vector<TShape> *out_attrs) {
+ if (softmax_has_dtype_override(attrs)) {
+ return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
+ } else {
+ return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs);
+ }
+}
+
+static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(out_attrs->size(), 1);
+ if (softmax_has_dtype_override(attrs)) {
+ CHECK_EQ(in_attrs->size(), 3);
+ int in_dtype = (*in_attrs)[1];
+ int out_dtype = (*in_attrs)[2];
+ TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
+
+ return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
+ } else {
+ CHECK_EQ(in_attrs->size(), 2);
+ int out_dtype = (*in_attrs)[1];
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
+ TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
+
+ return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
+ }
+}
+
+static inline std::vector<std::pair<int, int> >
+SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
+ if (softmax_has_dtype_override(attrs)) {
+ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
+ } else {
+ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
+ }
+}
+
+static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) {
+ return softmax_has_dtype_override(attrs) ? 3 : 2;
+}
+
+static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) {
+ if (softmax_has_dtype_override(attrs)) {
+ return std::vector<std::string>{"ograd", "data", "output"};
+ } else {
+ return std::vector<std::string>{"ograd", "output"};
+ }
+}
+
+struct SoftmaxFGradient {
+ const char *op_name;
+ std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+ const std::vector<nnvm::NodeEntry>& ograds) const {
+ if (softmax_has_dtype_override(n->attrs)) {
+ return ElemwiseGradUseInOut {op_name}(n, ograds);
+ } else {
+ return ElemwiseGradUseOut {op_name}(n, ograds);
+ }
}
};
@@ -297,16 +410,20 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
- MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
- if (shape.ndim() == 2) {
- Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<DType>(), shape.get<2>(), axis,
- static_cast<DType>(temperature));
- } else {
- Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
- outputs[0].dptr<DType>(), shape.get<3>(), axis,
- static_cast<DType>(temperature));
- }
+ MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ if (shape.ndim() == 2) {
+ Softmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), shape.get<2>(), axis,
+ static_cast<DType>(temperature));
+ } else {
+ Softmax<OP, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+ outputs[0].dptr<OType>(), shape.get<3>(), axis,
+ static_cast<DType>(temperature));
+ }
+ });
});
}
@@ -324,17 +441,24 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
- MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
- MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
- if (shape.ndim() == 2) {
- SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
- inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
- shape.get<2>(), axis, static_cast<DType>(temperature));
- } else {
- SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
- inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
- shape.get<3>(), axis, static_cast<DType>(temperature));
- }
+
+ int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
+
+ MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+ if (shape.ndim() == 2) {
+ SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+ inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+ shape.get<2>(), axis, static_cast<DType>(temperature));
+ } else {
+ SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+ ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+ inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+ shape.get<3>(), axis, static_cast<DType>(temperature));
+ }
+ });
});
});
}
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 81e775c..c88f738 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -67,7 +67,7 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
}
#endif
-MXNET_OPERATOR_REGISTER_UNARY(softmax)
+NNVM_REGISTER_OP(softmax)
.describe(R"code(Applies the softmax function.
The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.
@@ -102,15 +102,31 @@ Example::
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
#endif
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmax"})
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());
-MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax)
+NNVM_REGISTER_OP(_backward_softmax)
+.set_num_inputs(SoftmaxGradOpNumInputs)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>);
-MXNET_OPERATOR_REGISTER_UNARY(softmin)
+NNVM_REGISTER_OP(softmin)
.describe(R"code(Applies the softmin function.
The resulting array contains elements in the range (0,1) and the elements along the given axis sum
@@ -141,15 +157,31 @@ Example::
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd, true>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmin"})
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmin"})
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());
-MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin)
+NNVM_REGISTER_OP(_backward_softmin)
+.set_num_inputs(SoftmaxGradOpNumInputs)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd, true>);
-MXNET_OPERATOR_REGISTER_UNARY(log_softmax)
+NNVM_REGISTER_OP(log_softmax)
.describe(R"code(Computes the log softmax of the input.
This is equivalent to computing softmax followed by log.
@@ -168,10 +200,26 @@ Examples::
)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"})
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_log_softmax"})
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());
-MXNET_OPERATOR_REGISTER_BINARY(_backward_log_softmax)
+NNVM_REGISTER_OP(_backward_log_softmax)
+.set_num_inputs(SoftmaxGradOpNumInputs)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index a9b9cc8..ae7dc86 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4535,6 +4535,47 @@ def test_softmax_with_large_inputs():
softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0]))
@with_seed()
+def test_softmax_dtype():
+ def check_dtypes_almost_equal(op_name,
+ atol, rtol,
+ grad_atol, grad_rtol,
+ idtype, ref_dtype, odtype=None):
+ op = getattr(mx.nd, op_name)
+ input_data = mx.random.uniform(shape=(100, 500))
+ dtype_input = input_data.astype(idtype)
+ ref_input = input_data.astype(ref_dtype)
+ dtype_input.attach_grad()
+ ref_input.attach_grad()
+ with mx.autograd.record():
+ dtype_softmax = op(dtype_input, axis=-1, dtype=odtype)
+ ref_softmax = op(ref_input, axis=-1, dtype=odtype)
+ dtype_softmax_np = dtype_softmax.asnumpy()
+ ref_softmax_np = ref_softmax.asnumpy()
+ assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol)
+ dtype_softmax.backward()
+ ref_softmax.backward()
+ dtype_grad_np = dtype_input.grad.asnumpy()
+ ref_grad_np = ref_input.grad.asnumpy()
+ assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol)
+
+ check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
+ check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
+ check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
+ check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
+ check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
+ check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
+ check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
+ check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
+ check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
+ 'float16', 'float32')
+ check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
+ 'float16', 'float32', 'float32')
+ check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
+ 'float32', 'float64')
+ check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
+ 'float32', 'float64', 'float64')
+
+@with_seed()
def test_pick():
def test_pick_helper(index_type=np.int32):
for _ in range(100):