You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/05/22 00:01:18 UTC

[incubator-mxnet] branch v1.5.x updated (c5265fb -> e95b551)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a change to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git.


    omit c5265fb  Add primitive cache for MKL-DNN sum(elemwise_add operator (#14914)
     new e95b551  Add primitive cache for MKL-DNN sum(elemwise_add operator (#14914)

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (c5265fb)
            \
             N -- N -- N   refs/heads/v1.5.x (e95b551)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:


[incubator-mxnet] 01/01: Add primitive cache for MKL-DNN sum(elemwise_add operator (#14914)

Posted by ha...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit e95b551a69a54aac39b7c4b8d6e8423b1a782c2a
Author: ciyong <ci...@intel.com>
AuthorDate: Thu May 16 06:30:06 2019 +0800

    Add primitive cache for MKL-DNN sum(elemwise_add operator (#14914)
    
    * Add primitive cache for mkldnn sum
    
    * fix cpp test failure
---
 src/operator/nn/mkldnn/mkldnn_sum.cc            | 105 ++++++++++++++++++++----
 src/operator/tensor/elemwise_binary_op_basic.cc |   8 +-
 2 files changed, 94 insertions(+), 19 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc
index dfb0e25..724b8a2 100644
--- a/src/operator/nn/mkldnn/mkldnn_sum.cc
+++ b/src/operator/nn/mkldnn/mkldnn_sum.cc
@@ -24,6 +24,7 @@
 */
 #include <iostream>
 
+#include "../../operator_common.h"
 #include "./mkldnn_ops-inl.h"
 #include "./mkldnn_base-inl.h"
 
@@ -58,37 +59,105 @@ void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
   MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out));
 }
 
+class MKLDNNSumFwd {
+ public:
+  mkldnn::sum::primitive_desc fwd_pd;
+
+  MKLDNNSumFwd(const std::vector<float> &scales,
+               const std::vector<mkldnn::memory::primitive_desc> &data_md)
+      : fwd_pd(scales, data_md) {
+    data_.resize(data_md.size());
+  }
+
+  void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);
+
+  const mkldnn::sum &GetFwd() const { return *fwd_; }
+
+ private:
+  std::shared_ptr<mkldnn::sum> fwd_;
+  std::vector<std::shared_ptr<mkldnn::memory>> data_;
+  std::vector<mkldnn::primitive::at> data_mem_;
+  std::shared_ptr<mkldnn::memory> out_;
+};
+
+static MKLDNNSumFwd &GetSumForward(
+    const std::vector<float> &scales, const std::vector<NDArray> &in_data,
+    const std::vector<mkldnn::memory::primitive_desc> &data_md) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
+#endif
+  OpSignature key;
+  key.AddSign(in_data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNSumFwd fwd(scales, data_md);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+void MKLDNNSumFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
+                             const mkldnn::memory &output) {
+  auto num_inputs = data_.size();
+  CHECK_EQ(in_data.size(), num_inputs);
+  for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
+    if (this->data_[i] == nullptr) {
+      this->data_[i] = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
+      this->data_mem_.push_back(*this->data_[i]);
+    } else {
+      this->data_[i]->set_data_handle(in_data[i]->get_data_handle());
+    }
+  }
+  if (this->out_ == nullptr)
+    this->out_ = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+  else
+    this->out_->set_data_handle(output.get_data_handle());
+
+  if (this->fwd_ == nullptr)
+    this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_));
+}
+
 void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                       const std::vector<NDArray> &inputs, const OpReqType &req,
                       const NDArray &out_data) {
-  if (req == kNullOp) {
-    return;
-  }
-
   TmpMemMgr::Get()->Init(ctx.requested[0]);
-  std::vector<mkldnn::primitive::at> in_prims;
-  std::vector<mkldnn::memory::primitive_desc> in_pds(inputs.size());
-  std::vector<float> scales(inputs.size(), 1);
-  in_prims.reserve(inputs.size());
-  std::vector<NDArray> in_bufs(inputs.size());
-  for (size_t i = 0; i < inputs.size(); i++) {
+  auto num_inputs = inputs.size();
+  std::vector<mkldnn::memory::primitive_desc> data_md;
+  std::vector<const mkldnn::memory *> data_mem;
+  std::vector<float> scales(num_inputs, 1);
+  std::vector<NDArray> in_bufs(num_inputs);
+
+  data_md.reserve(num_inputs);
+  data_mem.reserve(num_inputs);
+
+  for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
     const mkldnn::memory *in_mem;
     if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
       in_bufs[i] = inputs[i].Reorder2Default();
       in_mem = in_bufs[i].GetMKLDNNData();
     } else {
+      in_bufs[i] = inputs[i];
       in_mem = inputs[i].GetMKLDNNData();
     }
-    in_prims.push_back(*in_mem);
-    in_pds[i] = in_mem->get_primitive_desc();
+    mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc();
+    data_md.push_back(tmp_pd);
+    data_mem.push_back(in_mem);
   }
 
-  mkldnn::sum::primitive_desc pdesc(scales, in_pds);
-  auto mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req, &inputs[0]);
-  MKLDNNStream *stream = MKLDNNStream::Get();
-  stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second));
-  CommitOutput(out_data, mem);
-  stream->Submit();
+  MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md);
+  mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data,
+                                                   fwd.fwd_pd.dst_primitive_desc(),
+                                                   req,
+                                                   &in_bufs[0]);
+  fwd.SetNewMem(data_mem, *out_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+  CommitOutput(out_data, out_mem);
+  MKLDNNStream::Get()->Submit();
 }
 
 }  // namespace op
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index 0ff73f4..c5e30c6 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -30,6 +30,12 @@
 namespace mxnet {
 namespace op {
 
+bool SupportMKLDNNSum(const NDArray& input) {
+  int ndim = input.shape().ndim();
+  return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) &&
+         input.storage_type() == kDefaultStorage;
+}
+
 static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
                           const OpContext& ctx,
                           const std::vector<NDArray>& inputs,
@@ -38,7 +44,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
 #if MXNET_USE_MKLDNN == 1
-  if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) {
+  if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) {
     MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]);
     return;
   } else if (inputs[0].storage_type() == kDefaultStorage