You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/03/22 05:53:36 UTC

[incubator-mxnet] branch master updated: [MXNET-105] Fix CuDNN performance after code refactor (#10116)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 46e47cb  [MXNET-105] Fix CuDNN performance after code refactor (#10116)
46e47cb is described below

commit 46e47cbc6183d2812a2e405851f0b209383e72ad
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Wed Mar 21 22:53:32 2018 -0700

    [MXNET-105] Fix CuDNN performance after code refactor (#10116)
    
    * Reduce #inputs/outputs of batchnorm backward.
    
    * Pass more arrays to BN.
    
    * Make std::vector thread local.
    
    * Set inputs of BN backward for other cases.
    
    * Fix for other cases.
    
    * remove commented code.
    
    * fix a potential mem leak.
    
    * Fix a compile error in mkldnn.
    
    * Fix an error.
    
    * reserve space for std::vector.
    
    * Fix alignment.
    
    * Fix cpp unit test.
    
    * Fix BN CPP unit tests.
    
    * Fix a compile error.
    
    * Fix compilation error.
    
    * Move Op signature.
    
    * Cache CuDNN conv op.
    
    * Fix compile error.
    
    * Fix compile error.
    
    * Remove thread_local.
    
    * Reduce mem alloc when caching cudnn conv.
    
    * Fix a lint error.
    
    * Cache CuDNN deconv.
    
    * Fix lint error.
---
 src/operator/nn/batch_norm-inl.h               |  47 +++++-----
 src/operator/nn/batch_norm.cc                  |  74 +++++++++++----
 src/operator/nn/batch_norm.cu                  |  18 ++--
 src/operator/nn/convolution-inl.h              |   2 +
 src/operator/nn/convolution.cu                 |  37 +++++++-
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h   |  51 +++++-----
 src/operator/nn/deconvolution-inl.h            |   2 +
 src/operator/nn/deconvolution.cu               |  32 ++++++-
 src/operator/nn/mkldnn/mkldnn_act.cc           |   4 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h       | 105 ---------------------
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h |   6 +-
 src/operator/nn/mkldnn/mkldnn_convolution.cc   |   4 +-
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc |   8 +-
 src/operator/nn/mkldnn/mkldnn_pooling-inl.h    |   2 +-
 src/operator/nn/mkldnn/mkldnn_pooling.cc       |   2 +-
 src/operator/operator_common.h                 | 124 +++++++++++++++++++++++++
 tests/cpp/include/test_core_op.h               |   7 +-
 tests/cpp/operator/batchnorm_test.cc           |  11 +--
 18 files changed, 319 insertions(+), 217 deletions(-)

diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 48638de..3f47d58 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -224,16 +224,25 @@ void BatchNormForward(const OpContext &ctx, const BatchNormParam& param,
  */
 template <typename xpu, typename DType, typename AccReal>
 void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param,
-                       const std::vector<TBlob> &out_grad,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &inputs,
                        const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &in_grad,
-                       const std::vector<TBlob> &aux_states) {
-  CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U);
-  CHECK_EQ(in_data.size(), 3U);
-  CHECK_EQ(out_data.size(), 3U);
-  CHECK_EQ(in_grad.size(), 3U);
+                       const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 8U);
+  CHECK_EQ(outputs.size(), 3U);
+  std::vector<TBlob> out_grad(1);
+  std::vector<TBlob> out_data(3);
+  std::vector<TBlob> in_data(3);
+  std::vector<TBlob> aux_states(2);
+
+  out_grad[0] = inputs[0];
+  out_data[batchnorm::kMean] = inputs[1];
+  out_data[batchnorm::kVar] = inputs[2];
+  in_data[batchnorm::kData] = inputs[3];
+  in_data[batchnorm::kGamma] = inputs[4];
+  in_data[batchnorm::kBeta] = inputs[5];
+  aux_states[batchnorm::kMovingMean] = inputs[6];
+  aux_states[batchnorm::kMovingVar] = inputs[7];
+  const std::vector<TBlob> &in_grad = outputs;
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   BatchNormBackwardImpl<xpu, DType, AccReal>(s, ctx, param, out_grad, in_data,
                                              out_data, req, in_grad, aux_states);
@@ -261,23 +270,11 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs,
                           const OpContext& ctx, const std::vector<TBlob>& inputs,
                           const std::vector<OpReqType>& req,
                           const std::vector<TBlob>& outputs) {
-  CHECK_EQ(inputs.size(), 11U);
+  CHECK_EQ(inputs.size(), 8U);
   const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
-  int num_out_grads = param.output_mean_var ? 3U : 1U;
-  int in_data_start = 3;
-  int aux_states_start = in_data_start + batchnorm::kInMovingMean;
-  int out_data_start = in_data_start + batchnorm::kInMovingVar + 1;
-  std::vector<TBlob> out_grad(inputs.begin(), inputs.begin() + num_out_grads);
-  std::vector<TBlob> in_data(inputs.begin() + in_data_start,
-                             inputs.begin() + aux_states_start);
-  std::vector<TBlob> aux_states(inputs.begin() + aux_states_start,
-                                inputs.begin() + out_data_start);
-  std::vector<TBlob> out_data(inputs.begin() + out_data_start, inputs.end());
-  std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);
-
-  MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, {
-    BatchNormBackward<xpu, DType, AccReal>(ctx, param, out_grad, in_data, out_data, req,
-                                           in_grad, aux_states);
+
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    BatchNormBackward<xpu, DType, AccReal>(ctx, param, inputs, req, outputs);
   });
 }
 
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index c8b5d58..457f536 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -413,24 +413,26 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
                                const std::vector<NDArray> &inputs,
                                const std::vector<OpReqType> &req,
                                const std::vector<NDArray> &outputs) {
-  CHECK_EQ(inputs.size(), 11U);
+  CHECK_EQ(inputs.size(), 8U);
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
-  int num_out_grads = param.output_mean_var ? 3U : 1U;
-  int in_data_start = 3;
-  int aux_states_start = in_data_start + batchnorm::kInMovingMean;
-  int out_data_start = in_data_start + batchnorm::kInMovingVar + 1;
 
   TShape shape = inputs[0].shape();
   // MKLDNN batchnorm only works well on the special MKLDNN layout.
   if (SupportMKLDNNBN(inputs[0], param)
-      && (inputs[in_data_start].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
-    std::vector<NDArray> out_grad(inputs.begin(), inputs.begin() + num_out_grads);
-    std::vector<NDArray> in_data(inputs.begin() + in_data_start,
-                                 inputs.begin() + aux_states_start);
-    std::vector<NDArray> aux_states(inputs.begin() + aux_states_start,
-                                    inputs.begin() + out_data_start);
-    std::vector<NDArray> out_data(inputs.begin() + out_data_start, inputs.end());
-    std::vector<NDArray> in_grad(outputs.begin(), outputs.begin() + 3);
+      && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
+    std::vector<NDArray> out_grad(1);
+    std::vector<NDArray> out_data(3);
+    std::vector<NDArray> in_data(3);
+    std::vector<NDArray> aux_states(2);
+    out_grad[0] = inputs[0];
+    out_data[batchnorm::kMean] = inputs[1];
+    out_data[batchnorm::kVar] = inputs[2];
+    in_data[batchnorm::kData] = inputs[3];
+    in_data[batchnorm::kGamma] = inputs[4];
+    in_data[batchnorm::kBeta] = inputs[5];
+    aux_states[batchnorm::kMovingMean] = inputs[6];
+    aux_states[batchnorm::kMovingVar] = inputs[7];
+    const std::vector<NDArray> &in_grad = outputs;
 
     if (inputs[0].dtype() == mshadow::kFloat32) {
       MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
@@ -470,8 +472,6 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
                                                  DispatchMode *dispatch_mode,
                                                  std::vector<int> *in_attrs,
                                                  std::vector<int> *out_attrs) {
-  CHECK_EQ(in_attrs->size(), 11);
-  CHECK_EQ(out_attrs->size(), 5);
   DispatchMode wanted_mode;
 #if MXNET_USE_MKLDNN == 1
   if (dev_mask == mshadow::cpu::kDevMask)
@@ -486,6 +486,46 @@ static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
                              dispatch_mode, wanted_mode);
 }
 
+std::vector<nnvm::NodeEntry> BatchNormGrad(const nnvm::NodePtr& n,
+                                           const std::vector<nnvm::NodeEntry>& ograds) {
+  std::vector<nnvm::NodeEntry> out_data(n->num_outputs());
+  for (uint32_t i = 0; i < out_data.size(); ++i) {
+    out_data[i] = nnvm::NodeEntry{n, i, 0};
+  }
+  std::vector<nnvm::NodeEntry> heads;
+  heads.reserve(8);
+  heads.push_back(ograds[0]);
+  heads.push_back(out_data[batchnorm::kMean]);
+  heads.push_back(out_data[batchnorm::kVar]);
+  heads.push_back(n->inputs[batchnorm::kData]);
+  heads.push_back(n->inputs[batchnorm::kGamma]);
+  heads.push_back(n->inputs[batchnorm::kBeta]);
+  heads.push_back(n->inputs[batchnorm::kInMovingMean]);
+  heads.push_back(n->inputs[batchnorm::kInMovingVar]);
+
+  nnvm::NodePtr gnode = nnvm::Node::Create();
+  gnode->inputs = std::move(heads);
+  gnode->control_deps.emplace_back(n);
+  gnode->attrs = n->attrs;
+  gnode->attrs.op = nnvm::Op::Get("_backward_BatchNorm");
+  gnode->attrs.name = n->attrs.name + "_backward";
+  // The input of batchnorm
+  std::vector<nnvm::NodeEntry> in_grad(5);
+  for (uint32_t i = 0; i < 3; ++i) {
+    in_grad[i] = nnvm::NodeEntry{gnode, i, 0};
+  }
+
+  // attach no gradient node to forbid gradient on aux_state
+  nnvm::NodePtr ng = nnvm::Node::Create();
+  ng->attrs.op = Op::Get("_NoGradient");
+  ng->attrs.name = "NoGradient";
+  // the aux state of batchnorm
+  for (uint32_t i = 0; i < 2; ++i) {
+    in_grad[i + 3] = nnvm::NodeEntry{ng, 0, 0};
+  }
+  return in_grad;
+}
+
 NNVM_REGISTER_OP(BatchNorm)
 .describe(R"code(Batch normalization.
 
@@ -559,7 +599,7 @@ then set ``gamma`` to 1 and its gradient to 0.
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
 #endif
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_BatchNorm"})
+.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -583,7 +623,7 @@ then set ``gamma`` to 1 and its gradient to 0.
   });
 
 NNVM_REGISTER_OP(_backward_BatchNorm)
-.set_num_outputs(5)
+.set_num_outputs(3)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", backward_BatchNormStorageType)
 #if MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index b8657fc..c310a93 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -690,13 +690,8 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
                                const OpContext& ctx, const std::vector<TBlob>& inputs,
                                const std::vector<OpReqType>& req,
                                const std::vector<TBlob>& outputs) {
-  CHECK_EQ(inputs.size(), 11U);
+  CHECK_EQ(inputs.size(), 8U);
   BatchNormParam param = nnvm::get<BatchNormParam>(attrs.parsed);
-  std::vector<TBlob> out_grad(1, inputs[0]);
-  std::vector<TBlob> in_data(inputs.begin() + 3, inputs.begin() + 6);
-  std::vector<TBlob> aux_states(inputs.begin() + 6, inputs.begin() + 8);
-  std::vector<TBlob> out_data(inputs.begin() + 8, inputs.end());
-  std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);
   int dtype = inputs[0].type_flag_;
   TShape shape = inputs[0].shape_;
 
@@ -705,19 +700,18 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
   if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
       && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
     MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-      GetCuDNNOp<DType>(param).Backward(ctx, out_grad, in_data, out_data,
-        req, in_grad, aux_states);
+      GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
     })
   } else {
     MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
-      BatchNormBackward<gpu, DType, AccReal>(ctx, param, out_grad,
-          in_data, out_data, req, in_grad, aux_states);
+      BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
     })
   }
 #else
+  aux_states[batchnorm::kMovingMean] = inputs[6];
+  aux_states[batchnorm::kMovingVar] = inputs[7];
   MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, {
-    BatchNormBackward<gpu, DType, AccReal>(ctx, param, out_grad,
-        in_data, out_data, req, in_grad, aux_states);
+    BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
   });
 #endif
 }
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index d0dd7dd..c98a010 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -124,6 +124,8 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
   }
 };
 
+typedef ParamOpSign<ConvolutionParam> ConvSignature;
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu
index d7f9e56..f6d14e3 100644
--- a/src/operator/nn/convolution.cu
+++ b/src/operator/nn/convolution.cu
@@ -41,13 +41,40 @@ static CuDNNConvolutionOp<DType> &GetCuDNNConvOp(const ConvolutionParam& param,
     const std::vector<TShape>& in_shape, const std::vector<TShape>& out_shape,
     const Context& ctx) {
 #if DMLC_CXX11_THREAD_LOCAL
-  static thread_local CuDNNConvolutionOp<DType> op;
+  static thread_local std::unordered_map<ConvSignature,
+                                         std::shared_ptr<CuDNNConvolutionOp<DType> >,
+                                         OpHash> ops;
 #else
-  static MX_THREAD_LOCAL CuDNNConvolutionOp<DType> op;
+  static MX_THREAD_LOCAL std::unordered_map<ConvSignature,
+                                            std::shared_ptr<CuDNNConvolutionOp<DType> >,
+                                            OpHash> ops;
 #endif
-  op.Init(param, forward_compute_type, backward_compute_type,
-      in_shape, out_shape, ctx);
-  return op;
+  ConvSignature key(param);
+  size_t ndim = 0;
+  for (auto &s : in_shape)
+    ndim += s.ndim();
+  for (auto &s : out_shape)
+    ndim += s.ndim();
+  key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */
+              + ndim + 1 /* for dev_id */);
+
+  key.AddSign(forward_compute_type);
+  key.AddSign(backward_compute_type);
+  key.AddSign(in_shape);
+  key.AddSign(out_shape);
+  key.AddSign(ctx.dev_id);
+
+  auto it = ops.find(key);
+  if (it == ops.end()) {
+    std::shared_ptr<CuDNNConvolutionOp<DType>> op(new CuDNNConvolutionOp<DType>());
+    auto ins_ret = ops.insert(std::pair<ConvSignature, std::shared_ptr<CuDNNConvolutionOp<DType>>>(
+                              key, op));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+    it->second->Init(param, forward_compute_type, backward_compute_type, in_shape,
+                     out_shape, ctx);
+  }
+  return *it->second;
 }
 #endif
 
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index e233704..e3d5dd9 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -67,10 +67,10 @@ class CuDNNBatchNormOp {
   }
 
   void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_states) {
+               const std::vector<TBlob> &in_data,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &out_data,
+               const std::vector<TBlob> &aux_states) {
     using namespace mshadow;
     using namespace mshadow::expr;
     CHECK_EQ(in_data.size(), 3U);
@@ -158,29 +158,30 @@ class CuDNNBatchNormOp {
   }
 
   void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
+                const std::vector<TBlob> &inputs,
+                const std::vector<OpReqType> &req,
+                const std::vector<TBlob> &outputs) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1U);
-    CHECK_EQ(in_data.size(), 3U);
-    CHECK_EQ(out_data.size(), 3U);
-    CHECK_EQ(in_grad.size(), 3U);
+    CHECK_EQ(inputs.size(), 8U);
+    CHECK_EQ(outputs.size(), 3U);
     CHECK(ctx.is_train && !param_.use_global_stats)
         << "use global statistics is not yet supported in CuDNNBatchNorm";
 
