You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/02/13 10:16:25 UTC

[incubator-mxnet] branch master updated: add mkldnn softmax_output (#13699)

This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 45978a9  add mkldnn softmax_output (#13699)
45978a9 is described below

commit 45978a95d91634973003f7143974c615001cb75a
Author: rongzha1 <ro...@intel.com>
AuthorDate: Wed Feb 13 18:15:51 2019 +0800

    add mkldnn softmax_output (#13699)
    
    * add mkldnn softmax_output
    
    * fix gpu OP unittest error
    
    * fix ci/jenkins/mxnet-validation/unix-gpu compiler error
    
    * fix coding style
    
    * fix Tao comments
    
    * remove blank line, fix indentx
    
    * modify according to sandeep's comments
    
    * change get CPU engine method, and pravate variable
    
    * move macro MXNET_USE_MKLDNN to the head
    
    * modify according to Tao's comments
    
    * make output layout as input
    
    * change API of GetSoftmaxOutputForward
    
    * add CommitOutput for mkldnn_softmax_output
    
    * trigger Jenkins re-test
    
    * add alias Softmax symbol for SoftmaxOutput OP
    
    * indent and remove blank line
---
 src/operator/nn/mkldnn/mkldnn_base-inl.h        |   2 +
 src/operator/nn/mkldnn/mkldnn_ops-inl.h         |  10 +-
 src/operator/nn/mkldnn/mkldnn_softmax_output.cc | 145 +++++++++++++++++++
 src/operator/softmax_output-inl.h               |  68 ++++++++-
 src/operator/softmax_output.cc                  | 181 ++++++++++++++++++++----
 src/operator/softmax_output.cu                  |  14 +-
 6 files changed, 380 insertions(+), 40 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index f770c4a..bf220b8 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -174,11 +174,13 @@ struct ActivationParam;
 struct ConvolutionParam;
 struct DeconvolutionParam;
 struct SoftmaxParam;
+struct SoftmaxOutputParam;
 bool SupportMKLDNNAct(const ActivationParam& param);
 bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
 bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
 bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
 bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
+bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 5093770..39f2632 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -76,6 +76,12 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                           const NDArray &in_data, const OpReqType &req,
                           const NDArray &out_data);
 
+/* For softmax_output */
+void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                                const std::vector<NDArray> &in_data,
+                                const std::vector<OpReqType> &req,
+                                const std::vector<NDArray> &out_data);
+
 /* For sum */
 void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                       const std::vector<NDArray> &inputs, const OpReqType &req,
@@ -83,8 +89,8 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
 
 /* For copy */
 void MKLDNNCopy(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-    const NDArray &in_data, const OpReqType &req,
-    const NDArray &out_data);
+                const NDArray &in_data, const OpReqType &req,
+                const NDArray &out_data);
 
 /* For concat */
 void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
