You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/26 05:33:40 UTC

[incubator-mxnet] branch master updated: MKLDNN Forward FullyConnected op cache (#11611)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7230bb9  MKLDNN Forward FullyConnected  op cache (#11611)
7230bb9 is described below

commit 7230bb9b5f2f8caabf7bb64689e49ef6a5529b66
Author: zhiyuan-huang <hu...@163.com>
AuthorDate: Sun Aug 26 13:33:23 2018 +0800

    MKLDNN Forward FullyConnected  op cache (#11611)
    
    * Enable primitive allocation cache for FullyConnected
    
    * Enable primitive allocation cache for FullyConnected
    
    * fix indent and pass in_data as last argument for CreateMKLDNNMem
    
    * fix indent and pass in_data as last argument for CreateMKLDNNMem
---
 src/operator/nn/fully_connected-inl.h            |  17 ++++
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 118 ++++++++++++++++++++---
 2 files changed, 123 insertions(+), 12 deletions(-)

diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index 2338f89..2b75419 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -61,6 +61,11 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
     DMLC_DECLARE_FIELD(flatten).set_default(true)
     .describe("Whether to collapse all but the first axis of the input data tensor.");
   }
+  bool operator==(const FullyConnectedParam& other) const {
+    return this->num_hidden == other.num_hidden &&
+           this->no_bias == other.no_bias &&
+           this->flatten == other.flatten;
+  }
 };
 
 template<typename xpu, typename DType>
@@ -228,4 +233,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
 
 }  // namespace op
 }  // namespace mxnet
+namespace std {
+template<>
+struct hash<mxnet::op::FullyConnectedParam> {
+  size_t operator()(const mxnet::op::FullyConnectedParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.num_hidden);
+    ret = dmlc::HashCombine(ret, val.no_bias);
+    ret = dmlc::HashCombine(ret, val.flatten);
+    return ret;
+  }
+};
+}  // namespace std
 #endif  // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index f86f8db..5f672cd 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -82,6 +82,100 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei
   }
 }
 
+class MKLDNNFullyConnectForward {
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> weight;
+  std::shared_ptr<mkldnn::memory> out;
+  std::shared_ptr<mkldnn::memory> bias;
+  std::shared_ptr<mkldnn::inner_product_forward> ipFwd;
+
+ public:
+  mkldnn::inner_product_forward::primitive_desc ipFwd_pd;
+
+  MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
+                            const NDArray &data, const NDArray &weight,
+                            const NDArray *bias,
+                            const mkldnn::memory::desc &output)
+      : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}
+
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+                 const mkldnn::memory *bias, const mkldnn::memory &output) {
+    if (this->data == nullptr)
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
+    else
+      this->data->set_data_handle(data.get_data_handle());
+
+    if (this->weight == nullptr)
+      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+    else
+      this->weight->set_data_handle(weight.get_data_handle());
+
+    if (this->out == nullptr)
+      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
+    else
+      this->out->set_data_handle(output.get_data_handle());
+
+    if (bias != nullptr) {
+      if (this->bias == nullptr)
+        this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
+      else
+        this->bias->set_data_handle(bias->get_data_handle());
+      if (this->ipFwd == nullptr)
+        this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+            new mkldnn::inner_product_forward(
+                ipFwd_pd, mkldnn::primitive::at(*this->data),
+                mkldnn::primitive::at(*this->weight),
+                mkldnn::primitive::at(*this->bias), *this->out));
+    } else if (this->ipFwd == nullptr) {
+      this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+          new mkldnn::inner_product_forward(
+              ipFwd_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->weight), *this->out));
+    }
+  }
+  const mkldnn::inner_product_forward &GetIpFwd() const {
+    return *ipFwd;
+  }
+};
+
+typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+
+static inline MKLDNNFullyConnectForward &GetFCFwd(
+    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
+    const NDArray *bias, const mkldnn::memory::desc &output,
+    const bool is_train) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNFullyconSignature,
+              MKLDNNFullyConnectForward, OpHash> fcFwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
+              MKLDNNFullyConnectForward, OpHash> fcFwds;
+#endif
+  const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+  MKLDNNFullyconSignature key(param);
+  key.AddSign(data);
+  key.AddSign(weight);
+  key.AddSign(is_train);
+
+  if (bias)
+    key.AddSign(*bias);
+
+  auto it = fcFwds.find(key);
+  if (it == fcFwds.end()) {
+    MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
+                                    output);
+    auto ins_ret = fcFwds.insert(
+        std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                      const std::vector<NDArray> &in_data,
                      const std::vector<OpReqType> &req,
@@ -112,21 +206,21 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
     out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
       mkldnn::memory::format::any);
   }
-
-  mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
-      param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
-  auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
-  auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
+  MKLDNNFullyConnectForward &FCFwd =
+      GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
+               out_md, ctx.is_train);
+  auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
+  auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
-      ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
-  if (param.no_bias) {
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
-          ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
+      FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
+  if (!param.no_bias) {
+    auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
+        FCFwd.ipFwd_pd.bias_primitive_desc());
+    FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
   } else {
-    auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
-          *data_mem, *weight_mem, *bias_mem, *out_mem.second));
+    FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
   }
+  MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
   CommitOutput(out_data[fullc::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }