You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/02/21 00:37:45 UTC

[incubator-mxnet] branch master updated: softmax for fp16 with fp32 accumulator (#14098)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 862cbc6  softmax for fp16 with fp32 accumulator (#14098)
862cbc6 is described below

commit 862cbc67aacf81990b8c885847686a4c3c734cd3
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Wed Feb 20 16:37:12 2019 -0800

    softmax for fp16 with fp32 accumulator (#14098)
    
    * softmax for fp16 with fp32 accumulator
    
    * return AType in kernel
    
    * add dtype
    
    * kernel
    
    * grad use in-out only when dtype override
    
    * simplify infer type
    
    * address comments
---
 src/operator/mxnet_op.h                |  42 ++++++
 src/operator/nn/softmax-inl.h          | 248 ++++++++++++++++++++++++---------
 src/operator/nn/softmax.cc             |  66 +++++++--
 tests/python/unittest/test_operator.py |  41 ++++++
 4 files changed, 326 insertions(+), 71 deletions(-)

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 6cab199..d8fc503 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -249,6 +249,48 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "Unknown type enum " << type;            \
   }
 
+#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      typedef double AType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      typedef double AType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    {                                                      \
+      typedef mshadow::half::half_t DType;                 \
+      typedef float AType;                                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    LOG(FATAL) << "This operation only support "           \
+                  "floating point types not uint8";        \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    LOG(FATAL) << "This operation only support "           \
+                  "floating point types not int8";         \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    LOG(FATAL) << "This operation only support "           \
+                  "floating point types, not int32";       \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    LOG(FATAL) << "This operation only support "           \
+                  "floating point types, not int64";       \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
 
 /*!
  * \brief assign the val to out according
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index c063e38..90950bc 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -25,6 +25,9 @@
 #ifndef MXNET_OPERATOR_NN_SOFTMAX_INL_H_
 #define MXNET_OPERATOR_NN_SOFTMAX_INL_H_
 
+#include <algorithm>
+#include <string>
+#include <utility>
 #include <vector>
 
 #include "../mxnet_op.h"
@@ -36,23 +39,33 @@ namespace op {
 namespace mxnet_op {
 
 struct softmax_fwd {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType a, DType b) {
-    return DType(expf(a)/b);
+  template<typename AType>
+  MSHADOW_XINLINE static AType Map(float a, AType b) {
+    return AType(expf(a)/b);
+  }
+
+  template<typename AType>
+  MSHADOW_XINLINE static AType Map(double a, AType b) {
+    return AType(exp(a)/b);
   }
 };
 
 
 struct log_softmax_fwd {
   template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType a, DType b) {
-    return DType(a - logf(b));
+  MSHADOW_XINLINE static float Map(DType a, float b) {
+    return a - logf(b);
+  }
+
+  template<typename DType>
+  MSHADOW_XINLINE static double Map(DType a, double b) {
+    return a - log(b);
   }
 };
 
 
-template<typename OP, bool negate, typename DType, int ndim>
-inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
+template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim>
+inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
                     Shape<ndim> shape, int axis, const DType temperature) {
   index_t M = shape[axis];
   index_t N = shape.Size()/M;
@@ -72,10 +85,9 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
       if (mmax < val) mmax = val;
     }
 
-    DType sum = DType(0);
+    AType sum = AType(0);
     DType in_val;
-    // By default temperature is 1.0, and only in reinforcement training
-    // users would set it to other values.
+    // By default temperature is 1.0.
     // Adding a branch here to save the CPU 'divide-by-1' computation at runtime
     if (temperature == 1.0) {
       for (index_t j = 0; j < M; ++j) {
@@ -103,23 +115,29 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
 
 
 struct softmax_bwd {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) {
-    return DType(out * (ograd - sum));
+  template<typename DType, typename AType>
+  MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
+    return AType(out * (ograd - sum));
   }
 };
 
 
 struct log_softmax_bwd {
-  template<typename DType>
-  MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) {
-    return DType(ograd - expf(out)*sum);
+  template<typename AType>
+  MSHADOW_XINLINE static AType Map(float ograd, float out, AType sum) {
+    return AType(ograd - expf(out)*sum);
+  }
+
+  template<typename AType>
+  MSHADOW_XINLINE static AType Map(double ograd, double out, AType sum) {
+    return AType(ograd - exp(out)*sum);
   }
 };
 
 
-template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
-inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
+template<typename OP1, typename OP2, int Req, bool negate,
+  typename AType, typename DType, typename OType, int ndim>
+inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
                         DType *igrad, Shape<ndim> shape, int axis,
                         const DType temperature) {
   index_t M = shape[axis];
@@ -133,13 +151,12 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
   for (int i = 0; i < static_cast<int>(N); ++i) {
     index_t base = unravel_dot(i, sshape, stride);
 
-    DType sum = DType(0);
+    AType sum = AType(0);
     for (index_t j = 0; j < M; ++j) {
       sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
     }
 
-    // By default temperature is 1.0, and only in reinforcement training
-    // users would set it to other values.
+    // By default temperature is 1.0.
     // Adding a branch here to save the CPU 'divide-by-1' computation at runtime
     DType final_result;
     if (temperature == 1.0) {
@@ -162,19 +179,20 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
 
 
 #ifdef __CUDACC__
-template<int x_bits, typename OP, bool negate, typename DType, int ndim>
-__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis,
+template<int x_bits, typename OP, bool negate, typename AType, int ndim,
+  typename DType, typename OType>
+__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis,
                                        Shape<ndim> sshape, Shape<ndim> stride,
                                        const double temperature) {
   const unsigned x_size = 1 << x_bits;
-  __shared__ DType smem[x_size];
+  __shared__ AType smem[x_size];
   index_t sa = stride[axis];
   index_t base = unravel_dot(blockIdx.x, sshape, stride);
   index_t x = threadIdx.x;
 
   red::maximum::SetInitValue(smem[x]);
   for (index_t i = x; i < M; i += x_size) {
-    red::maximum::Reduce(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
+    smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
   }
   __syncthreads();
   cuda::Reduce1D<red::maximum, x_bits>(smem);
@@ -186,13 +204,12 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
   DType val;
   for (index_t i = x; i < M; i += x_size) {
     val = negate ? -in[base + i*sa]:in[base + i*sa];
-    red::sum::Reduce(
-      smem[x], static_cast<DType>(expf((val - smax) / static_cast<DType>(temperature))));
+    smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
   }
   __syncthreads();
   cuda::Reduce1D<red::sum, x_bits>(smem);
   __syncthreads();
-  DType ssum = smem[0];
+  AType ssum = smem[0];
   __syncthreads();
 
   for (index_t i = x; i < M; i += x_size) {
@@ -201,8 +218,8 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
   }
 }
 
-template<typename OP, bool negate, typename DType, int ndim>
-inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
+template<typename OP, bool negate, typename AType, typename DType, typename OType, int ndim>
+inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
                     Shape<ndim> shape, int axis, const double temperature) {
   const int x_bits = 7;
   const int x_size = 1 << x_bits;
@@ -212,31 +229,32 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
   Shape<ndim> sshape = shape;
   sshape[axis] = 1;
 
-  softmax_compute_kernel<x_bits, OP, negate, DType, ndim>
+  softmax_compute_kernel<x_bits, OP, negate, AType, ndim>
     <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
       in, out, M, axis, sshape, stride, temperature);
   MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
 }
 
 
-template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
-__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
+template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+  typename DType, typename OType>
+__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
                                         index_t M, int axis, Shape<ndim> sshape,
                                         Shape<ndim> stride, const double temperature) {
   const unsigned x_size = 1 << x_bits;
-  __shared__ DType smem[x_size];
+  __shared__ AType smem[x_size];
   index_t sa = stride[axis];
   index_t base = unravel_dot(blockIdx.x, sshape, stride);
   index_t x = threadIdx.x;
 
   red::sum::SetInitValue(smem[x]);
   for (index_t i = x; i < M; i += x_size) {
-    red::sum::Reduce(smem[x], OP1::Map(ograd[base + i*sa], out[base + i*sa]));
+    smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
   }
   __syncthreads();
   cuda::Reduce1D<red::sum, x_bits>(smem);
   __syncthreads();
-  DType ssum = smem[0];
+  AType ssum = smem[0];
   __syncthreads();
 
   DType final_result;
@@ -250,8 +268,9 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
 }
 
 
-template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
-inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
+template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
+  typename DType, typename OType>
+inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
                         DType *igrad, Shape<ndim> shape, int axis,
                         const double temperature) {
   const int x_bits = 7;
@@ -262,7 +281,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
   Shape<ndim> sshape = shape;
   sshape[axis] = 1;
 
-  softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, ndim>
+  softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
     <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
       out, ograd, igrad, M, axis, sshape, stride, temperature);
   MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
@@ -275,11 +294,105 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
 struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
   int axis;
   dmlc::optional<double> temperature;
+  dmlc::optional<int> dtype;
   DMLC_DECLARE_PARAMETER(SoftmaxParam) {
     DMLC_DECLARE_FIELD(axis).set_default(-1)
-      .describe("The axis along which to compute softmax.");
+    .describe("The axis along which to compute softmax.");
     DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
-      .describe("Temperature parameter in softmax");
+    .describe("Temperature parameter in softmax");
+    DMLC_DECLARE_FIELD(dtype)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(dmlc::optional<int>())
+    .describe("DType of the output in case this can't be inferred. "
+              "Defaults to the same as input's dtype if not defined (dtype=None).");
+  }
+};
+
+static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+  return param.dtype.has_value() && param.dtype.value() != -1;
+}
+
+static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
+                                 std::vector<int>* in_attrs,
+                                 std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+
+  if (softmax_has_dtype_override(attrs)) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
+    type_assign(&(*in_attrs)[0], (*out_attrs)[0]);
+    return true;
+  } else {
+    return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs);
+  }
+}
+
+static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
+                                      std::vector<TShape> *in_attrs,
+                                      std::vector<TShape> *out_attrs) {
+  if (softmax_has_dtype_override(attrs)) {
+    return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
+  } else {
+    return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs);
+  }
+}
+
+static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
+                                     std::vector<int>* in_attrs,
+                                     std::vector<int>* out_attrs) {
+  CHECK_EQ(out_attrs->size(), 1);
+  if (softmax_has_dtype_override(attrs)) {
+    CHECK_EQ(in_attrs->size(), 3);
+    int in_dtype = (*in_attrs)[1];
+    int out_dtype = (*in_attrs)[2];
+    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
+
+    return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
+  } else {
+    CHECK_EQ(in_attrs->size(), 2);
+    int out_dtype = (*in_attrs)[1];
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
+    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
+
+    return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
+  }
+}
+
+static inline std::vector<std::pair<int, int> >
+SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
+  if (softmax_has_dtype_override(attrs)) {
+    return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
+  } else {
+    return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
+  }
+}
+
+static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) {
+  return softmax_has_dtype_override(attrs) ? 3 : 2;
+}
+
+static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) {
+  if (softmax_has_dtype_override(attrs)) {
+    return std::vector<std::string>{"ograd", "data", "output"};
+  } else {
+    return std::vector<std::string>{"ograd", "output"};
+  }
+}
+
+struct SoftmaxFGradient {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+                                          const std::vector<nnvm::NodeEntry>& ograds) const {
+    if (softmax_has_dtype_override(n->attrs)) {
+      return ElemwiseGradUseInOut {op_name}(n, ograds);
+    } else {
+      return ElemwiseGradUseOut {op_name}(n, ograds);
+    }
   }
 };
 
@@ -297,16 +410,20 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
   const double temperature = param.temperature.has_value() ?
     param.temperature.value() : 1.0;
   TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
-  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    if (shape.ndim() == 2) {
-      Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-                          outputs[0].dptr<DType>(), shape.get<2>(), axis,
-                          static_cast<DType>(temperature));
-    } else {
-      Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
-                          outputs[0].dptr<DType>(), shape.get<3>(), axis,
-                          static_cast<DType>(temperature));
-    }
+  MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+      if (shape.ndim() == 2) {
+        Softmax<OP, negate, AType>(
+            ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+            outputs[0].dptr<OType>(), shape.get<2>(), axis,
+            static_cast<DType>(temperature));
+      } else {
+        Softmax<OP, negate, AType>(
+            ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
+            outputs[0].dptr<OType>(), shape.get<3>(), axis,
+            static_cast<DType>(temperature));
+      }
+    });
   });
 }
 
@@ -324,17 +441,24 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
   const double temperature = param.temperature.has_value() ?
     param.temperature.value() : 1.0;
   TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
-  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      if (shape.ndim() == 2) {
-        SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
-                                           inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
-                                           shape.get<2>(), axis, static_cast<DType>(temperature));
-      } else {
-        SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
-                                           inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
-                                           shape.get<3>(), axis, static_cast<DType>(temperature));
-      }
+
+  int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
+
+  MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        if (shape.ndim() == 2) {
+          SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+              ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+              inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+              shape.get<2>(), axis, static_cast<DType>(temperature));
+        } else {
+          SoftmaxGrad<OP1, OP2, Req, negate, AType>(
+              ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
+              inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
+              shape.get<3>(), axis, static_cast<DType>(temperature));
+        }
+      });
     });
   });
 }
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 81e775c..c88f738 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -67,7 +67,7 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
 }
 #endif
 
-MXNET_OPERATOR_REGISTER_UNARY(softmax)
+NNVM_REGISTER_OP(softmax)
 .describe(R"code(Applies the softmax function.
 
 The resulting array contains elements in the range (0,1) and the elements along the given axis sum up to 1.
@@ -102,15 +102,31 @@ Example::
 .set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
 .set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
 #endif
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmax"})
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
 .add_arguments(SoftmaxParam::__FIELDS__());
 
-MXNET_OPERATOR_REGISTER_BINARY(_backward_softmax)
+NNVM_REGISTER_OP(_backward_softmax)
+.set_num_inputs(SoftmaxGradOpNumInputs)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
 .set_attr_parser(ParamParser<SoftmaxParam>)
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
                                                         mxnet_op::softmax_bwd>);
 
-MXNET_OPERATOR_REGISTER_UNARY(softmin)
+NNVM_REGISTER_OP(softmin)
 .describe(R"code(Applies the softmin function.
 
 The resulting array contains elements in the range (0,1) and the elements along the given axis sum
@@ -141,15 +157,31 @@ Example::
     return std::vector<std::string>{"output"};
 })
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd, true>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_softmin"})
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmin"})
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
 .add_arguments(SoftmaxParam::__FIELDS__());
 
-MXNET_OPERATOR_REGISTER_BINARY(_backward_softmin)
+NNVM_REGISTER_OP(_backward_softmin)
+.set_num_inputs(SoftmaxGradOpNumInputs)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
 .set_attr_parser(ParamParser<SoftmaxParam>)
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
                                                         mxnet_op::softmax_bwd, true>);
 
-MXNET_OPERATOR_REGISTER_UNARY(log_softmax)
+NNVM_REGISTER_OP(log_softmax)
 .describe(R"code(Computes the log softmax of the input.
 This is equivalent to computing softmax followed by log.
 
@@ -168,10 +200,26 @@ Examples::
 )code")
 .set_attr_parser(ParamParser<SoftmaxParam>)
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"})
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_log_softmax"})
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("data", "NDArray-or-Symbol", "The input array.")
 .add_arguments(SoftmaxParam::__FIELDS__());
 
-MXNET_OPERATOR_REGISTER_BINARY(_backward_log_softmax)
+NNVM_REGISTER_OP(_backward_log_softmax)
+.set_num_inputs(SoftmaxGradOpNumInputs)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
 .set_attr_parser(ParamParser<SoftmaxParam>)
 .set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
                                                         mxnet_op::log_softmax_bwd>);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index a9b9cc8..ae7dc86 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4535,6 +4535,47 @@ def test_softmax_with_large_inputs():
     softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0]))
 
 @with_seed()
+def test_softmax_dtype():
+    def check_dtypes_almost_equal(op_name,
+                                  atol, rtol,
+                                  grad_atol, grad_rtol,
+                                  idtype, ref_dtype, odtype=None):
+        op = getattr(mx.nd, op_name)
+        input_data = mx.random.uniform(shape=(100, 500))
+        dtype_input = input_data.astype(idtype)
+        ref_input = input_data.astype(ref_dtype)
+        dtype_input.attach_grad()
+        ref_input.attach_grad()
+        with mx.autograd.record():
+            dtype_softmax = op(dtype_input, axis=-1, dtype=odtype)
+            ref_softmax = op(ref_input, axis=-1, dtype=odtype)
+        dtype_softmax_np = dtype_softmax.asnumpy()
+        ref_softmax_np = ref_softmax.asnumpy()
+        assert_almost_equal(dtype_softmax_np, ref_softmax_np, rtol=rtol, atol=atol)
+        dtype_softmax.backward()
+        ref_softmax.backward()
+        dtype_grad_np = dtype_input.grad.asnumpy()
+        ref_grad_np = ref_input.grad.asnumpy()
+        assert_almost_equal(dtype_grad_np, ref_grad_np, rtol=grad_rtol, atol=grad_atol)
+
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
+    check_dtypes_almost_equal('softmax', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float16', 'float32', 'float32')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64')
+    check_dtypes_almost_equal('softmin', 1e-5, 1e-5, 1e-5, 1e-5, 'float32', 'float64', 'float64')
+    check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
+                              'float16', 'float32')
+    check_dtypes_almost_equal('log_softmax', 1e-2, 1e-2, 1e-2, 1e-2,
+                              'float16', 'float32', 'float32')
+    check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
+                              'float32', 'float64')
+    check_dtypes_almost_equal('log_softmax', 1e-3, 1e-3, 1e-3, 1e-3,
+                              'float32', 'float64', 'float64')
+
+@with_seed()
 def test_pick():
     def test_pick_helper(index_type=np.int32):
         for _ in range(100):