diff --git a/src/operator/nn/mkldnn/mkldnn_softmax_output.cc b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
new file mode 100644
index 0000000..ae34fe6
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_softmax_output.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_softmax_output.cc
+ * \brief integrate mkldnn softmax to softmax_output forward
+ * \author Zhang Rong A
+*/
+
+#if MXNET_USE_MKLDNN == 1
+#include "../../softmax_output-inl.h"
+#include "./mkldnn_ops-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl(
+               const SoftmaxOutputParam& param, bool is_train,
+               const int axis, const mkldnn::memory &input_mem) {
+  mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
+  mkldnn::memory::desc data_md = data_mpd.desc();
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  auto prop = is_train ? mkldnn::prop_kind::forward_training
+                       : mkldnn::prop_kind::forward_scoring;
+  auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
+  return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
+}
+
+typedef ParamOpSign<SoftmaxOutputParam> MKLDNNSoftmaxOuputSignature;
+
+class MKLDNNSoftmaxOutputFwd {
+  std::shared_ptr<mkldnn::softmax_forward> fwd_;
+  std::shared_ptr<mkldnn::memory> data_;
+  std::shared_ptr<mkldnn::memory> out_;
+
+ public:
+  const mkldnn::softmax_forward::primitive_desc fwd_pd;
+
+  MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train,
+                         const int axis, const mkldnn::memory &mem): fwd_pd(
+                         GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) {
+  }
+
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
+    if (this->data_ == nullptr)
+      this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        data.get_primitive_desc(), data.get_data_handle()));
+    else
+      this->data_->set_data_handle(data.get_data_handle());
+
+    if (this->out_ == nullptr)
+      this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        output.get_primitive_desc(), output.get_data_handle()));
+    else
+      this->out_->set_data_handle(output.get_data_handle());
+
+    if (this->fwd_ == nullptr) {
+      this->fwd_ = std::shared_ptr<mkldnn::softmax_forward>(
+        new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
+        *this->out_));
+    }
+  }
+
+  const mkldnn::softmax_forward &GetFwd() const {
+    return *fwd_;
+  }
+};
+
+static MKLDNNSoftmaxOutputFwd &GetSoftmaxOutputForward(const SoftmaxOutputParam& param,
+                                                       const OpContext &ctx,
+                                                       const NDArray &in_data) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local
+    std::unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL
+    std::unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash> fwds;
+#endif
+  MKLDNNSoftmaxOuputSignature key(param);
+  key.AddSign(ctx.is_train);
+  key.AddSign(in_data);
+
+  //  softmax_output has no axis parameter, so use it as it original implement.
+  int axis = in_data.shape().ndim() - 1;
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    auto in_mem = *(in_data.GetMKLDNNData());
+    MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+//  This is only used for forward. For backward ,need double check compatibility
+bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param) {
+  return param.multi_output ? false : true;
+}
+
+void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
+                                const OpContext &ctx,
+                                const std::vector<NDArray> &in_data,
+                                const std::vector<OpReqType> &req,
+                                const std::vector<NDArray> &out_data) {
+  const SoftmaxOutputParam &param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
+
+  NDArray idata = in_data[softmaxout_enum::kData];
+  NDArray odata = out_data[softmaxout_enum::kOut];
+  if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) {
+    idata = in_data[softmaxout_enum::kData].Reorder2Default();
+  }
+
+  auto input_mem = idata.GetMKLDNNData();
+  auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut],
+                                 input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]);
+
+  MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata);
+  fwd.SetNewMem(*input_mem, *out_mem.second);
+
+  MKLDNNStream *stream = MKLDNNStream::Get();
+  stream->RegisterPrim(fwd.GetFwd());
+
+  CommitOutput(out_data[softmaxout_enum::kOut], out_mem);
+  stream->Submit();
+}
+}   // namespace op
+}   // namespace mxnet
+#endif
diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h
index fec321b..5a01d3a 100644
--- a/src/operator/softmax_output-inl.h
+++ b/src/operator/softmax_output-inl.h
@@ -88,6 +88,17 @@ struct SoftmaxOutputParam : public dmlc::Parameter<SoftmaxOutputParam> {
               "one-hot encoding of the gold label and distributed uniformly to"
               "all other labels.");
   };
+
+  bool operator==(const SoftmaxOutputParam& other) const {
+    return this->grad_scale == other.grad_scale &&
+           this->ignore_label == other.ignore_label &&
+           this->multi_output == other.multi_output &&
+           this->use_ignore == other.use_ignore &&
+           this->preserve_shape == other.preserve_shape &&
+           this->normalization == other.normalization &&
+           this->out_grad == other.out_grad &&
+           this->smooth_alpha == other.smooth_alpha;
+  }
 };
 
 template<typename xpu, typename DType>
@@ -267,9 +278,43 @@ class SoftmaxOutputOp : public Operator {
   SoftmaxOutputParam param_;
 };  // class SoftmaxOutputOp
 
-// Decalre Factory function, used for dispatch specialization
 template<typename xpu>
-Operator* CreateOp(SoftmaxOutputParam param, int dtype);
+void SoftmaxOutputCompute(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx, const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  const SoftmaxOutputParam &param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
+  const std::vector<TBlob> no_use_but_adapt_origin_api;
+  CHECK_EQ(inputs.size(), 2U);
+
+  MSHADOW_REAL_TYPE_SWITCH(inputs[softmaxout_enum::kData].type_flag_, DType, {
+    SoftmaxOutputOp<xpu, DType> op(param);
+    op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api);
+  });
+}
+
+template<typename xpu>
+void SoftmaxOutputGradCompute(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
+  const std::vector<TBlob> no_use_but_adapt_origin_api;
+  CHECK_EQ(inputs.size(), 2U);
+
+  std::vector<TBlob> out_grad{inputs[0]};
+  std::vector<TBlob> out_data{inputs[0]};
+  std::vector<TBlob> in_data(inputs.begin(), inputs.end());
+  int dtype = inputs[0].type_flag_;
+  const std::vector<TBlob> &in_grad = outputs;
+
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    SoftmaxOutputOp<xpu, DType> op(param);
+    op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api);
+  });
+}
+
 
 #if DMLC_USE_CXX11
 class SoftmaxOutputProp : public OperatorProperty {
@@ -414,4 +459,23 @@ class DeprecatedSoftmaxProp : public SoftmaxOutputProp {
 
 }  // namespace op
 }  // namespace mxnet
