You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/09/30 06:08:07 UTC

[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add MKL-DNN FC (#16221)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new 23093e6  [mkldnn-v1.0] Add MKL-DNN FC (#16221)
23093e6 is described below

commit 23093e6f774f6b1ce94007a267c32b2253b3cbcd
Author: rongzha1 <ro...@intel.com>
AuthorDate: Mon Sep 30 14:07:13 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN FC (#16221)
    
    * add mkldnn fc; pass lint; pass mnist training
    
    * add TODO info for future debug
---
 src/imperative/imperative_utils.h                  |  12 +-
 src/operator/nn/fully_connected.cc                 |  12 +-
 .../nn/mkldnn/mkldnn_fully_connected-inl.h         |  15 +--
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc   | 128 ++++++++-------------
 src/operator/nn/mkldnn/mkldnn_ops-inl.h            |  20 ++--
 5 files changed, 77 insertions(+), 110 deletions(-)

diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index bc61bc7..f0c199b 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -541,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state,
         // copying A to B may not happen, and will corrupt A's memory.
         InvalidateOutputs(outputs, req);
       }
+      // add for mkldnn OP + no mkldnn OP
+      const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+      if (!is_mkldnn.get(attrs.op, false)) {
+        std::vector<NDArray> inputs_fallback;
+        CreateDefaultInputs(inputs, &inputs_fallback);
+        fcompute_ex(state, opctx, inputs_fallback, req, outputs);
+      } else {
+#endif
+        fcompute_ex(state, opctx, inputs, req, outputs);
+#if MXNET_USE_MKLDNN == 100
+      }
 #endif
-      fcompute_ex(state, opctx, inputs, req, outputs);
       if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
           && rctx.get_stream<gpu>() && !rctx.is_bulk) {
         rctx.get_stream<gpu>()->Wait();
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 06ad6d0..c80c08e 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -97,7 +97,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
     valid_bias = inputs[2].storage_type() == kDefaultStorage ||
                  inputs[2].storage_type() == kRowSparseStorage;
   }
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
       common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
     if (SupportMKLDNNFC(inputs[0])) {
@@ -141,7 +141,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
 #endif
 }
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                     const OpContext &ctx,
                                     const std::vector<NDArray> &inputs,
@@ -199,7 +199,7 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFComputeEx);
   }
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   if (!MKLDNNEnvSet())
     *dispatch_mode = DispatchMode::kFComputeFallback;
 #endif
@@ -233,7 +233,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
   }
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   if (!MKLDNNEnvSet())
     *dispatch_mode = DispatchMode::kFComputeFallback;
 #endif
@@ -295,7 +295,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
     [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"output"};
 })
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -326,7 +326,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
 })
 .set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
 .set_attr_parser(ParamParser<FullyConnectedParam>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index fddaedc..db8cfdc 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -27,7 +27,7 @@
 #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include <vector>
 #include <string>
@@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
     DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
     .describe("Whether to enable float32 output");
     DMLC_DECLARE_FIELD(with_eltwise).set_default(false)
-    .describe("Whether there's a post elemwise after FullyConnected operator");
+    .describe("Whether there's a post with_eltwise after FullyConnected operator");
     DMLC_DECLARE_FIELD(min_calib_range)
     .set_default(dmlc::optional<float>())
     .describe("The minimum scalar value in the form of float32 obtained "
@@ -85,10 +85,9 @@ class MKLDNNFullyConnectedForward {
                               const NDArray &data, const NDArray &weight,
                               const NDArray *bias,
                               const mkldnn::memory::desc &out_md)
-      : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}
-
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output);
+      : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {
+          fwd_ = std::make_shared<mkldnn::inner_product_forward>(fwd_pd);
+      }
 
   const mkldnn::inner_product_forward &GetFwd() const {
     return *fwd_;
@@ -96,10 +95,6 @@ class MKLDNNFullyConnectedForward {
 
  private:
   std::shared_ptr<mkldnn::inner_product_forward> fwd_;
-  std::shared_ptr<mkldnn::memory> data_;
-  std::shared_ptr<mkldnn::memory> weight_;
-  std::shared_ptr<mkldnn::memory> bias_;
-  std::shared_ptr<mkldnn::memory> out_;
 };
 
 typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index fbe37e2..80eb2d6 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -24,7 +24,7 @@
  * \author Da Zheng, Ciyong Chen
 */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "mkldnn_fully_connected-inl.h"
 
 namespace mxnet {
@@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
       }
 
       attr.set_output_scales(mask, scales);
-      attr.set_int_output_round_mode(round_nearest);
     }
   }
 
@@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
   }
 }
 
-void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
-                                            const mkldnn::memory &weight,
-                                            const mkldnn::memory *bias,
-                                            const mkldnn::memory &output) {
-  if (this->data_ == nullptr)
-    this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-            fwd_pd.src_primitive_desc(), data.get_data_handle()));
-  else
-    this->data_->set_data_handle(data.get_data_handle());
-
-  if (this->weight_ == nullptr)
-    this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-            fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
-  else
-    this->weight_->set_data_handle(weight.get_data_handle());
-
-  if (this->out_ == nullptr)
-    this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-            fwd_pd.dst_primitive_desc(), output.get_data_handle()));
-  else
-    this->out_->set_data_handle(output.get_data_handle());
-
-  if (bias != nullptr) {
-    if (this->bias_ == nullptr)
-      this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-      fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
-    else
-      this->bias_->set_data_handle(bias->get_data_handle());
-
-    if (this->fwd_ == nullptr)
-      this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
-          new mkldnn::inner_product_forward(
-              fwd_pd, mkldnn::primitive::at(*this->data_),
-              mkldnn::primitive::at(*this->weight_),
-              mkldnn::primitive::at(*this->bias_), *this->out_));
-  } else {
-     if (this->fwd_ == nullptr) {
-      this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
-          new mkldnn::inner_product_forward(
-              fwd_pd, mkldnn::primitive::at(*this->data_),
-              mkldnn::primitive::at(*this->weight_), *this->out_));
-    }
-  }
-}
-
 MKLDNNFullyConnectedForward &GetFCFwd(
     const FullyConnectedParam &param, const bool is_train,
     const NDArray &data, const NDArray &weight,
@@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam &param,
       mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
         static_cast<int>(oshape[ishape.ndim()-1])};
       *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
-        mkldnn::memory::format::any);
+                                     mkldnn::memory::format_tag::any);
     } else {
       *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
       mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
         static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
       *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
-        mkldnn::memory::format::any);
+                                     mkldnn::memory::format_tag::any);
     }
   }
 }
@@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
   NDArray weight = in_data[fullc::kWeight];
   NDArray data = in_data[fullc::kData];
 
-  auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc());
+  auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
   const mkldnn::memory *weight_mem;
   if (ctx.is_train) {
     if (weight.IsMKLDNNData()) {
       weight.Reorder2DefaultAsync();
     }
-    weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
+    weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
   } else {
-    if (weight.IsDefaultData()) {
-      // We also need to modify the layout on the original weight array.
-      // Don't switch below sequence because naive engine will executes
-      // pushAsync synchronously.
-      weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
-      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
-    } else {
-      weight_mem = weight.GetMKLDNNData();
-      CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
+    weight_mem = weight.GetMKLDNNData();
+    if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
+      // TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
+      // weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
+      weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
     }
   }
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
-      fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
+                                 fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);
+
+  std::unordered_map<int, mkldnn::memory> args = {
+      {MKLDNN_ARG_SRC, *data_mem},
+      {MKLDNN_ARG_WEIGHTS, *weight_mem},
+      {MKLDNN_ARG_DST, *out_mem.second},
+  };
   if (!full_param.default_param.no_bias) {
     auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
-        fwd->fwd_pd.bias_primitive_desc());
-    fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
-  } else {
-    fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
+        fwd->fwd_pd.bias_desc());
+    args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
   }
-  MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
+  MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
   CommitOutput(out_data[fullc::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }
@@ -339,13 +293,18 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
     mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
         data, weight, out_grad, fwd_pd);
     auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
-        ipBwdData_pd.diff_dst_primitive_desc());
-    auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
+        ipBwdData_pd.diff_dst_desc());
+    auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
     auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
-                                       ipBwdData_pd.diff_src_primitive_desc(),
+                                       ipBwdData_pd.diff_src_desc(),
                                        req[fullc::kData]);
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
-          ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
+    std::unordered_map<int, mkldnn::memory> args = {
+      {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
+      {MKLDNN_ARG_WEIGHTS, *weight_mem},
+      {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
+    };
+
+    MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
     CommitOutput(in_grad[fullc::kData], in_grad_mem);
   }
   if (req[fullc::kWeight]) {
@@ -353,23 +312,26 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
       = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
           out_grad, fwd_pd);
     auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
-        ipBwdWeights_pd.diff_dst_primitive_desc());
-    auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
+        ipBwdWeights_pd.diff_dst_desc());
+    auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
     auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
-                                                 ipBwdWeights_pd.diff_weights_primitive_desc(),
+                                                 ipBwdWeights_pd.diff_weights_desc(),
                                                  req[fullc::kWeight]);
+    std::unordered_map<int, mkldnn::memory> args = {
+      {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
+      {MKLDNN_ARG_SRC, *data_mem},
+      {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
+    };
+
     mkldnn_output_t in_grad_bias;
-    if (param.no_bias) {
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
-            ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
-    } else {
+    if (!param.no_bias) {
       in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
-                                     ipBwdWeights_pd.diff_bias_primitive_desc(),
+                                     ipBwdWeights_pd.diff_bias_desc(),
                                      req[fullc::kBias]);
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
-            ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
-            *in_grad_bias.second));
+      args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
     }
+    MKLDNNStream::Get()->RegisterPrimArgs(
+        mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
     CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
     CommitOutput(in_grad[fullc::kBias], in_grad_bias);
   }
@@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
 
 }  // namespace op
 }  // namespace mxnet
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 951b075..20d80cd 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -44,16 +44,6 @@ namespace mxnet {
 namespace op {
 
 #if MXNET_USE_MKLDNN == 1
-/* For fully connected. */
-void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                     const std::vector<NDArray> &in_data,
-                     const std::vector<OpReqType> &req,
-                     const std::vector<NDArray> &out_data);
-void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                      const std::vector<NDArray> &inputs,
-                      const std::vector<OpReqType> &req,
-                      const std::vector<NDArray> &outputs);
-
 /* For deconvolution */
 void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                                 const std::vector<NDArray> &in_data,
@@ -104,6 +94,16 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
 #endif
 
 #if MXNET_USE_MKLDNN == 100
+/* For fully connected. */
+void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                     const std::vector<NDArray> &in_data,
+                     const std::vector<OpReqType> &req,
+                     const std::vector<NDArray> &out_data);
+void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                      const std::vector<NDArray> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<NDArray> &outputs);
+
 /* For convolution. */
 void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                               const std::vector<NDArray> &in_data,