You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/26 05:33:40 UTC
[incubator-mxnet] branch master updated: MKLDNN Forward
FullyConnected op cache (#11611)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7230bb9 MKLDNN Forward FullyConnected op cache (#11611)
7230bb9 is described below
commit 7230bb9b5f2f8caabf7bb64689e49ef6a5529b66
Author: zhiyuan-huang <hu...@163.com>
AuthorDate: Sun Aug 26 13:33:23 2018 +0800
MKLDNN Forward FullyConnected op cache (#11611)
* Enable primitive allocation cache for FullyConnected
* Enable primitive allocation cache for FullyConnected
* fix indent and pass in_data as last argument for CreateMKLDNNMem
* fix indent and pass in_data as last argument for CreateMKLDNNMem
---
src/operator/nn/fully_connected-inl.h | 17 ++++
src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 118 ++++++++++++++++++++---
2 files changed, 123 insertions(+), 12 deletions(-)
diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h
index 2338f89..2b75419 100644
--- a/src/operator/nn/fully_connected-inl.h
+++ b/src/operator/nn/fully_connected-inl.h
@@ -61,6 +61,11 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
DMLC_DECLARE_FIELD(flatten).set_default(true)
.describe("Whether to collapse all but the first axis of the input data tensor.");
}
+ bool operator==(const FullyConnectedParam& other) const {
+ return this->num_hidden == other.num_hidden &&
+ this->no_bias == other.no_bias &&
+ this->flatten == other.flatten;
+ }
};
template<typename xpu, typename DType>
@@ -228,4 +233,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
} // namespace op
} // namespace mxnet
+namespace std {
+template<>
+struct hash<mxnet::op::FullyConnectedParam> {
+ size_t operator()(const mxnet::op::FullyConnectedParam& val) {
+ size_t ret = 0;
+ ret = dmlc::HashCombine(ret, val.num_hidden);
+ ret = dmlc::HashCombine(ret, val.no_bias);
+ ret = dmlc::HashCombine(ret, val.flatten);
+ return ret;
+ }
+};
+} // namespace std
#endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index f86f8db..5f672cd 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -82,6 +82,100 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei
}
}
+class MKLDNNFullyConnectForward {
+ std::shared_ptr<mkldnn::memory> data;
+ std::shared_ptr<mkldnn::memory> weight;
+ std::shared_ptr<mkldnn::memory> out;
+ std::shared_ptr<mkldnn::memory> bias;
+ std::shared_ptr<mkldnn::inner_product_forward> ipFwd;
+
+ public:
+ mkldnn::inner_product_forward::primitive_desc ipFwd_pd;
+
+ MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train,
+ const NDArray &data, const NDArray &weight,
+ const NDArray *bias,
+ const mkldnn::memory::desc &output)
+ : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}
+
+ void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+ const mkldnn::memory *bias, const mkldnn::memory &output) {
+ if (this->data == nullptr)
+ this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
+ else
+ this->data->set_data_handle(data.get_data_handle());
+
+ if (this->weight == nullptr)
+ this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+ else
+ this->weight->set_data_handle(weight.get_data_handle());
+
+ if (this->out == nullptr)
+ this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
+ else
+ this->out->set_data_handle(output.get_data_handle());
+
+ if (bias != nullptr) {
+ if (this->bias == nullptr)
+ this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
+ else
+ this->bias->set_data_handle(bias->get_data_handle());
+ if (this->ipFwd == nullptr)
+ this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+ new mkldnn::inner_product_forward(
+ ipFwd_pd, mkldnn::primitive::at(*this->data),
+ mkldnn::primitive::at(*this->weight),
+ mkldnn::primitive::at(*this->bias), *this->out));
+ } else if (this->ipFwd == nullptr) {
+ this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+ new mkldnn::inner_product_forward(
+ ipFwd_pd, mkldnn::primitive::at(*this->data),
+ mkldnn::primitive::at(*this->weight), *this->out));
+ }
+ }
+ const mkldnn::inner_product_forward &GetIpFwd() const {
+ return *ipFwd;
+ }
+};
+
+typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+
+static inline MKLDNNFullyConnectForward &GetFCFwd(
+ const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
+ const NDArray *bias, const mkldnn::memory::desc &output,
+ const bool is_train) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<MKLDNNFullyconSignature,
+ MKLDNNFullyConnectForward, OpHash> fcFwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
+ MKLDNNFullyConnectForward, OpHash> fcFwds;
+#endif
+ const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ MKLDNNFullyconSignature key(param);
+ key.AddSign(data);
+ key.AddSign(weight);
+ key.AddSign(is_train);
+
+ if (bias)
+ key.AddSign(*bias);
+
+ auto it = fcFwds.find(key);
+ if (it == fcFwds.end()) {
+ MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
+ output);
+ auto ins_ret = fcFwds.insert(
+ std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
+ CHECK(ins_ret.second);
+ it = ins_ret.first;
+ }
+ return it->second;
+}
+
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
@@ -112,21 +206,21 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
mkldnn::memory::format::any);
}
-
- mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
- param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
- auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
- auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
+ MKLDNNFullyConnectForward &FCFwd =
+ GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
+ out_md, ctx.is_train);
+ auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
+ auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
- ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
- if (param.no_bias) {
- MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
- ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
+ FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
+ if (!param.no_bias) {
+ auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
+ FCFwd.ipFwd_pd.bias_primitive_desc());
+ FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
- auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
- MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
- *data_mem, *weight_mem, *bias_mem, *out_mem.second));
+ FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
}
+ MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}