+
+namespace std {
+template<>
+struct hash<mxnet::op::SoftmaxOutputParam> {
+  size_t operator()(const mxnet::op::SoftmaxOutputParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.grad_scale);
+    ret = dmlc::HashCombine(ret, val.ignore_label);
+    ret = dmlc::HashCombine(ret, val.multi_output);
+    ret = dmlc::HashCombine(ret, val.use_ignore);
+    ret = dmlc::HashCombine(ret, val.preserve_shape);
+    ret = dmlc::HashCombine(ret, val.normalization);
+    ret = dmlc::HashCombine(ret, val.out_grad);
+    ret = dmlc::HashCombine(ret, val.smooth_alpha);
+    return ret;
+  }
+};
+}  // namespace std
+
 #endif  // MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 5ba421f..322ac0b 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -21,30 +21,137 @@
  * Copyright (c) 2015 by Contributors
  * \file softmax_output.cc
  * \brief
- * \author Bing Xu
+ * \author Bing Xu, Zhang Rong A
 */
 #include "./softmax_output-inl.h"
-
+#if MXNET_USE_MKLDNN == 1
+#include "./nn/mkldnn/mkldnn_ops-inl.h"
+#endif
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(SoftmaxOutputParam param, int dtype) {
-  Operator *op = nullptr;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new SoftmaxOutputOp<cpu, DType>(param);
-  })
-  return op;
+
+DMLC_REGISTER_PARAMETER(SoftmaxOutputParam);
+struct SoftmaxOutputGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+                                          const std::vector<nnvm::NodeEntry>& ograds) const {
+  std::vector<nnvm::NodeEntry> out_data(n->num_outputs());
+  for (uint32_t i = 0; i < out_data.size(); ++i) {
+    out_data[i] = nnvm::NodeEntry{n, i, 0};
+  }
+  std::vector<nnvm::NodeEntry> heads;
+  heads.push_back(out_data[softmaxout_enum::kOut]);
+  heads.push_back(n->inputs[softmaxout_enum::kLabel]);
+
+  nnvm::NodePtr gnode = nnvm::Node::Create();
+  gnode->inputs = std::move(heads);
+  gnode->control_deps.emplace_back(n);
+  gnode->attrs = n->attrs;
+  gnode->attrs.op = nnvm::Op::Get("_backward_SoftmaxOutput");
+  gnode->attrs.name = n->attrs.name + "_backward";
+  std::vector<nnvm::NodeEntry> in_grad(2);
+  in_grad[0] = nnvm::NodeEntry{gnode, 0, 0};
+  in_grad[1] = nnvm::NodeEntry{gnode, 1, 0};
+  return in_grad;
+  }
+};
+
+static inline std::vector<std::string> ListArguments() {
+  return {"data", "label"};
 }
 
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator *SoftmaxOutputProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                                     std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
+                              std::vector<int> *in_type,
+                              std::vector<int> *out_type) {
+  CHECK_EQ(in_type->size(), 2U);
+  int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  for (size_t i = 0; i < in_type->size(); ++i) {
+    if ((*in_type)[i] == -1) {
+      (*in_type)[i] = dtype;
+    } else {
+      UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
+    }
+  }
+  out_type->clear();
+  out_type->push_back(dtype);
+  return true;
 }
 
-DMLC_REGISTER_PARAMETER(SoftmaxOutputParam);
+static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs,
+                               std::vector<TShape> *in_shape,
+                               std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]";
+  const TShape &dshape = in_shape->at(0);
+  if (dshape.ndim() == 0) return false;
 
-MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp)
+  // label.shape == data.shape: use probability as label
+  if (dshape != (*in_shape)[softmaxout_enum::kLabel]) {
+    if (param.multi_output) {
+      TShape lshape1 = Shape2(dshape[0], dshape.Size()/dshape[0]/dshape[1]);
+      TShape lshape2(dshape.ndim() - 1);
+      lshape2[0] = dshape[0];
+      for (index_t i = 2; i < dshape.ndim(); ++i)
+        lshape2[i-1] = dshape[i];
+      TShape lshape3 = dshape;
+      lshape3[1] = 1;
+      if (in_shape->at(softmaxout_enum::kLabel).ndim() == 0) {
+        in_shape->at(softmaxout_enum::kLabel) = lshape1;
+      } else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) {
+      } else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) {
+      } else if (in_shape->at(softmaxout_enum::kLabel) == lshape3) {
+      } else {
+        std::ostringstream os;
+        os << "Expecting " << lshape1 << " or " << lshape2
+           << ". But got " << in_shape->at(softmaxout_enum::kLabel);
+        throw InferShapeError(os.str(), softmaxout_enum::kLabel);
+      }
+    } else {
+      TShape label_shape(dshape.ndim() - 1);
+      for (index_t i = 0; i + 1 < dshape.ndim(); ++i)
+        label_shape[i] = dshape[i];
+      SHAPE_ASSIGN_CHECK(*in_shape, softmaxout_enum::kLabel, label_shape);
+    }
+  }
+
+  out_shape->clear();
+  out_shape->push_back(dshape);
+  return true;
+}
+
+#if MXNET_USE_MKLDNN == 1
+inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs,
+                                            const int dev_mask,
+                                            DispatchMode* dispatch_mode,
+                                            std::vector<int>* in_attrs,
+                                            std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2);
+  CHECK_EQ(out_attrs->size(), 1);
+
+  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
+                           out_attrs);
+}
+
+void SoftmaxOutputComputeExCPU(const nnvm::NodeAttrs &attrs,
+                               const OpContext &ctx,
+                               const std::vector<NDArray> &inputs,
+                               const std::vector<OpReqType> &req,
+                               const std::vector<NDArray> &outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  const SoftmaxOutputParam &param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
+  if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmaxOutput(param)) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNSoftmaxOutputForward(attrs, ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(SoftmaxOutputCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
+  }
+  FallBackCompute(SoftmaxOutputCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+#endif
+
+NNVM_REGISTER_OP(SoftmaxOutput)
 .describe(R"code(Computes the gradient of cross entropy loss with respect to softmax output.
 
 - This operator computes the gradient in two steps.
@@ -121,23 +228,41 @@ MXNET_REGISTER_OP_PROPERTY(SoftmaxOutput, SoftmaxOutputProp)
     - ``'valid'``: divide the gradient by the number of instances which are not ignored.
 
 )code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SoftmaxOutputParam>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FInferStorageType>("FInferStorageType", SoftmaxOutputStorageType)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxOutputComputeExCPU)
+#endif
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"data", "label"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output"};
+})
+.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxOutputShape)
+.set_attr<nnvm::FInferType>("FInferType", SoftmaxOutputType)
+.set_attr<FCompute>("FCompute<cpu>", SoftmaxOutputCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", SoftmaxOutputGrad{"_backward_SoftmaxOutput"})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
+  return std::vector<std::pair<int, int> >{{0, 0}};
+})
 .add_argument("data", "NDArray-or-Symbol", "Input array.")
 .add_argument("label", "NDArray-or-Symbol", "Ground truth label.")
 .add_arguments(SoftmaxOutputParam::__FIELDS__());
 
+// Softmax symbol is renamed to SoftmaxOutput and deprecated since Dec, 2015
+NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax");
 
-MXNET_REGISTER_OP_PROPERTY(Softmax, DeprecatedSoftmaxProp)
-.describe(R"code(Please use `SoftmaxOutput`.
-
-.. note::
-
-  This operator has been renamed to `SoftmaxOutput`, which
-  computes the gradient of cross-entropy loss w.r.t softmax output.
-  To just compute softmax output, use the `softmax` operator.
-
-)code" ADD_FILELINE)
-.add_argument("data", "NDArray-or-Symbol", "Input array.")
-.add_arguments(SoftmaxOutputParam::__FIELDS__());
-
+NNVM_REGISTER_OP(_backward_SoftmaxOutput)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
+  return std::vector<std::pair<int, int> >{{0, 0}};
+})
+.set_attr_parser(ParamParser<SoftmaxOutputParam>)
+.set_attr<FCompute>("FCompute<cpu>", SoftmaxOutputGradCompute<cpu>);
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/softmax_output.cu b/src/operator/softmax_output.cu
index afcc8f4..b2a4167 100644
--- a/src/operator/softmax_output.cu
+++ b/src/operator/softmax_output.cu
@@ -28,14 +28,12 @@
 
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<gpu>(SoftmaxOutputParam param, int dtype) {
-  Operator *op = NULL;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new SoftmaxOutputOp<gpu, DType>(param);
-  })
-  return op;
-}
+
+NNVM_REGISTER_OP(SoftmaxOutput)
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxOutputCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_SoftmaxOutput)
+.set_attr<FCompute>("FCompute<gpu>", SoftmaxOutputGradCompute<gpu>);
 
 }  // namespace op
 }  // namespace mxnet