You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/07/13 13:51:33 UTC
[incubator-mxnet] branch v1.x updated: [1.x][backport] Integrate
matmul primitive from oneDNN in batch dot (#20382)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 7e835a5 [1.x][backport] Integrate matmul primitive from oneDNN in batch dot (#20382)
7e835a5 is described below
commit 7e835a592da902066f2ccb0d74a78adc8e77fdb7
Author: bartekkuncer <ba...@intel.com>
AuthorDate: Tue Jul 13 15:50:01 2021 +0200
[1.x][backport] Integrate matmul primitive from oneDNN in batch dot (#20382)
---
src/operator/nn/mkldnn/mkldnn_base-inl.h | 1 +
src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h | 65 +++++++++++++
src/operator/nn/mkldnn/mkldnn_batch_dot.cc | 132 ++++++++++++++++++++++++++
src/operator/nn/mkldnn/mkldnn_ops-inl.h | 6 ++
src/operator/tensor/dot-inl.h | 18 +++-
src/operator/tensor/dot.cc | 37 ++++++++
6 files changed, 258 insertions(+), 1 deletion(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 3e73103..cb30b0b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -215,6 +215,7 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
+bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs, const NDArray &output);
} // namespace op
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
new file mode 100644
index 0000000..02bd9ad
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_batch_dot-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <numeric>
+#include <utility>
+#include <vector>
+#include "../../tensor/dot-inl.h"
+#include "./mkldnn_base-inl.h"
+#include "./mkldnn_ops-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using batch_dot_fwd_t = mkldnn::matmul;
+using batch_dot_fwd_pd_t = mkldnn::matmul::primitive_desc;
+
+typedef ParamOpSign<DotParam> BatchDotSignature;
+
+class MKLDNNBatchDotFwd {
+ public:
+ static MKLDNNBatchDotFwd &GetCached(const DotParam ¶m,
+ const std::vector<NDArray> &inputs,
+ const std::vector<NDArray> &outputs);
+
+ MKLDNNBatchDotFwd(const DotParam ¶m, const std::vector<NDArray> &inputs,
+ const std::vector<NDArray> &outputs);
+
+ void Execute(const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
+ private:
+ std::shared_ptr<batch_dot_fwd_t> fwd;
+ std::shared_ptr<batch_dot_fwd_pd_t> fwd_pd;
+};
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H__
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
new file mode 100644
index 0000000..1a5006d
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_batch_dot.cc
+ */
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "./mkldnn_batch_dot-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs,
+ const NDArray &output) {
+ return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
+ output.shape().Size() != 0 &&
+ (inputs[0].dtype() == mshadow::kFloat32 ||
+ inputs[0].dtype() == mshadow::kBfloat16);
+}
+
+void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ const DotParam ¶m = nnvm::get<DotParam>(attrs.parsed);
+ MKLDNNBatchDotFwd &fwd = MKLDNNBatchDotFwd::GetCached(param, inputs, outputs);
+ fwd.Execute(inputs, req, outputs);
+}
+
+MKLDNNBatchDotFwd &MKLDNNBatchDotFwd::GetCached(
+ const DotParam ¶m, const std::vector<NDArray> &inputs,
+ const std::vector<NDArray> &outputs) {
+ using batch_dot_fwd_map =
+ std::unordered_map<BatchDotSignature, MKLDNNBatchDotFwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local batch_dot_fwd_map fwds;
+#else
+ static MX_THREAD_LOCAL batch_dot_fwd_map fwds;
+#endif
+
+ BatchDotSignature key(param);
+ key.AddSign(inputs[0]);
+ key.AddSign(inputs[1]);
+ key.AddSign(outputs[0]);
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ const MKLDNNBatchDotFwd fwd(param, inputs, outputs);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
+MKLDNNBatchDotFwd::MKLDNNBatchDotFwd(const DotParam ¶m,
+ const std::vector<NDArray> &inputs,
+ const std::vector<NDArray> &outputs) {
+ auto shape = inputs[0].shape();
+ auto ndim = shape.ndim();
+ auto bigDim = shape[0];
+ for (int i = 1; i < ndim - 2; ++i) {
+ bigDim *= shape[i];
+ }
+
+ auto GetMemoryDesc = [&ndim, &bigDim](const NDArray &tensor,
+ const bool transpose) {
+ auto shape = tensor.shape();
+ if (transpose) {
+ return mkldnn::memory::desc(
+ mkldnn::memory::dims{bigDim, shape[ndim - 1], shape[ndim - 2]},
+ get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::acb);
+ } else {
+ return mkldnn::memory::desc(
+ mkldnn::memory::dims{bigDim, shape[ndim - 2], shape[ndim - 1]},
+ get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::any);
+ }
+ };
+
+ mkldnn::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a);
+ mkldnn::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b);
+ mkldnn::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]},
+ get_mkldnn_type(outputs[0].dtype()),
+ mkldnn::memory::format_tag::any);
+ mkldnn::matmul::desc fwd_desc(data_md, weights_md, out_md);
+ fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(
+ fwd_desc, mxnet::CpuEngine::Get()->get_engine());
+ fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
+}
+
+void MKLDNNBatchDotFwd::Execute(const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ auto engine = mxnet::CpuEngine::Get()->get_engine();
+ auto data = mkldnn::memory(fwd_pd->src_desc(), engine,
+ reinterpret_cast<void *>(inputs[0].data().dptr_));
+ auto weights =
+ mkldnn::memory(fwd_pd->weights_desc(), engine,
+ reinterpret_cast<void *>(inputs[1].data().dptr_));
+ mkldnn_output_t out_mem =
+ CreateMKLDNNMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);
+
+ mkldnn_args_map_t args = {
+ {MKLDNN_ARG_SRC, data},
+ {MKLDNN_ARG_WEIGHTS, weights},
+ {MKLDNN_ARG_DST, *out_mem.second},
+ };
+
+ MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args);
+ CommitOutput(outputs[0], out_mem);
+ MKLDNNStream::Get()->Submit();
+}
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 32f2e9f..27cbabb 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -133,6 +133,12 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
+/* For batch dot */
+void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 8405404..2d6879f 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -64,6 +64,11 @@ struct DotParam : public dmlc::Parameter<DotParam> {
.add_enum("csr", kCSRStorage)
.set_default(dmlc::optional<int>());
}
+ bool operator==(const DotParam& other) const {
+ return this->transpose_a == other.transpose_a &&
+ this->transpose_b == other.transpose_b &&
+ this->forward_stype == other.forward_stype;
+ }
};
template<typename xpu>
@@ -1454,5 +1459,16 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
} // namespace op
} // namespace mxnet
-
+namespace std {
+template<>
+struct hash<mxnet::op::DotParam> {
+ size_t operator()(const mxnet::op::DotParam& val) {
+ size_t ret = 0;
+ ret = dmlc::HashCombine(ret, val.transpose_a);
+ ret = dmlc::HashCombine(ret, val.transpose_b);
+ ret = dmlc::HashCombine(ret, val.forward_stype);
+ return ret;
+ }
+};
+} // namespace std
#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index b3f6331..93f5813 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -23,6 +23,10 @@
*/
#include "./dot-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./../nn/mkldnn/mkldnn_base-inl.h"
+#include "./../nn/mkldnn/mkldnn_ops-inl.h"
+#endif // MXNET_USE_MKLDNN
namespace mxnet {
namespace op {
@@ -111,6 +115,34 @@ NNVM_REGISTER_OP(_backward_dot)
.set_attr<FComputeEx>("FComputeEx<cpu>", DotBackwardEx<cpu>)
.add_arguments(DotParam::__FIELDS__());
+#if MXNET_USE_MKLDNN == 1
+static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ if (SupportMKLDNNBatchDot(inputs, outputs[0])) {
+ MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ MKLDNNRun(MKLDNNBatchDotForward, attrs, ctx, inputs, req, outputs);
+ MKLDNN_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+ return;
+ }
+ FallBackCompute(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+static bool BatchDotStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2);
+ CHECK_EQ(out_attrs->size(), 1);
+
+ return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
+ out_attrs);
+}
+#endif
+
NNVM_REGISTER_OP(batch_dot)
.add_alias("_npx_batch_dot")
.describe(R"doc(Batchwise dot product.
@@ -140,6 +172,11 @@ which is computed by::
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", BatchDotForward_<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FInferStorageType>("FInferStorageType", BatchDotStorageType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BatchDotComputeExCPU)
+#endif
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {