You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/07/13 13:51:33 UTC

[incubator-mxnet] branch v1.x updated: [1.x][backport] Integrate matmul primitive from oneDNN in batch dot (#20382)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 7e835a5  [1.x][backport] Integrate matmul primitive from oneDNN in batch dot (#20382)
7e835a5 is described below

commit 7e835a592da902066f2ccb0d74a78adc8e77fdb7
Author: bartekkuncer <ba...@intel.com>
AuthorDate: Tue Jul 13 15:50:01 2021 +0200

    [1.x][backport] Integrate matmul primitive from oneDNN in batch dot (#20382)
---
 src/operator/nn/mkldnn/mkldnn_base-inl.h      |   1 +
 src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h |  65 +++++++++++++
 src/operator/nn/mkldnn/mkldnn_batch_dot.cc    | 132 ++++++++++++++++++++++++++
 src/operator/nn/mkldnn/mkldnn_ops-inl.h       |   6 ++
 src/operator/tensor/dot-inl.h                 |  18 +++-
 src/operator/tensor/dot.cc                    |  37 ++++++++
 6 files changed, 258 insertions(+), 1 deletion(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 3e73103..cb30b0b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -215,6 +215,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
                              const NDArray &output);
 bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
 bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
+bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs, const NDArray &output);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
new file mode 100644
index 0000000..02bd9ad
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_batch_dot-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <numeric>
+#include <utility>
+#include <vector>
+#include "../../tensor/dot-inl.h"
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using batch_dot_fwd_t = mkldnn::matmul;
+using batch_dot_fwd_pd_t = mkldnn::matmul::primitive_desc;
+
+typedef ParamOpSign<DotParam> BatchDotSignature;
+
+class MKLDNNBatchDotFwd {
+ public:
+  static MKLDNNBatchDotFwd &GetCached(const DotParam &param,
+                                      const std::vector<NDArray> &inputs,
+                                      const std::vector<NDArray> &outputs);
+
+  MKLDNNBatchDotFwd(const DotParam &param, const std::vector<NDArray> &inputs,
+                    const std::vector<NDArray> &outputs);
+
+  void Execute(const std::vector<NDArray> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<NDArray> &outputs);
+
+ private:
+  std::shared_ptr<batch_dot_fwd_t> fwd;
+  std::shared_ptr<batch_dot_fwd_pd_t> fwd_pd;
+};
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H__
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
new file mode 100644
index 0000000..1a5006d
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_batch_dot.cc
+ */
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "./mkldnn_batch_dot-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs,
+                           const NDArray &output) {
+  return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
+         output.shape().Size() != 0 &&
+         (inputs[0].dtype() == mshadow::kFloat32 ||
+          inputs[0].dtype() == mshadow::kBfloat16);
+}
+
+void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+                           const std::vector<NDArray> &inputs,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<NDArray> &outputs) {
+  const DotParam &param = nnvm::get<DotParam>(attrs.parsed);
+  MKLDNNBatchDotFwd &fwd = MKLDNNBatchDotFwd::GetCached(param, inputs, outputs);
+  fwd.Execute(inputs, req, outputs);
+}
+
+MKLDNNBatchDotFwd &MKLDNNBatchDotFwd::GetCached(
+    const DotParam &param, const std::vector<NDArray> &inputs,
+    const std::vector<NDArray> &outputs) {
+  using batch_dot_fwd_map =
+      std::unordered_map<BatchDotSignature, MKLDNNBatchDotFwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local batch_dot_fwd_map fwds;
+#else
+  static MX_THREAD_LOCAL batch_dot_fwd_map fwds;
+#endif
+
+  BatchDotSignature key(param);
+  key.AddSign(inputs[0]);
+  key.AddSign(inputs[1]);
+  key.AddSign(outputs[0]);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    const MKLDNNBatchDotFwd fwd(param, inputs, outputs);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+MKLDNNBatchDotFwd::MKLDNNBatchDotFwd(const DotParam &param,
+                                     const std::vector<NDArray> &inputs,
+                                     const std::vector<NDArray> &outputs) {
+  auto shape = inputs[0].shape();
+  auto ndim = shape.ndim();
+  auto bigDim = shape[0];
+  for (int i = 1; i < ndim - 2; ++i) {
+    bigDim *= shape[i];
+  }
+
+  auto GetMemoryDesc = [&ndim, &bigDim](const NDArray &tensor,
+                                        const bool transpose) {
+    auto shape = tensor.shape();
+    if (transpose) {
+      return mkldnn::memory::desc(
+          mkldnn::memory::dims{bigDim, shape[ndim - 1], shape[ndim - 2]},
+          get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::acb);
+    } else {
+      return mkldnn::memory::desc(
+          mkldnn::memory::dims{bigDim, shape[ndim - 2], shape[ndim - 1]},
+          get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::any);
+    }
+  };
+
+  mkldnn::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a);
+  mkldnn::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b);
+  mkldnn::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]},
+                              get_mkldnn_type(outputs[0].dtype()),
+                              mkldnn::memory::format_tag::any);
+  mkldnn::matmul::desc fwd_desc(data_md, weights_md, out_md);
+  fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(
+      fwd_desc, mxnet::CpuEngine::Get()->get_engine());
+  fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
+}
+
+void MKLDNNBatchDotFwd::Execute(const std::vector<NDArray> &inputs,
+                                const std::vector<OpReqType> &req,
+                                const std::vector<NDArray> &outputs) {
+  auto engine = mxnet::CpuEngine::Get()->get_engine();
+  auto data = mkldnn::memory(fwd_pd->src_desc(), engine,
+                             reinterpret_cast<void *>(inputs[0].data().dptr_));
+  auto weights =
+      mkldnn::memory(fwd_pd->weights_desc(), engine,
+                     reinterpret_cast<void *>(inputs[1].data().dptr_));
+  mkldnn_output_t out_mem =
+      CreateMKLDNNMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);
+
+  mkldnn_args_map_t args = {
+      {MKLDNN_ARG_SRC, data},
+      {MKLDNN_ARG_WEIGHTS, weights},
+      {MKLDNN_ARG_DST, *out_mem.second},
+  };
+
+  MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args);
+  CommitOutput(outputs[0], out_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 32f2e9f..27cbabb 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -133,6 +133,12 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                           const std::vector<OpReqType>& req,
                           const std::vector<NDArray>& outputs);
 
