You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/09/30 06:08:07 UTC
[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add
MKL-DNN FC (#16221)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
new 23093e6 [mkldnn-v1.0] Add MKL-DNN FC (#16221)
23093e6 is described below
commit 23093e6f774f6b1ce94007a267c32b2253b3cbcd
Author: rongzha1 <ro...@intel.com>
AuthorDate: Mon Sep 30 14:07:13 2019 +0800
[mkldnn-v1.0] Add MKL-DNN FC (#16221)
* add mkldnn fc; pass lint; pass mnist training
* add TODO info for future debug
---
src/imperative/imperative_utils.h | 12 +-
src/operator/nn/fully_connected.cc | 12 +-
.../nn/mkldnn/mkldnn_fully_connected-inl.h | 15 +--
src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 128 ++++++++-------------
src/operator/nn/mkldnn/mkldnn_ops-inl.h | 20 ++--
5 files changed, 77 insertions(+), 110 deletions(-)
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index bc61bc7..f0c199b 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -541,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state,
// copying A to B may not happen, and will corrupt A's memory.
InvalidateOutputs(outputs, req);
}
+ // add for mkldnn OP + no mkldnn OP
+ const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
+ if (!is_mkldnn.get(attrs.op, false)) {
+ std::vector<NDArray> inputs_fallback;
+ CreateDefaultInputs(inputs, &inputs_fallback);
+ fcompute_ex(state, opctx, inputs_fallback, req, outputs);
+ } else {
+#endif
+ fcompute_ex(state, opctx, inputs, req, outputs);
+#if MXNET_USE_MKLDNN == 100
+ }
#endif
- fcompute_ex(state, opctx, inputs, req, outputs);
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
&& rctx.get_stream<gpu>() && !rctx.is_bulk) {
rctx.get_stream<gpu>()->Wait();
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 06ad6d0..c80c08e 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -97,7 +97,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
valid_bias = inputs[2].storage_type() == kDefaultStorage ||
inputs[2].storage_type() == kRowSparseStorage;
}
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
if (SupportMKLDNNFC(inputs[0])) {
@@ -141,7 +141,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
#endif
}
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
@@ -199,7 +199,7 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
@@ -233,7 +233,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
@@ -295,7 +295,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -326,7 +326,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
})
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index fddaedc..db8cfdc 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -27,7 +27,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include <vector>
#include <string>
@@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_eltwise).set_default(false)
- .describe("Whether there's a post elemwise after FullyConnected operator");
+ .describe("Whether there's a post with_eltwise after FullyConnected operator");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
@@ -85,10 +85,9 @@ class MKLDNNFullyConnectedForward {
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &out_md)
- : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}
-
- void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
- const mkldnn::memory *bias, const mkldnn::memory &output);
+ : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {
+ fwd_ = std::make_shared<mkldnn::inner_product_forward>(fwd_pd);
+ }
const mkldnn::inner_product_forward &GetFwd() const {
return *fwd_;
@@ -96,10 +95,6 @@ class MKLDNNFullyConnectedForward {
private:
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
- std::shared_ptr<mkldnn::memory> data_;
- std::shared_ptr<mkldnn::memory> weight_;
- std::shared_ptr<mkldnn::memory> bias_;
- std::shared_ptr<mkldnn::memory> out_;
};
typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index fbe37e2..80eb2d6 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -24,7 +24,7 @@
* \author Da Zheng, Ciyong Chen
*/
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
#include "mkldnn_fully_connected-inl.h"
namespace mxnet {
@@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
}
attr.set_output_scales(mask, scales);
- attr.set_int_output_round_mode(round_nearest);
}
}
@@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
}
}
-void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
- const mkldnn::memory &weight,
- const mkldnn::memory *bias,
- const mkldnn::memory &output) {
- if (this->data_ == nullptr)
- this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.src_primitive_desc(), data.get_data_handle()));
- else
- this->data_->set_data_handle(data.get_data_handle());
-
- if (this->weight_ == nullptr)
- this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
- else
- this->weight_->set_data_handle(weight.get_data_handle());
-
- if (this->out_ == nullptr)
- this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.dst_primitive_desc(), output.get_data_handle()));
- else
- this->out_->set_data_handle(output.get_data_handle());
-
- if (bias != nullptr) {
- if (this->bias_ == nullptr)
- this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
- else
- this->bias_->set_data_handle(bias->get_data_handle());
-
- if (this->fwd_ == nullptr)
- this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
- new mkldnn::inner_product_forward(
- fwd_pd, mkldnn::primitive::at(*this->data_),
- mkldnn::primitive::at(*this->weight_),
- mkldnn::primitive::at(*this->bias_), *this->out_));
- } else {
- if (this->fwd_ == nullptr) {
- this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
- new mkldnn::inner_product_forward(
- fwd_pd, mkldnn::primitive::at(*this->data_),
- mkldnn::primitive::at(*this->weight_), *this->out_));
- }
- }
-}
-
MKLDNNFullyConnectedForward &GetFCFwd(
const FullyConnectedParam ¶m, const bool is_train,
const NDArray &data, const NDArray &weight,
@@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam ¶m,
mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
static_cast<int>(oshape[ishape.ndim()-1])};
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
- mkldnn::memory::format::any);
+ mkldnn::memory::format_tag::any);
} else {
*in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
- mkldnn::memory::format::any);
+ mkldnn::memory::format_tag::any);
}
}
}
@@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
NDArray weight = in_data[fullc::kWeight];
NDArray data = in_data[fullc::kData];
- auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc());
+ auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
if (weight.IsMKLDNNData()) {
weight.Reorder2DefaultAsync();
}
- weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
+ weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
} else {
- if (weight.IsDefaultData()) {
- // We also need to modify the layout on the original weight array.
- // Don't switch below sequence because naive engine will executes
- // pushAsync synchronously.
- weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
- weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
- } else {
- weight_mem = weight.GetMKLDNNData();
- CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
+ weight_mem = weight.GetMKLDNNData();
+ if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
+ // TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
+ // weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
+ weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
}
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
- fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
+ fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);
+
+ std::unordered_map<int, mkldnn::memory> args = {
+ {MKLDNN_ARG_SRC, *data_mem},
+ {MKLDNN_ARG_WEIGHTS, *weight_mem},
+ {MKLDNN_ARG_DST, *out_mem.second},
+ };
if (!full_param.default_param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
- fwd->fwd_pd.bias_primitive_desc());
- fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
- } else {
- fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
+ fwd->fwd_pd.bias_desc());
+ args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
}
- MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
+ MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
@@ -339,13 +293,18 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
- ipBwdData_pd.diff_dst_primitive_desc());
- auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
+ ipBwdData_pd.diff_dst_desc());
+ auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
- ipBwdData_pd.diff_src_primitive_desc(),
+ ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
- MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
- ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
+ std::unordered_map<int, mkldnn::memory> args = {
+ {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
+ {MKLDNN_ARG_WEIGHTS, *weight_mem},
+ {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
+ };
+
+ MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
@@ -353,23 +312,26 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
- ipBwdWeights_pd.diff_dst_primitive_desc());
- auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
+ ipBwdWeights_pd.diff_dst_desc());
+ auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
- ipBwdWeights_pd.diff_weights_primitive_desc(),
+ ipBwdWeights_pd.diff_weights_desc(),
req[fullc::kWeight]);
+ std::unordered_map<int, mkldnn::memory> args = {
+ {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
+ {MKLDNN_ARG_SRC, *data_mem},
+ {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
+ };
+
mkldnn_output_t in_grad_bias;
- if (param.no_bias) {
- MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
- ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
- } else {
+ if (!param.no_bias) {
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
- ipBwdWeights_pd.diff_bias_primitive_desc(),
+ ipBwdWeights_pd.diff_bias_desc(),
req[fullc::kBias]);
- MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
- ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
- *in_grad_bias.second));
+ args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
}
+ MKLDNNStream::Get()->RegisterPrimArgs(
+ mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
@@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
} // namespace op
} // namespace mxnet
-#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_USE_MKLDNN == 100
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 951b075..20d80cd 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -44,16 +44,6 @@ namespace mxnet {
namespace op {
#if MXNET_USE_MKLDNN == 1
-/* For fully connected. */
-void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
- const std::vector<NDArray> &in_data,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &out_data);
-void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
- const std::vector<NDArray> &inputs,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &outputs);
-
/* For deconvolution */
void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
@@ -104,6 +94,16 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
#endif
#if MXNET_USE_MKLDNN == 100
+/* For fully connected. */
+void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data);
+void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
/* For convolution. */
void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,