-    Init(in_data[cudnnbatchnorm::kData]);
+    // Rename the inputs and outputs.
+    const TBlob &out_grad = inputs[0];
+    const TBlob &out_mean = inputs[1];
+    const TBlob &out_var = inputs[2];
+    const TBlob &in_data = inputs[3];
+    const TBlob &in_gamma = inputs[4];
+    const std::vector<TBlob> &in_grad = outputs;
+
+    Init(in_data);
     Stream<gpu> *s = ctx.get_stream<gpu>();
-    Tensor<gpu, 4, DType> x =
-      in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
+    Tensor<gpu, 4, DType> x = in_data.get_with_shape<gpu, 4, DType>(shape_, s);
     Tensor<gpu, 4, DType> dx =
       in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
-    Tensor<gpu, 4, DType> dy =
-      out_grad[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, DType>(shape_, s);
+    Tensor<gpu, 4, DType> dy = out_grad.get_with_shape<gpu, 4, DType>(shape_, s);
 
 #if CUDNN_VERSION >= 4007
 #if CUDNN_VERSION >= 7002
@@ -190,15 +191,15 @@ class CuDNNBatchNormOp {
 #endif
     MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
       Tensor<gpu, 1, DTypeParam> gamma =
-        in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
+        in_gamma.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> dbeta =
         in_grad[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> dgamma =
         in_grad[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> save_mean =
-        out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
+        out_mean.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> save_inv_var =
-        out_data[cudnnbatchnorm::kInvVar].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
+        out_var.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
 
       typename DataType<DType>::ScaleType a = 1.0f;
       typename DataType<DType>::ScaleType b = 0.0f;
@@ -232,15 +233,15 @@ class CuDNNBatchNormOp {
 #else  // CUDNN_VERSION < 4007
     MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
       Tensor<gpu, 1, DTypeParam> gamma =
-        in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
+        in_gamma.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> dbeta =
         in_grad[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> dgamma =
         in_grad[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> save_mean =
-        out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
+        out_mean.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
       Tensor<gpu, 1, DTypeParam> save_inv_var =
-        out_data[cudnnbatchnorm::kInvVar].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
+        out_var.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
 
       typename DataType<DType>::ScaleType a = 1.0f;
       typename DataType<DType>::ScaleType b = 0.0f;
diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h
index badbb8b..b41ecf4 100644
--- a/src/operator/nn/deconvolution-inl.h
+++ b/src/operator/nn/deconvolution-inl.h
@@ -169,6 +169,8 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
   }
 };
 
+typedef ParamOpSign<DeconvolutionParam> DeconvSignature;
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu
index c739542..086b470 100644
--- a/src/operator/nn/deconvolution.cu
+++ b/src/operator/nn/deconvolution.cu
@@ -40,9 +40,35 @@ static CuDNNDeconvolutionOp<DType> &GetCuDNNDeconvOp(const DeconvolutionParam& p
                                                      const std::vector<TShape>& in_shape,
                                                      const std::vector<TShape>& out_shape,
                                                      const Context& ctx) {
-  static thread_local CuDNNDeconvolutionOp<DType> op;
-  op.Init(param, forward_compute_type, backward_compute_type, in_shape, out_shape, ctx);
-  return op;
+  static thread_local std::unordered_map<DeconvSignature,
+                                         std::shared_ptr<CuDNNDeconvolutionOp<DType> >,
+                                         OpHash> ops;
+  DeconvSignature key(param);
+  size_t ndim = 0;
+  for (auto &s : in_shape)
+    ndim += s.ndim();
+  for (auto &s : out_shape)
+    ndim += s.ndim();
+  key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */
+              + ndim + 1 /* for dev_id */);
+
+  key.AddSign(forward_compute_type);
+  key.AddSign(backward_compute_type);
+  key.AddSign(in_shape);
+  key.AddSign(out_shape);
+  key.AddSign(ctx.dev_id);
+
+  auto it = ops.find(key);
+  if (it == ops.end()) {
+    std::shared_ptr<CuDNNDeconvolutionOp<DType>> op(new CuDNNDeconvolutionOp<DType>());
+    auto ins_ret = ops.insert(
+            std::pair<DeconvSignature, std::shared_ptr<CuDNNDeconvolutionOp<DType>>>(key, op));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+    it->second->Init(param, forward_compute_type, backward_compute_type, in_shape,
+                     out_shape, ctx);
+  }
+  return *it->second;
 }
 #endif
 
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 71fdf4c..8c19850 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -93,7 +93,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
   return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
 }
 
-typedef MKLDNNParamOpSign<ActivationParam> MKLDNNActSignature;
+typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
 
 class MKLDNNActForward {
   std::shared_ptr<mkldnn::eltwise_forward> fwd;
@@ -137,7 +137,7 @@ class MKLDNNActForward {
 static MKLDNNActForward &GetActForward(const ActivationParam& param,
                                        const OpContext &ctx, const NDArray &in_data,
                                        const mkldnn::memory &in_mem) {
-  static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActForward, MKLDNNOpHash> fwds;
+  static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActForward, OpHash> fwds;
   MKLDNNActSignature key(param);
   key.AddSign(ctx.is_train);
   key.AddSign(param.act_type);
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 1c583e1..362f5fb 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -296,111 +296,6 @@ class MKLDNNStream {
   }
 };
 
-class MKLDNNOpSignature {
-  std::vector<int> eles;
-  uint64_t hash;
-
- public:
-  MKLDNNOpSignature() {
-    hash = 0;
-  }
-
-  explicit MKLDNNOpSignature(uint64_t hash) {
-    this->hash = hash;
-  }
-
-  /*
-   * We provide different methods to add signature to an op.
-   * For operations, such as convolutin and fully connected, which determines
-   * the optimal data layout for the op, we only need to use the shape and data
-   * type to sign the op. For other operations, such as activation, which uses
-   * whatever layout in the input array, we have to use the shape, the data type
-   * and the layout to sign the op.
-   */
-
-  void AddSign(const mkldnn::memory &mem) {
-    auto desc = mem.get_primitive_desc().desc();
-    hash = hash * 2 + desc.data.format;
-    eles.push_back(desc.data.format);
-    hash = hash * 2 + desc.data.data_type;
-    eles.push_back(desc.data.data_type);
-    for (int i = 0; i < desc.data.ndims; i++) {
-      hash = hash * 2 + desc.data.dims[i];
-      eles.push_back(desc.data.dims[i]);
-    }
-  }
-
-  void AddSign(const std::vector<NDArray> &arrs) {
-    for (auto &arr : arrs) {
-      AddSign(arr);
-    }
-  }
-
-  void AddSign(const NDArray &arr) {
-    if (arr.IsMKLDNNData()) {
-      AddSign(*(arr.GetMKLDNNData()));
-    } else {
-      hash = hash * 2 + arr.dtype();
-      eles.push_back(arr.dtype());
-      AddSign(arr.shape());
-    }
-  }
-
-  void AddSign(const TShape &shape) {
-    for (size_t i = 0; i < shape.ndim(); i++) {
-      hash = hash * 2 + shape[i];
-      eles.push_back(shape[i]);
-    }
-  }
-
-  void AddSign(int val) {
-    hash = hash * 2 + val;
-    eles.push_back(val);
-  }
-
-  bool operator==(const MKLDNNOpSignature &sign) const {
-    if (hash != sign.hash)
-      return false;
-    if (eles.size() != sign.eles.size())
-      return false;
-    for (size_t i = 0; i < eles.size(); i++)
-      if (eles[i] != sign.eles[i])
-        return false;
-    return true;
-  }
-
-  uint64_t GetHash() const {
-    return hash;
-  }
-};
-
-struct MKLDNNOpHash {
-  size_t operator()(const MKLDNNOpSignature &sign) const {
-    return sign.GetHash();
-  }
-};
-
-template<typename ParamType>
-class MKLDNNParamOpSign: public MKLDNNOpSignature {
-  const ParamType param;
-
-  static size_t hash(const ParamType &param) {
-    std::hash<ParamType> fn;
-    return fn(param);
-  }
-
- public:
-  explicit MKLDNNParamOpSign(const ParamType &_param): MKLDNNOpSignature(
-      hash(_param)), param(_param) {
-  }
-
-  bool operator==(const MKLDNNParamOpSign<ParamType> &sign) const {
-    const MKLDNNOpSignature &this_upper = *this;
-    const MKLDNNOpSignature &other_upper = sign;
-    return this_upper == other_upper && param == sign.param;
-  }
-};
-
 enum OutDataOp {
   Noop,
   CopyBack,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index a685ebf..16f9874 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -98,7 +98,7 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem,
   return t_bn_b_pdesc(bnBwd_desc, engine, _GetFwd(data_mem, true, eps, flags));
 }
 
-typedef MKLDNNParamOpSign<BatchNormParam> MKLDNNBNSignature;
+typedef ParamOpSign<BatchNormParam> MKLDNNBNSignature;
 
 class MKLDNNBNForward {
   std::shared_ptr<const mkldnn::memory> data_m;
@@ -184,7 +184,7 @@ template<typename DType>
 static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
                                      const OpContext &ctx, const NDArray &in_data,
                                      unsigned flags) {
-  static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNForward, MKLDNNOpHash> fwds;
+  static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNForward, OpHash> fwds;
   MKLDNNBNSignature key(param);
   key.AddSign(ctx.is_train);
   key.AddSign(in_data);
@@ -302,7 +302,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
                              const std::vector<NDArray>    &in_grad,
                              const std::vector<NDArray>    &aux_states) {
   TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
-  CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U);
+  CHECK_EQ(out_grad.size(), 1U);
   CHECK_EQ(in_data.size(), 3U);
   CHECK_EQ(out_data.size(), 3U);
   CHECK_EQ(in_grad.size(), 3U);
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 76efc24..453221f 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -226,13 +226,13 @@ class MKLDNNConvForward {
   }
 };
 
-typedef MKLDNNParamOpSign<ConvolutionParam> MKLDNNConvSignature;
+typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
 
 static inline MKLDNNConvForward &GetConvFwd(
     const nnvm::NodeAttrs& attrs, bool is_train,
     const NDArray &data, const NDArray &weights,
     const NDArray *bias, const NDArray &output) {
-  static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvForward, MKLDNNOpHash> fwds;
+  static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvForward, OpHash> fwds;
   const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
   MKLDNNConvSignature key(param);
   key.AddSign(is_train);
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index a0d3df7..af57b68 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -289,16 +289,14 @@ static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
   }
 }
 
-typedef MKLDNNParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;
-
 static inline MKLDNNDeconvForward &GetDeconvFwd(
     const nnvm::NodeAttrs& attrs, const NDArray &data,
     const NDArray &weights, const NDArray *bias,
     const NDArray &output) {
   static thread_local
-        std::unordered_map<MKLDNNDeconvSignature, MKLDNNDeconvForward, MKLDNNOpHash> fwds;
+        std::unordered_map<DeconvSignature, MKLDNNDeconvForward, OpHash> fwds;
   const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
-  MKLDNNDeconvSignature key(param);
+  DeconvSignature key(param);
   // Here we can sign the conv op with NDArray because conv primitive will
   // decide the right layout for the, so we only need to get the shape and the
   // data type of the arrays.
@@ -313,7 +311,7 @@ static inline MKLDNNDeconvForward &GetDeconvFwd(
     bool has_bias = (bias != nullptr);
     MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
     auto ins_ret = fwds.insert(
-        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvForward>(key, fwd));
+        std::pair<DeconvSignature, MKLDNNDeconvForward>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
index 61895b4..2097d57 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
@@ -104,7 +104,7 @@ inline bool MKLDNNRequireWorkspace(const PoolingParam &param) {
   return param.pool_type != pool_enum::kAvgPooling;
 }
 
-typedef MKLDNNParamOpSign<PoolingParam> MKLDNNPoolingSignature;
+typedef ParamOpSign<PoolingParam> MKLDNNPoolingSignature;
 void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
                           const NDArray &in_data, const OpReqType req,
                           const NDArray &out_data, const NDArray *workspace);
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 86f1314..1aeb7d4 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -188,7 +188,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
                                 const NDArray &output) {
   static thread_local std::unordered_map<MKLDNNPoolingSignature,
                                          MKLDNNPoolingFwd,
-                                         MKLDNNOpHash> pooling_fwds;
+                                         OpHash> pooling_fwds;
 
   bool with_workspace = is_train && MKLDNNRequireWorkspace(param);
   MKLDNNPoolingSignature key(param);
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 10581d1..a629ba5 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -489,6 +489,130 @@ inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
 }
 
+class OpSignature {
+  std::vector<int> eles;
+  uint64_t hash;
+
+ public:
+  OpSignature() {
+    hash = 0;
+  }
+
+  explicit OpSignature(uint64_t hash) {
+    this->hash = hash;
+  }
+
+  /*
+   * This is to reserve space for the vector.
+   */
+  void Reserve(size_t num) {
+    eles.reserve(num);
+  }
+
+  /*
+   * We provide different methods to add signature to an op.
+   * For operations, such as convolutin and fully connected, which determines
+   * the optimal data layout for the op, we only need to use the shape and data
+   * type to sign the op. For other operations, such as activation, which uses
+   * whatever layout in the input array, we have to use the shape, the data type
+   * and the layout to sign the op.
+   */
+
+#if MXNET_USE_MKLDNN == 1
+  void AddSign(const mkldnn::memory &mem) {
+    auto desc = mem.get_primitive_desc().desc();
+    hash = hash * 2 + desc.data.format;
+    eles.push_back(desc.data.format);
+    hash = hash * 2 + desc.data.data_type;
+    eles.push_back(desc.data.data_type);
+    for (int i = 0; i < desc.data.ndims; i++) {
+      hash = hash * 2 + desc.data.dims[i];
+      eles.push_back(desc.data.dims[i]);
+    }
+  }
+#endif
+
+  void AddSign(const std::vector<NDArray> &arrs) {
+    for (auto &arr : arrs) {
+      AddSign(arr);
+    }
+  }
+
+  void AddSign(const NDArray &arr) {
+#if MXNET_USE_MKLDNN == 1
+    if (arr.IsMKLDNNData()) {
+      AddSign(*(arr.GetMKLDNNData()));
+    } else {
+#endif
+      hash = hash * 2 + arr.dtype();
+      eles.push_back(arr.dtype());
+      AddSign(arr.shape());
+#if MXNET_USE_MKLDNN == 1
+    }
+#endif
+  }
+
+  void AddSign(const std::vector<TShape> &shapes) {
+    for (auto &shape : shapes) {
+      AddSign(shape);
+    }
+  }
+
+  void AddSign(const TShape &shape) {
+    for (size_t i = 0; i < shape.ndim(); i++) {
+      hash = hash * 2 + shape[i];
+      eles.push_back(shape[i]);
+    }
+  }
+
+  void AddSign(int val) {
+    hash = hash * 2 + val;
+    eles.push_back(val);
+  }
+
+  bool operator==(const OpSignature &sign) const {
+    if (hash != sign.hash)
+      return false;
+    if (eles.size() != sign.eles.size())
+      return false;
+    for (size_t i = 0; i < eles.size(); i++)
+      if (eles[i] != sign.eles[i])
+        return false;
+    return true;
+  }
+
+  uint64_t GetHash() const {
+    return hash;
+  }
+};
+
+struct OpHash {
+  size_t operator()(const OpSignature &sign) const {
+    return sign.GetHash();
+  }
+};
+
+template<typename ParamType>
+class ParamOpSign: public OpSignature {
+  const ParamType param;
+
+  static size_t hash(const ParamType &param) {
+    std::hash<ParamType> fn;
+    return fn(param);
+  }
+
+ public:
+  explicit ParamOpSign(const ParamType &_param): OpSignature(
+      hash(_param)), param(_param) {
+  }
+
+  bool operator==(const ParamOpSign<ParamType> &sign) const {
+    const OpSignature &this_upper = *this;
+    const OpSignature &other_upper = sign;
+    return this_upper == other_upper && param == sign.param;
+  }
+};
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_OPERATOR_COMMON_H_
diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h
index 63f5c91..7dc05fd 100644
--- a/tests/cpp/include/test_core_op.h
+++ b/tests/cpp/include/test_core_op.h
@@ -141,8 +141,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
     static auto gradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
     nnvm::FGradient grad_fun = gradient.get(op_, nullptr);
     if (grad_fun) {
-      std::vector<nnvm::NodeEntry> out_grads;
-      std::vector<nnvm::NodeEntry> entries = grad_fun(MakeNode(), out_grads);
+      auto n = MakeNode();
+      std::vector<nnvm::NodeEntry> out_grads(n->num_outputs());
+      std::vector<nnvm::NodeEntry> entries = grad_fun(n, out_grads);
       CHECK_GE(entries.size(), 1U);
       res.reserve(entries.size());
       for (const nnvm::NodeEntry& node_entry : entries) {
@@ -467,7 +468,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
             input_shapes_ = input_shapes;
             // BWD Output shapes
             output_shapes = backward_for_op->input_shapes_;
-            CHECK_EQ(output_shapes.size(), inferred_num_outputs);
+            output_shapes.resize(inferred_num_outputs);
           } else {
             output_shapes = input_shapes;
             output_shapes.resize(inferred_num_outputs);
diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc
index 4b08d98..2f9de74 100644
--- a/tests/cpp/operator/batchnorm_test.cc
+++ b/tests/cpp/operator/batchnorm_test.cc
@@ -77,10 +77,10 @@ enum ForwardOutputs {
  * \brief Backward
  */
 enum BackwardInputs {
-  /* out_grad */    bwd_out_grad_Grad, bwd_out_grad_Mean, bwd_out_grad_Var,
+  /* out_grad */    bwd_out_grad_Grad,
+  /* out_data */    bwd_out_data_Mean, bwd_out_data_Var,
   /* in_data */     bwd_in_data_Data, bwd_in_data_Gamma, bwd_in_data_Beta,
-  /* aux_states */  bwd_aux_states_MovingMean, bwd_aux_states_MovingVar,
-  /* in_grad */     bwd_out_data_Data, bwd_out_data_Mean, bwd_out_data_Var
+  /* aux_states */  bwd_aux_states_MovingMean, bwd_aux_states_MovingVar
 };
 enum BackwardOutputs {
   /* in_grad */     bwd_in_grad_Data /* Original input data */,
@@ -250,17 +250,12 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor<DType, AccReal> {
     test::try_fill(ctx().run_ctx, &GetBlob(bwd_aux_states_MovingMean), 0);
     test::try_fill(ctx().run_ctx, &GetBlob(bwd_aux_states_MovingVar), 1);
 
-    val = -.101;
-    test::patternFill(ctx().run_ctx, &GetBlob(bwd_out_data_Data), [&val]() -> double {
-      return val += 1; });
     test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_data_Mean), 0.0);
     test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_data_Var), 1.0);
 
     val = -.001;
     test::patternFill(ctx().run_ctx, &GetBlob(bwd_out_grad_Grad), [&val]() -> double {
       return val += 0.01; });
-    test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_grad_Mean), 0.0);
-    test::try_fill(ctx().run_ctx, &GetBlob(bwd_out_grad_Var), 1.0);
   }
 
   const bool hasWeightAndBias_;  // This will cause forward pass validation to fail

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.