+/* For batch dot */
+void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+                           const std::vector<NDArray> &inputs,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<NDArray> &outputs);
+
 void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
                const mkldnn::memory &out);
 
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 8405404..2d6879f 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -64,6 +64,11 @@ struct DotParam : public dmlc::Parameter<DotParam> {
       .add_enum("csr", kCSRStorage)
       .set_default(dmlc::optional<int>());
   }
+  bool operator==(const DotParam& other) const {
+    return this->transpose_a == other.transpose_a &&
+           this->transpose_b == other.transpose_b &&
+           this->forward_stype == other.forward_stype;
+  }
 };
 
 template<typename xpu>
@@ -1454,5 +1459,16 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
 
 }  // namespace op
 }  // namespace mxnet
-
+namespace std {
+template<>
+struct hash<mxnet::op::DotParam> {
+  size_t operator()(const mxnet::op::DotParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.transpose_a);
+    ret = dmlc::HashCombine(ret, val.transpose_b);
+    ret = dmlc::HashCombine(ret, val.forward_stype);
+    return ret;
+  }
+};
+}  // namespace std
 #endif  // MXNET_OPERATOR_TENSOR_DOT_INL_H_
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index b3f6331..93f5813 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -23,6 +23,10 @@
  */
 
 #include "./dot-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./../nn/mkldnn/mkldnn_base-inl.h"
+#include "./../nn/mkldnn/mkldnn_ops-inl.h"
+#endif  // MXNET_USE_MKLDNN
 
 namespace mxnet {
 namespace op {
@@ -111,6 +115,34 @@ NNVM_REGISTER_OP(_backward_dot)
 .set_attr<FComputeEx>("FComputeEx<cpu>", DotBackwardEx<cpu>)
 .add_arguments(DotParam::__FIELDS__());
 
+#if MXNET_USE_MKLDNN == 1
+static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                 const OpContext& ctx,
+                                 const std::vector<NDArray>& inputs,
+                                 const std::vector<OpReqType>& req,
+                                 const std::vector<NDArray>& outputs) {
+  if (SupportMKLDNNBatchDot(inputs, outputs[0])) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNRun(MKLDNNBatchDotForward, attrs, ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
+  }
+  FallBackCompute(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+static bool BatchDotStorageType(const nnvm::NodeAttrs& attrs,
+                                const int dev_mask,
+                                DispatchMode* dispatch_mode,
+                                std::vector<int>* in_attrs,
+                                std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2);
+  CHECK_EQ(out_attrs->size(), 1);
+
+  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
+                           out_attrs);
+}
+#endif
+
 NNVM_REGISTER_OP(batch_dot)
 .add_alias("_npx_batch_dot")
 .describe(R"doc(Batchwise dot product.
@@ -140,6 +172,11 @@ which is computed by::
   })
 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
 .set_attr<FCompute>("FCompute<cpu>", BatchDotForward_<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FInferStorageType>("FInferStorageType", BatchDotStorageType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BatchDotComputeExCPU)
+#endif
 .set_attr<nnvm::FGradient>("FGradient",
     [](const nnvm::ObjectPtr& n,
        const std::vector<nnvm::NodeEntry>& ograds) {