You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/26 05:33:25 UTC

[GitHub] eric-haibin-lin closed pull request #11611: MKLDNN Forward FullyConnected op cache

eric-haibin-lin closed pull request #11611: MKLDNN Forward FullyConnected  op cache
URL: https://github.com/apache/incubator-mxnet/pull/11611
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index 7eba2e20e57..bff58266118 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -60,6 +60,11 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
     DMLC_DECLARE_FIELD(flatten).set_default(true)
     .describe("Whether to collapse all but the first axis of the input data tensor.");
   }
+  bool operator==(const FullyConnectedParam& other) const {
+    return this->num_hidden == other.num_hidden &&
+           this->no_bias == other.no_bias &&
+           this->flatten == other.flatten;
+  }
 };
 
 template<typename xpu, typename DType>
@@ -227,4 +232,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
 
 }  // namespace op
 }  // namespace mxnet
+namespace std {
+template<>
+struct hash<mxnet::op::FullyConnectedParam> {
+  size_t operator()(const mxnet::op::FullyConnectedParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.num_hidden);
+    ret = dmlc::HashCombine(ret, val.no_bias);
+    ret = dmlc::HashCombine(ret, val.flatten);
+    return ret;
+  }
+};
+}  // namespace std
 #endif  // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index f86f8dbefa2..5f672cd51fd 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -82,6 +82,100 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei
   }
 }
 
+class MKLDNNFullyConnectForward {
+  std::shared_ptr<mkldnn::memory> data;
+  std::shared_ptr<mkldnn::memory> weight;
+  std::shared_ptr<mkldnn::memory> out;
+  std::shared_ptr<mkldnn::memory> bias;
+  std::shared_ptr<mkldnn::inner_product_forward> ipFwd;
+
+ public:
+  mkldnn::inner_product_forward::primitive_desc ipFwd_pd;
+
+  MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
+                            const NDArray &data, const NDArray &weight,
+                            const NDArray *bias,
+                            const mkldnn::memory::desc &output)
+      : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}
+
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+                 const mkldnn::memory *bias, const mkldnn::memory &output) {
+    if (this->data == nullptr)
+      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
+    else
+      this->data->set_data_handle(data.get_data_handle());
+
+    if (this->weight == nullptr)
+      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+    else
+      this->weight->set_data_handle(weight.get_data_handle());
+
+    if (this->out == nullptr)
+      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
+    else
+      this->out->set_data_handle(output.get_data_handle());
+
+    if (bias != nullptr) {
+      if (this->bias == nullptr)
+        this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+        ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
+      else
+        this->bias->set_data_handle(bias->get_data_handle());
+      if (this->ipFwd == nullptr)
+        this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+            new mkldnn::inner_product_forward(
+                ipFwd_pd, mkldnn::primitive::at(*this->data),
+                mkldnn::primitive::at(*this->weight),
+                mkldnn::primitive::at(*this->bias), *this->out));
+    } else if (this->ipFwd == nullptr) {
+      this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+          new mkldnn::inner_product_forward(
+              ipFwd_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->weight), *this->out));
+    }
+  }
+  const mkldnn::inner_product_forward &GetIpFwd() const {
+    return *ipFwd;
+  }
+};
+
+typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+
+static inline MKLDNNFullyConnectForward &GetFCFwd(
+    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
+    const NDArray *bias, const mkldnn::memory::desc &output,
+    const bool is_train) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<MKLDNNFullyconSignature,
+              MKLDNNFullyConnectForward, OpHash> fcFwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
+              MKLDNNFullyConnectForward, OpHash> fcFwds;
+#endif
+  const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+  MKLDNNFullyconSignature key(param);
+  key.AddSign(data);
+  key.AddSign(weight);
+  key.AddSign(is_train);
+
+  if (bias)
+    key.AddSign(*bias);
+
+  auto it = fcFwds.find(key);
+  if (it == fcFwds.end()) {
+    MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
+                                    output);
+    auto ins_ret = fcFwds.insert(
+        std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                      const std::vector<NDArray> &in_data,
                      const std::vector<OpReqType> &req,
@@ -112,21 +206,21 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
     out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
       mkldnn::memory::format::any);
   }
-
-  mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
-      param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
-  auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
-  auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
+  MKLDNNFullyConnectForward &FCFwd =
+      GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
+               out_md, ctx.is_train);
+  auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
+  auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
-      ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
-  if (param.no_bias) {
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
-          ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
+      FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
+  if (!param.no_bias) {
+    auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
+        FCFwd.ipFwd_pd.bias_primitive_desc());
+    FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
   } else {
-    auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
-    MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
-          *data_mem, *weight_mem, *bias_mem, *out_mem.second));
+    FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
   }
+  MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
   CommitOutput(out_data[fullc::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services