You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/03/08 04:35:30 UTC
[incubator-mxnet] branch master updated: MKLDNN based Quantized
FullyConnected Operator and its fusion (#14128)
This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8668db7 MKLDNN based Quantized FullyConnected Operator and its fusion (#14128)
8668db7 is described below
commit 8668db79ad6e0f75cc848923e323f6c316e18937
Author: ciyong <ci...@intel.com>
AuthorDate: Fri Mar 8 12:35:03 2019 +0800
MKLDNN based Quantized FullyConnected Operator and its fusion (#14128)
* add MKL-DNN quantized innerproduct
* initial qfc with mkldnn
* Add MKL-DNN quantized_fully_connected
* refactor params order for fullyconnected
* update quantized_fully_connected unittest, force data to uint8 type temporary
* change mkl based quantized fully_connected to FCompute
* add check data type for mkldnn quantized_fc
* add fuse requantize and dequantize for mkldnn quantized fullyconnected
* add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected
* fix requantize scaling error
* add fallback when input data is int8
* fix mkl quantized fullyconnected index error
* update quantized fc test cases
* add subgraph node for mkldnn fullyconnected
* fix compiling and lint error
* clean and refactor code
* enable quantized_fc for imagenet
* cleanup code
* Fix StorageType error for non-mkldnn path
* fix pylint
* reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check
* rename variables and refactor codes according to comments
* add subgraph qfc tests and fix shape error
* remove fuse_requantize and change fuse_dequantize to enable_float_output.
* change to use mxnet::Tuple and update tests
* update description in file header
* update input0 type check for quantized FullyConnected
* fix conflit of mkl/test_subgraph.py
* retrigger CI
* retrigger CI due to hang
---
example/quantization/imagenet_gen_qsym_mkldnn.py | 9 +-
python/mxnet/initializer.py | 13 +
.../nn/mkldnn/mkldnn_fully_connected-inl.h | 133 +++++++
src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 343 +++++++++-------
.../mkldnn/mkldnn_quantized_fully_connected.cc | 134 +++++++
.../quantization/mkldnn/mkldnn_quantized_ops-inl.h | 48 +++
.../quantization/quantized_fully_connected.cc | 126 ++++--
src/operator/subgraph/mkldnn/mkldnn_fc.cc | 442 +++++++++++++++++++++
.../mkldnn/mkldnn_fc_post_quantize_property.cc | 217 ++++++++++
src/operator/subgraph/mkldnn/mkldnn_fc_property.cc | 193 +++++++++
tests/python/mkl/test_subgraph.py | 175 ++++++--
tests/python/quantization/test_quantization.py | 70 +++-
12 files changed, 1679 insertions(+), 224 deletions(-)
diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py
index d807e7f..3f644fc 100644
--- a/example/quantization/imagenet_gen_qsym_mkldnn.py
+++ b/example/quantization/imagenet_gen_qsym_mkldnn.py
@@ -180,6 +180,7 @@ if __name__ == '__main__':
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
sym = sym.get_backend_symbol('MKLDNN')
+ sym = sym.get_backend_symbol('MKLDNN_FC')
# get batch size
batch_size = args.batch_size
@@ -207,19 +208,18 @@ if __name__ == '__main__':
if args.model == 'imagenet1k-resnet-152':
rgb_mean = '0,0,0'
rgb_std = '1,1,1'
- excluded_sym_names += ['flatten0', 'fc1']
+ excluded_sym_names += ['flatten0']
if exclude_first_conv:
excluded_sym_names += ['conv0']
elif args.model == 'imagenet1k-inception-bn':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '1,1,1'
- excluded_sym_names += ['flatten', 'fc1']
+ excluded_sym_names += ['flatten']
if exclude_first_conv:
excluded_sym_names += ['conv_1']
elif args.model in ['resnet50_v1', 'resnet101_v1']:
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
- excluded_sym_names += ['resnetv10_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['resnetv10_conv0_fwd']
elif args.model == 'squeezenet1.0':
@@ -232,14 +232,12 @@ if __name__ == '__main__':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['mobilenet0_flatten0_flatten0',
- 'mobilenet0_dense0_fwd',
'mobilenet0_pool0_fwd']
if exclude_first_conv:
excluded_sym_names += ['mobilenet0_conv0_fwd']
elif args.model == 'inceptionv3':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
- excluded_sym_names += ['inception30_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['inception30_conv0_fwd']
elif args.model == 'custom':
@@ -305,6 +303,7 @@ if __name__ == '__main__':
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
+ qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE')
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py
index 611592a..aca7c58 100755
--- a/python/mxnet/initializer.py
+++ b/python/mxnet/initializer.py
@@ -159,6 +159,12 @@ class Initializer(object):
elif desc.endswith('max'):
self._init_one(desc, arr)
self._verbose_print(desc, 'max', arr)
+ elif desc.endswith('weight_quantize'):
+ self._init_quantized_weight(desc, arr)
+ self._verbose_print(desc, 'weight_quantize', arr)
+ elif desc.endswith('bias_quantize'):
+ self._init_quantized_bias(desc, arr)
+ self._verbose_print(desc, 'bias_quantize', arr)
else:
self._init_default(desc, arr)
@@ -235,6 +241,9 @@ class Initializer(object):
def _init_bias(self, _, arr):
arr[:] = 0.0
+ def _init_quantized_bias(self, _, arr):
+ arr[:] = 0
+
def _init_gamma(self, _, arr):
arr[:] = 1.0
@@ -245,6 +254,10 @@ class Initializer(object):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")
+ def _init_quantized_weight(self, _, arr):
+ _arr = random.randint(-127, 127, dtype='int32').asnumpy()
+ arr[:] = np.int8(_arr)
+
def _init_default(self, name, _):
raise ValueError(
'Unknown initialization pattern for %s. ' \
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
new file mode 100644
index 0000000..c083714
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_fully_connected-inl.h
+ * \brief Common functions used by MKLDNN (Quantized) FullyConnected operator
+ * \author Ciyong Chen
+*/
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <vector>
+#include <string>
+#include "../fully_connected-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
+ bool quantized;
+ bool enable_float_output;
+ bool with_relu;
+ dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
+ dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
+
+ DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
+ DMLC_DECLARE_FIELD(quantized).set_default(false)
+ .describe("Whether it's a quantized FullyConnected operator");
+ DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
+ .describe("Whether to enable float32 output");
+ DMLC_DECLARE_FIELD(with_relu).set_default(false)
+ .describe("Whether there's a post relu after FullyConnected operator");
+ DMLC_DECLARE_FIELD(min_calib_range)
+ .set_default(dmlc::optional<float>())
+ .describe("The minimum scalar value in the form of float32 obtained "
+ "through calibration. If present, it will be used to by "
+ "quantized fullyconnected op to calculate primitive scale");
+ DMLC_DECLARE_FIELD(max_calib_range)
+ .set_default(dmlc::optional<float>())
+ .describe("The maximum scalar value in the form of float32 obtained "
+ "through calibration. If present, it will be used to by "
+ "quantized fullyconnected op to calculate primitive scale");
+ }
+};
+
+struct MKLDNNFCFullParam {
+ FullyConnectedParam default_param;
+ MKLDNNFCParam mkldnn_param;
+ std::vector<float> output_scales = {0.0};
+ std::vector<float> requantize_scales = {0.0};
+};
+
+mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
+ const MKLDNNFCFullParam &full_param, const bool is_train,
+ const NDArray &data, const NDArray &weight, const NDArray *bias,
+ const mkldnn::memory::desc &out_md);
+
+class MKLDNNFullyConnectedForward {
+ public:
+ mkldnn::inner_product_forward::primitive_desc fwd_pd;
+
+ MKLDNNFullyConnectedForward(const MKLDNNFCFullParam &full_param, const bool is_train,
+ const NDArray &data, const NDArray &weight,
+ const NDArray *bias,
+ const mkldnn::memory::desc &out_md)
+ : fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}
+
+
+ void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+ const mkldnn::memory *bias, const mkldnn::memory &output);
+
+ const mkldnn::inner_product_forward &GetFwd() const {
+ return *fwd_;
+ }
+
+ private:
+ std::shared_ptr<mkldnn::inner_product_forward> fwd_;
+ std::shared_ptr<mkldnn::memory> data_;
+ std::shared_ptr<mkldnn::memory> weight_;
+ std::shared_ptr<mkldnn::memory> bias_;
+ std::shared_ptr<mkldnn::memory> out_;
+};
+
+typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+
+MKLDNNFullyConnectedForward &GetFCFwd(
+ const FullyConnectedParam ¶m, const bool is_train,
+ const NDArray &data, const NDArray &weight,
+ const NDArray *bias, const mkldnn::memory::desc &out_md);
+
+void MKLDNNFCFlattenData(const FullyConnectedParam ¶m,
+ const NDArray &out_data,
+ NDArray *in_data,
+ mkldnn::memory::desc *out_md);
+
+void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data);
+
+void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam ¶m,
+ const OpContext &ctx,
+ MKLDNNFullyConnectedForward *fwd,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 05ef7eb..03d7e62 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -18,220 +18,296 @@
*/
/*!
+ * Copyright (c) 2018 by Contributors
* \file mkldnn_fully_connected.cc
- * \brief
- * \author Da Zheng
+ * \brief MKLDNN FullyConnected operator
+ * \author Da Zheng, Ciyong Chen
*/
-#include "../fully_connected-inl.h"
-#include "./mkldnn_base-inl.h"
-
#if MXNET_USE_MKLDNN == 1
+#include "mkldnn_fully_connected-inl.h"
+
namespace mxnet {
namespace op {
-inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd(
+DMLC_REGISTER_PARAMETER(MKLDNNFCParam);
+
+mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
+ const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight, const NDArray *bias,
- const mkldnn::memory::desc &out_md, const bool is_train) {
+ const mkldnn::memory::desc &out_md) {
auto data_md = GetMemDesc(data);
auto weight_md = GetMemDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
auto propagation =
is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
+
+ mkldnn::primitive_attr attr;
+ mkldnn::post_ops ops;
+ if (full_param.mkldnn_param.with_relu) {
+ const float scale = 1.0f;
+ const float alpha = 0.0f;
+ const float beta = 1.0f;
+ ops.append_eltwise(scale, eltwise_relu, alpha, beta);
+ }
+ attr.set_post_ops(ops);
+
+ if (full_param.mkldnn_param.quantized) {
+ if ((full_param.mkldnn_param.min_calib_range.has_value() &&
+ full_param.mkldnn_param.max_calib_range.has_value()) ||
+ full_param.mkldnn_param.enable_float_output) {
+ int mask = 0;
+ std::vector<float> scales = {0.0};
+ if (full_param.requantize_scales.size()) {
+ scales[0] = full_param.requantize_scales[0];
+ } else if (full_param.output_scales.size()) {
+ scales[0] = full_param.output_scales[0];
+ } else {
+ LOG(FATAL) << "Must specified either output_scales or requantize_scales!";
+ }
+
+ attr.set_output_scales(mask, scales);
+ attr.set_int_output_round_mode(round_nearest);
+ }
+ }
+
+ auto GetFCFwdPd = [&full_param, &attr,
+ &engine](const mkldnn::inner_product_forward::desc &desc) {
+ try {
+ return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine);
+ } catch (mkldnn::error &e) {
+ if (e.status == mkldnn_unimplemented &&
+ full_param.mkldnn_param.quantized) {
+ LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 fully_connected.";
+ } else {
+ LOG(ERROR) << e.message;
+ }
+ throw;
+ }
+ };
+
if (bias) {
auto bias_md = GetMemDesc(*bias);
- mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
+ mkldnn::inner_product_forward::desc desc(propagation,
data_md, weight_md, bias_md, out_md);
- return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
+ return GetFCFwdPd(desc);
} else {
- mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
+ mkldnn::inner_product_forward::desc desc(propagation,
data_md, weight_md, out_md);
- return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
+ return GetFCFwdPd(desc);
}
}
-inline static mkldnn::inner_product_backward_data::primitive_desc GetIpBwdData(
+inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray &data, const NDArray &weight, const NDArray &output,
- mkldnn::inner_product_forward::primitive_desc ipFwd_pd) {
+ mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetMemDesc(weight);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
- return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, ipFwd_pd);
+ return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, fwd_pd);
}
-inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWeights(
+inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWeights(
const NDArray &data, const NDArray &weight, const NDArray *bias,
- const NDArray &output, mkldnn::inner_product_forward::primitive_desc ipFwd_pd) {
+ const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
auto weight_md = GetMemDesc(weight);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {
auto bias_md = GetMemDesc(*bias);
- mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md,
+ mkldnn::inner_product_backward_weights::desc desc(data_md,
weight_md, bias_md, out_md);
return mkldnn::inner_product_backward_weights::primitive_desc(
- ipBwdWeights_desc, engine, ipFwd_pd);
+ desc, engine, fwd_pd);
} else {
- mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md,
+ mkldnn::inner_product_backward_weights::desc desc(data_md,
weight_md, out_md);
return mkldnn::inner_product_backward_weights::primitive_desc(
- ipBwdWeights_desc, engine, ipFwd_pd);
+ desc, engine, fwd_pd);
}
}
-class MKLDNNFullyConnectForward {
- std::shared_ptr<mkldnn::memory> data;
- std::shared_ptr<mkldnn::memory> weight;
- std::shared_ptr<mkldnn::memory> out;
- std::shared_ptr<mkldnn::memory> bias;
- std::shared_ptr<mkldnn::inner_product_forward> ipFwd;
-
- public:
- mkldnn::inner_product_forward::primitive_desc ipFwd_pd;
-
- MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train,
- const NDArray &data, const NDArray &weight,
- const NDArray *bias,
- const mkldnn::memory::desc &output)
- : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}
-
- void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
- const mkldnn::memory *bias, const mkldnn::memory &output) {
- if (this->data == nullptr)
- this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
- else
- this->data->set_data_handle(data.get_data_handle());
+void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
+ const mkldnn::memory &weight,
+ const mkldnn::memory *bias,
+ const mkldnn::memory &output) {
+ if (this->data_ == nullptr)
+ this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ fwd_pd.src_primitive_desc(), data.get_data_handle()));
+ else
+ this->data_->set_data_handle(data.get_data_handle());
- if (this->weight == nullptr)
- this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
- else
- this->weight->set_data_handle(weight.get_data_handle());
+ if (this->weight_ == nullptr)
+ this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+ else
+ this->weight_->set_data_handle(weight.get_data_handle());
+
+ if (this->out_ == nullptr)
+ this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+ else
+ this->out_->set_data_handle(output.get_data_handle());
- if (this->out == nullptr)
- this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
+ if (bias != nullptr) {
+ if (this->bias_ == nullptr)
+ this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+ fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
else
- this->out->set_data_handle(output.get_data_handle());
-
- if (bias != nullptr) {
- if (this->bias == nullptr)
- this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
- ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
- else
- this->bias->set_data_handle(bias->get_data_handle());
- if (this->ipFwd == nullptr)
- this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
- new mkldnn::inner_product_forward(
- ipFwd_pd, mkldnn::primitive::at(*this->data),
- mkldnn::primitive::at(*this->weight),
- mkldnn::primitive::at(*this->bias), *this->out));
- } else if (this->ipFwd == nullptr) {
- this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+ this->bias_->set_data_handle(bias->get_data_handle());
+
+ if (this->fwd_ == nullptr)
+ this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
- ipFwd_pd, mkldnn::primitive::at(*this->data),
- mkldnn::primitive::at(*this->weight), *this->out));
+ fwd_pd, mkldnn::primitive::at(*this->data_),
+ mkldnn::primitive::at(*this->weight_),
+ mkldnn::primitive::at(*this->bias_), *this->out_));
+ } else {
+ if (this->fwd_ == nullptr) {
+ this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
+ new mkldnn::inner_product_forward(
+ fwd_pd, mkldnn::primitive::at(*this->data_),
+ mkldnn::primitive::at(*this->weight_), *this->out_));
}
}
- const mkldnn::inner_product_forward &GetIpFwd() const {
- return *ipFwd;
- }
-};
-
-typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+}
-static inline MKLDNNFullyConnectForward &GetFCFwd(
- const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
- const NDArray *bias, const mkldnn::memory::desc &output,
- const bool is_train) {
+MKLDNNFullyConnectedForward &GetFCFwd(
+ const FullyConnectedParam ¶m, const bool is_train,
+ const NDArray &data, const NDArray &weight,
+ const NDArray *bias, const mkldnn::memory::desc &out_md) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNFullyconSignature,
- MKLDNNFullyConnectForward, OpHash> fcFwds;
+ MKLDNNFullyConnectedForward, OpHash> fcFwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
- MKLDNNFullyConnectForward, OpHash> fcFwds;
+ MKLDNNFullyConnectedForward, OpHash> fcFwds;
#endif
- const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
MKLDNNFullyconSignature key(param);
+ key.AddSign(is_train);
key.AddSign(data);
key.AddSign(weight);
- key.AddSign(is_train);
-
if (bias)
key.AddSign(*bias);
auto it = fcFwds.find(key);
if (it == fcFwds.end()) {
- MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
- output);
- auto ins_ret = fcFwds.insert(
- std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
- CHECK(ins_ret.second);
- it = ins_ret.first;
+ MKLDNNFCFullParam full_param;
+ full_param.default_param = param;
+ full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
+ MKLDNNFullyConnectedForward fcFwd(full_param, is_train, data, weight, bias, out_md);
+ it = AddToCache(&fcFwds, key, fcFwd);
}
return it->second;
}
-void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
- const std::vector<NDArray> &in_data,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &out_data) {
+void MKLDNNFCFlattenData(const FullyConnectedParam ¶m,
+ const NDArray &out_data,
+ NDArray *in_data,
+ mkldnn::memory::desc *out_md) {
+ const mxnet::TShape ishape = in_data->shape();
+ const mxnet::TShape oshape = out_data.shape();
+
+ // If the input data is a view of an MKLDNN array, we should create a new
+ // NDArray with reordered data.
+ if (in_data->IsMKLDNNData() && in_data->IsView())
+ *in_data = in_data->Reorder2Default();
+
+ if (ishape.ndim() != 2) {
+ if (!param.flatten) {
+ *in_data = in_data->MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1),
+ ishape[ishape.ndim()-1]));
+ mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
+ static_cast<int>(oshape[ishape.ndim()-1])};
+ *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
+ mkldnn::memory::format::any);
+ } else {
+ *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
+ mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
+ static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
+ *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
+ mkldnn::memory::format::any);
+ }
+ }
+}
+
+void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
+ const OpContext &ctx,
+ MKLDNNFullyConnectedForward *fwd,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
- const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
- const mxnet::TShape& ishape = in_data[fullc::kData].shape();
- const mxnet::TShape& oshape = out_data[fullc::kOut].shape();
NDArray weight = in_data[fullc::kWeight];
NDArray data = in_data[fullc::kData];
- // If the input data is a view of an MKLDNN array, we should create a new
- // NDArray with reordered data.
- if (data.IsMKLDNNData() && data.IsView())
- data = in_data[fullc::kData].Reorder2Default();
- auto out_md = GetMemDesc(out_data[fullc::kOut]);
- if (data.shape().ndim() != 2 && !param.flatten) {
- data = data.MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1),
- ishape[ishape.ndim()-1]));
- mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
- static_cast<int>(oshape[ishape.ndim()-1])};
- out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
- mkldnn::memory::format::any);
- } else if (data.shape().ndim() != 2) {
- data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
- mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
- static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
- out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
- mkldnn::memory::format::any);
+ auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc());
+ const mkldnn::memory *weight_mem;
+ if (ctx.is_train) {
+ if (weight.IsMKLDNNData()) {
+ weight.Reorder2DefaultAsync();
+ }
+ weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
+ } else {
+ if (weight.IsDefaultData()) {
+ weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
+ weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
+ } else {
+ weight_mem = weight.GetMKLDNNData();
+ CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
+ }
}
- MKLDNNFullyConnectForward &FCFwd =
- GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
- out_md, ctx.is_train);
- auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
- auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
- FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
- if (!param.no_bias) {
+ fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
+ if (!full_param.default_param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
- FCFwd.ipFwd_pd.bias_primitive_desc());
- FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
+ fwd->fwd_pd.bias_primitive_desc());
+ fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
- FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
+ fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
}
- MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
+ MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
+void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data) {
+ MKLDNNFCFullParam full_param;
+ full_param.default_param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
+
+ NDArray data = in_data[fullc::kData];
+ mkldnn::memory::desc out_md = GetMemDesc(out_data[fullc::kOut]);
+ MKLDNNFCFlattenData(full_param.default_param, out_data[fullc::kOut],
+ &data, &out_md);
+ auto &fwd = GetFCFwd(full_param.default_param, ctx.is_train, data,
+ in_data[fullc::kWeight],
+ full_param.default_param.no_bias ? nullptr : &in_data[fullc::kBias],
+ out_md);
+ std::vector<NDArray> new_inputs;
+ if (full_param.default_param.no_bias)
+ new_inputs = {data, in_data[fullc::kWeight]};
+ else
+ new_inputs = {data, in_data[fullc::kWeight], in_data[fullc::kBias]};
+ MKLDNNFCForwardFullFeature(full_param, ctx, &fwd, new_inputs, req, out_data);
+}
+
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
const std::vector<NDArray> &in_grad = outputs;
- const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ MKLDNNFCFullParam full_param;
+ full_param.default_param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
+ const FullyConnectedParam& param = full_param.default_param;
const mxnet::TShape& ishape = inputs[fullc::kData + 1].shape();
const mxnet::TShape& oshape = inputs[fullc::kOut].shape();
@@ -251,13 +327,14 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
out_grad = out_grad.MKLDNNDataReshape(Shape2(oshape[0],
oshape.ProdShape(1, oshape.ndim())));
- mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
- param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train);
+
+ mkldnn::inner_product_forward::primitive_desc fwd_pd = GetFCFwdImpl(full_param, ctx.is_train,
+ data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));
CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
if (req[fullc::kData]) {
- mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData(
- data, weight, out_grad, ipFwd_pd);
+ mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
+ data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
@@ -270,8 +347,8 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
- = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
- out_grad, ipFwd_pd);
+ = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
+ out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
new file mode 100644
index 0000000..36def00
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_quantized_fully_connected.cc
+ * \brief MKLDNN Quantized FullyConnected operator
+ * \author Ciyong Chen
+ */
+
+#if MXNET_USE_MKLDNN == 1
+#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
+#include "../quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+namespace quantized_fc_enum {
+enum QuantizedFCInputMinMax { kDataMin, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax };
+enum QuantizedFCOutputs { kOut, kOutMin, kOutMax };
+}
+
+void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data) {
+ TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
+ FullyConnectedParam param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ const size_t num_inputs = param.no_bias ? 2 : 3;
+
+ CHECK_EQ(in_data.size(), static_cast<size_t>(num_inputs * 3));
+ CHECK_EQ(out_data.size(), 3U);
+
+ NDArray data = in_data[fullc::kData];
+ NDArray weight = in_data[fullc::kWeight];
+ const TShape &ishape = data.shape();
+
+ CHECK(data.dtype() == mshadow::kUint8)
+ << "MKLDNNQuantizedFullyConnected Op only supports uint8 for now, but got "
+ << mxnet::op::type_string(data.dtype());
+
+ if (ishape.ndim() != 2) {
+ CHECK(param.flatten)
+ << "QuantizedFullyConnected Op only supports flatten=true when ishape.ndim()!=2 for now.";
+ data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
+ }
+
+ const float min_data =
+ in_data[num_inputs + quantized_fc_enum::kDataMin].data().dptr<float>()[0];
+ const float max_data =
+ in_data[num_inputs + quantized_fc_enum::kDataMax].data().dptr<float>()[0];
+ const float min_weight =
+ in_data[num_inputs + quantized_fc_enum::kWeightMin].data().dptr<float>()[0];
+ const float max_weight =
+ in_data[num_inputs + quantized_fc_enum::kWeightMax].data().dptr<float>()[0];
+ float *min_output_ptr = out_data[quantized_fc_enum::kOutMin].data().dptr<float>();
+ float *max_output_ptr = out_data[quantized_fc_enum::kOutMax].data().dptr<float>();
+
+ auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range;
+ float data_scale = data_range / MaxAbs(min_data, max_data);
+ float weight_scale = kInt8Range / MaxAbs(min_weight, max_weight);
+
+ NDArray quantized_bias;
+ if (!param.no_bias) {
+ NDArray bias = in_data[fullc::kBias];
+ float min_bias = in_data[num_inputs + quantized_fc_enum::kBiasMin].data().dptr<float>()[0];
+ float max_bias = in_data[num_inputs + quantized_fc_enum::kBiasMax].data().dptr<float>()[0];
+ float bias_int32_rescale = data_scale * weight_scale * MaxAbs(min_bias, max_bias) / kInt8Range;
+
+ quantized_bias = NDArray(bias.storage_type(), bias.shape(),
+ bias.ctx(), true, mshadow::kInt32);
+ int8_t *bias_ptr = bias.data().dptr<int8_t>();
+ int32_t *quantized_bias_ptr = quantized_bias.data().dptr<int32_t>();
+ size_t bias_size = bias.shape().Size();
+ #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+ for (size_t i = 0; i < bias_size; ++i) {
+ quantized_bias_ptr[i] = bias_ptr[i] * bias_int32_rescale;
+ }
+ }
+
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
+ min_output_ptr, max_output_ptr, &min_data, &max_data, &min_weight, &max_weight);
+
+ bool is_train = false;
+ mkldnn::memory::desc out_md = GetMemDesc(out_data[fullc::kOut]);
+ MKLDNNFCFlattenData(param, out_data[fullc::kOut], &data, &out_md);
+ auto &fwd = GetFCFwd(param, is_train, data, weight,
+ param.no_bias ? nullptr : &quantized_bias, out_md);
+
+ auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
+ const mkldnn::memory *weight_mem = nullptr;
+
+ if (weight.IsDefaultData()) {
+ weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
+ weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
+ } else {
+ weight_mem = weight.GetMKLDNNData();
+ CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
+ }
+ auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(),
+ req[fullc::kOut]);
+ const mkldnn::memory *bias_mem = nullptr;
+ if (!param.no_bias)
+ bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());
+
+ fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
+ MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+
+ CommitOutput(out_data[fullc::kOut], out_mem);
+ MKLDNNStream::Get()->Submit();
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h
new file mode 100644
index 0000000..88d77c8
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_quantized_ops-inl.h
+ * \brief Common functions used by MKLDNN Quantized FullyConnected operator
+ * \author Ciyong Chen
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <mxnet/ndarray.h>
+#include <vector>
+
+namespace mxnet {
+namespace op {
+
+void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc
index 3b18e65..742825c 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -26,6 +26,10 @@
#include <vector>
#include "quantization_utils.h"
#include "../nn/fully_connected-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
+#include "mkldnn/mkldnn_quantized_ops-inl.h"
+#endif
namespace mxnet {
namespace op {
@@ -38,7 +42,6 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
- CHECK(param.flatten) << "QuantizedFullyConnectedOp only supports flatten=true for now";
using namespace mshadow;
uint32_t num_inputs = param.no_bias ? 2 : 3;
CHECK_EQ(in_shape->size(), num_inputs * 3);
@@ -48,6 +51,10 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
<< "QuantizedFullyConnectedOp input data shape must be given";
const mxnet::TShape& dshape = in_shape->at(0);
mxnet::TShape wshape = Shape2(param.num_hidden, dshape.ProdShape(1, dshape.ndim()));
+ if (dshape.ndim() != 2) {
+ CHECK(param.flatten)
+ << "QuantizedFullyConnectedOp only supports flatten=true when ishape.ndim()!=2 for now. ";
+ }
SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
if (!param.no_bias) {
mxnet::TShape bshape = Shape1(param.num_hidden);
@@ -72,7 +79,14 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_type->size(), num_inputs * 3);
CHECK_EQ(out_type->size(), 3U);
- for (size_t i = 0; i < num_inputs; ++i) {
+#if MXNET_USE_MKLDNN == 1
+ // TODO(ciyong): currently, only uint8 fully_connected is upported,
+ // int8 fully_connected will be supported after mkldnn v0.18
+ TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kUint8);
+#else
+ TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
+#endif
+ for (size_t i = 1; i < num_inputs; ++i) {
TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kInt8);
}
for (size_t i = num_inputs; i < 3 * num_inputs; ++i) {
@@ -90,10 +104,16 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
+ const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+ uint32_t num_inputs = param.no_bias ? 2 : 3;
+ CHECK_EQ(in_attrs->size(), num_inputs * 3);
+ CHECK_EQ(out_attrs->size(), 3U);
+
+#if MXNET_USE_MKLDNN == 1
+ return MKLDNNStorageType(attrs, dev_mask, true,
+ dispatch_mode, in_attrs, out_attrs);
+#else
*dispatch_mode = DispatchMode::kFCompute;
- if (dev_mask == mshadow::cpu::kDevMask) {
- *dispatch_mode = DispatchMode::kFComputeEx;
- }
for (auto &v : *out_attrs) {
v = kDefaultStorage;
@@ -109,6 +129,7 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
}
}
return true;
+#endif
}
struct QuantizedSumInitKernelWithBias {
@@ -137,28 +158,41 @@ struct QuantizedSumInitKernelWithBias {
};
-template<typename SrcType>
-void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
- const OpContext &ctx,
- const std::vector<NDArray> &in_data,
- const std::vector<OpReqType> &req,
- const std::vector<NDArray> &out_data) {
+void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &out_data) {
#if MSHADOW_USE_MKL == 1
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
using namespace mshadow;
using namespace mxnet_op;
+ Stream<cpu> *s = ctx.get_stream<cpu>();
size_t num_inputs = param.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), num_inputs * 3);
CHECK_EQ(out_data.size(), 3U);
- const NDArray& data = in_data[0];
- const NDArray& weight = in_data[1];
- const NDArray& out = out_data[0];
- mxnet::TShape dshape = data.shape();
- mxnet::TShape wshape = weight.shape();
- mxnet::TShape oshape = out.shape();
- auto output_temp = out.data().dptr<int32_t>();
- auto weight_temp = weight.data().dptr<SrcType>();
- auto data_temp = data.data().dptr<SrcType>();
+
+ const mxnet::TShape &dshape = in_data[fullc::kData].shape_;
+ const mxnet::TShape &wshape = in_data[fullc::kWeight].shape_;
+ const mxnet::TShape &oshape = out_data[fullc::kOut].shape_;
+
+ CHECK(in_data[fullc::kData].type_flag_ == mshadow::kInt8)
+ << "QuantizedFullyConnectedForwardCPU Op only supports int8 for now, but got "
+ << mxnet::op::type_string(in_data[fullc::kData].type_flag_);
+
+ if (dshape.ndim() != 2)
+ CHECK(param.flatten)
+ << "QuantizedFullyConnectedOp only supports flatten=true when input_shape!=2 for now. ";
+
+ Tensor<cpu, 2, int8_t> weight = in_data[fullc::kWeight].get<cpu, 2, int8_t>(s);
+ Tensor<cpu, 2, int8_t> data = in_data[fullc::kData].get_with_shape<cpu, 2, int8_t>(
+ Shape2(dshape[0], dshape.ProdShape(1, dshape.ndim())), s);
+ Tensor<cpu, 2, int32_t> out = out_data[fullc::kOut].get_with_shape<cpu, 2, int32_t>(
+ Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
+
+ auto data_temp = data.dptr_;
+ auto weight_temp = weight.dptr_;
+ auto output_temp = out.dptr_;
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const float alpha = 1.0f;
const float beta = 1.0f;
@@ -167,7 +201,6 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
const MKL_INT8 ob = 0;
MKL_INT32 oc = 0;
const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim());
- Stream<cpu> *s = ctx.get_stream<cpu>();
// cblas_gemm_s8u8s32 required first matrix must be uint8
// shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
int shift = 128;
@@ -179,16 +212,23 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
shiftdata.dptr_[i] = data_temp[i] + shift;
}
- Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
- out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
- in_data[num_inputs].data().dptr<float>(), in_data[num_inputs+1].data().dptr<float>(),
- in_data[num_inputs+2].data().dptr<float>(), in_data[num_inputs+3].data().dptr<float>());
+ Tensor<cpu, 1, float> min_output = out_data[1].get<cpu, 1, float>(s);
+ Tensor<cpu, 1, float> max_output = out_data[2].get<cpu, 1, float>(s);
+ Tensor<cpu, 1, float> min_data = in_data[num_inputs].get<cpu, 1, float>(s);
+ Tensor<cpu, 1, float> max_data = in_data[num_inputs + 1].get<cpu, 1, float>(s);
+ Tensor<cpu, 1, float> min_weight = in_data[num_inputs + 2].get<cpu, 1, float>(s);
+ Tensor<cpu, 1, float> max_weight = in_data[num_inputs + 3].get<cpu, 1, float>(s);
+
+ Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1, min_output.dptr_,
+ max_output.dptr_, min_data.dptr_, max_data.dptr_, min_weight.dptr_, max_weight.dptr_);
if (!param.no_bias) {
- const NDArray& bias = in_data[2];
- Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.data().dptr<int32_t>(),
- bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(),
- out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(),
- in_data[8].data().dptr<float>());
+ Tensor<cpu, 1, int8_t> bias = in_data[fullc::kBias].get_with_shape<cpu, 1, int8_t>(
+ Shape1(wshape[0]), s);
+ Tensor<cpu, 1, float> min_bias = in_data[num_inputs + 4].get<cpu, 1, float>(s);
+ Tensor<cpu, 1, float> max_bias = in_data[num_inputs + 5].get<cpu, 1, float>(s);
+
+ Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.dptr_,
+ bias.dptr_, min_output.dptr_, max_output.dptr_, min_bias.dptr_, max_bias.dptr_);
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < m * n; ++i) {
@@ -216,11 +256,11 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
shiftdata.dptr_,
k,
oa,
- weight.data().dptr<SrcType>(),
+ weight.dptr_,
k,
ob,
beta,
- out.data().dptr<int32_t>(),
+ out.dptr_,
n,
&oc);
#else
@@ -230,6 +270,21 @@ void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
#endif
}
+#if MXNET_USE_MKLDNN == 1
+void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data) {
+ if (in_data[fullc::kData].dtype() == mshadow::kInt8) {
+ FallBackCompute(QuantizedFullyConnectedForwardCPU, attrs, ctx, in_data, req, out_data);
+ return;
+ }
+
+ MKLDNNQuantizedFullyConnectedForward(attrs, ctx, in_data, req, out_data);
+}
+#endif
+
NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
.describe(R"code(Fully Connected operator for input, weight and bias data type of int8,
and accumulates in type int32 for the output. For each argument, two more arguments of type
@@ -268,8 +323,11 @@ and max thresholds representing the threholds for quantizing the float32 output
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
-.set_attr<FComputeEx>("FComputeEx<cpu>",
- QuantizedFullyConnectedForward<int8_t>)
+.set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FComputeEx>("FComputeEx<cpu>", QuantizedFullyConnectedForwardExCPU)
+#endif
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
new file mode 100644
index 0000000..94e2bda
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -0,0 +1,442 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_fc.cc
+ * \brief MKLDNN (Quantized) FullyConnected operator based on subgraph
+ * \author Ciyong Chen
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <utility>
+#include <vector>
+#include <string>
+#include "../common.h"
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+#include "../../nn/mkldnn/mkldnn_ops-inl.h"
+#include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
+#include "../../quantization/quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNFCOp {
+ public:
+ explicit SgMKLDNNFCOp(const nnvm::NodeAttrs &attrs)
+ : initialized_(false),
+ subgraph_sym_(*attrs.subgraphs[0]),
+ full_param_(nnvm::get<MKLDNNFCFullParam>(attrs.parsed)) {}
+
+ void Forward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs);
+
+ void Backward(const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports "
+ "inference computation.";
+ }
+
+ private:
+ bool initialized_;
+ nnvm::Symbol subgraph_sym_;
+ MKLDNNFCFullParam full_param_;
+ std::shared_ptr<MKLDNNFullyConnectedForward> fwd_;
+ NDArray cached_weight_;
+ NDArray cached_bias_;
+ float cached_min_data_;
+ float cached_max_data_;
+ float cached_min_weight_;
+ float cached_max_weight_;
+ float cached_min_bias_;
+ float cached_max_bias_;
+};
+
+void SgMKLDNNFCOp::Forward(const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data) {
+ auto &mkldnn_param = full_param_.mkldnn_param;
+ auto &default_param = full_param_.default_param;
+ bool has_bias = !default_param.no_bias;
+ size_t base_num_inputs = has_bias ? 3 : 2;
+ size_t total_num_inputs = base_num_inputs;
+ size_t base_num_outputs = 1;
+ size_t total_num_outputs = base_num_outputs;
+
+ float min_data = 0.0;
+ float max_data = 0.0;
+ float min_weight = 0.0;
+ float max_weight = 0.0;
+ float min_bias = 0.0;
+ float max_bias = 0.0;
+ float *min_output_ptr = nullptr;
+ float *max_output_ptr = nullptr;
+
+ if (mkldnn_param.quantized) {
+ total_num_inputs = base_num_inputs * 3;
+ min_data = in_data[base_num_inputs].data().dptr<float>()[0];
+ max_data = in_data[base_num_inputs + 1].data().dptr<float>()[0];
+ min_weight = in_data[base_num_inputs + 2].data().dptr<float>()[0];
+ max_weight = in_data[base_num_inputs + 3].data().dptr<float>()[0];
+ if (has_bias) {
+ min_bias = in_data[base_num_inputs + 4].data().dptr<float>()[0];
+ max_bias = in_data[base_num_inputs + 5].data().dptr<float>()[0];
+ }
+ if (!mkldnn_param.enable_float_output) {
+ total_num_outputs = base_num_outputs * 3;
+ min_output_ptr = out_data[1].data().dptr<float>();
+ max_output_ptr = out_data[2].data().dptr<float>();
+ }
+ }
+ CHECK_EQ(in_data.size(), total_num_inputs);
+ CHECK_EQ(out_data.size(), total_num_outputs);
+
+ NDArray data = in_data[fullc::kData];
+ NDArray weight = in_data[fullc::kWeight];
+ NDArray output = out_data[fullc::kOut];
+ const mxnet::TShape &ishape = data.shape();
+ if (mkldnn_param.quantized && ishape.ndim() != 2) {
+ CHECK(default_param.flatten)
+ << "QuantizedFullyConnected only supports flatten=true when ishape.ndim() != 2 for now.";
+ }
+
+ mkldnn::memory::desc out_md = GetMemDesc(output);
+ MKLDNNFCFlattenData(default_param, out_data[fullc::kOut], &data, &out_md);
+
+ if (initialized_ && mkldnn_param.quantized) {
+ if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+ cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
+ (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) {
+ initialized_ = false;
+ }
+ }
+
+ if (!initialized_) {
+ cached_min_data_ = min_data;
+ cached_max_data_ = max_data;
+ cached_min_weight_ = min_weight;
+ cached_max_weight_ = max_weight;
+ if (has_bias) {
+ cached_bias_ = in_data[fullc::kBias];
+ } else {
+ cached_bias_ = NDArray();
+ }
+
+ if (mkldnn_param.quantized) {
+ CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
+ auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range;
+ float data_scale = data_range / MaxAbs(cached_min_data_, cached_max_data_);
+ float weight_scale = kInt8Range / MaxAbs(cached_min_weight_, cached_max_weight_);
+ float quantized_out_range = mkldnn_param.with_relu ? kUint8Range : kInt8Range;
+
+ if (has_bias) {
+ NDArray bias = in_data[fullc::kBias];
+ float bias_int32_rescale = data_scale * weight_scale *
+ MaxAbs(min_bias, max_bias) / kInt8Range;
+
+ cached_bias_ = NDArray(bias.storage_type(), bias.shape(),
+ bias.ctx(), true, mshadow::kInt32);
+ int8_t *bias_ptr = bias.data().dptr<int8_t>();
+ int32_t *quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
+ size_t bias_size = bias.shape().Size();
+ #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+ for (size_t i = 0; i < bias_size; ++i) {
+ quantized_bias_ptr[i] = bias_ptr[i] * bias_int32_rescale;
+ }
+ }
+
+ if (mkldnn_param.enable_float_output) {
+ full_param_.output_scales[0] = 1.0 / data_scale / weight_scale;
+ full_param_.requantize_scales.resize(0);
+ } else if (mkldnn_param.min_calib_range.has_value() &&
+ mkldnn_param.max_calib_range.has_value()) {
+ full_param_.output_scales.resize(0);
+ *min_output_ptr = mkldnn_param.min_calib_range.value();
+ *max_output_ptr = mkldnn_param.max_calib_range.value();
+
+ full_param_.requantize_scales[0] = quantized_out_range /
+ MaxAbs(*min_output_ptr, *max_output_ptr) / data_scale / weight_scale;
+ } else {
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
+ min_output_ptr, max_output_ptr, &min_data, &max_data, &min_weight, &max_weight);
+ }
+ }
+
+ fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, weight,
+ (has_bias ? &cached_bias_ : nullptr), out_md));
+ initialized_ = true;
+ }
+ std::vector<NDArray> new_inputs;
+ std::vector<OpReqType> new_req;
+ if (has_bias) {
+ new_inputs = {data, weight, cached_bias_};
+ new_req = {req[fullc::kData], req[fullc::kWeight], req[fullc::kBias]};
+ } else {
+ new_inputs = {data, weight};
+ new_req = {req[fullc::kData], req[fullc::kWeight]};
+ }
+
+ MKLDNNFCForwardFullFeature(full_param_, ctx, fwd_.get(), new_inputs, new_req, out_data);
+}
+
+static void SgMKLDNNFCParamParser(nnvm::NodeAttrs *attrs) {
+ MKLDNNFCFullParam full_param;
+ try {
+ full_param.mkldnn_param.Init(attrs->dict);
+ } catch (const dmlc::ParamError &e) {
+ std::ostringstream os;
+ os << e.what();
+ os << ", in operator " << attrs->op->name << "("
+ << "name=\"" << attrs->name << "\"";
+ for (const auto &k : attrs->dict) {
+ os << ", " << k.first << "=\"" << k.second << "\"";
+ }
+ os << ")";
+ throw dmlc::ParamError(os.str());
+ }
+ auto subgraph_sym = attrs->subgraphs[0];
+ DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) {
+ if (node->is_variable()) return;
+ auto &node_name = node->op()->name;
+ if (node_name == "FullyConnected") {
+ full_param.default_param =
+ nnvm::get<FullyConnectedParam>(node->attrs.parsed);
+ }
+ });
+ attrs->parsed = std::move(full_param);
+}
+
+static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs &attrs) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
+ if (full_param.mkldnn_param.quantized) {
+ input_names.emplace_back("min_data");
+ input_names.emplace_back("max_data");
+ input_names.emplace_back("min_weight");
+ input_names.emplace_back("max_weight");
+ if (!full_param.default_param.no_bias) {
+ input_names.emplace_back("min_bias");
+ input_names.emplace_back("max_bias");
+ }
+ }
+ return input_names;
+}
+
+static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs &attrs) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ if (full_param.mkldnn_param.quantized) {
+ if (full_param.mkldnn_param.enable_float_output)
+ return std::vector<std::string>{"output"};
+ else
+ return std::vector<std::string>{"output", "min_output", "max_output"};
+ } else {
+ return std::vector<std::string>{"output"};
+ }
+}
+
+template <typename T>
+static inline void FillBaseInputOutputInfo(const FullyConnectedParam ¶m,
+ std::vector<T> *base_in_attrs,
+ std::vector<T> *base_out_attrs,
+ std::vector<T> *in_attrs,
+ std::vector<T> *out_attrs) {
+ auto base_num_inputs = param.no_bias ? 2 : 3;
+
+ base_out_attrs->push_back(out_attrs->at(0));
+ for (int i = 0; i < base_num_inputs; ++i) {
+ base_in_attrs->push_back(in_attrs->at(i));
+ }
+}
+
+static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs,
+ mxnet::ShapeVector *in_shapes,
+ mxnet::ShapeVector *out_shapes) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ if (full_param.mkldnn_param.quantized) {
+ mxnet::ShapeVector base_in_shapes;
+ mxnet::ShapeVector base_out_shapes;
+ FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes,
+ in_shapes, out_shapes);
+ bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
+
+ for (size_t i = 0; i < in_shapes->size(); ++i) {
+ if (i < base_in_shapes.size())
+ in_shapes->at(i) = base_in_shapes[i];
+ else
+ SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1));
+ }
+
+ out_shapes->at(0) = base_out_shapes[0];
+ if (!full_param.mkldnn_param.enable_float_output) {
+ SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
+ SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
+ }
+ return ret;
+ } else {
+ return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
+ }
+}
+
+static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
+ std::vector<int> *in_types,
+ std::vector<int> *out_types) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ if (full_param.mkldnn_param.quantized) {
+ size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;
+
+ // TODO(ciyong): currently, only uint8 fully_connected is upported,
+ // int8 fully_connected will be supported after mkldnn v0.18
+ TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kUint8);
+ for (size_t i = 1; i < in_types->size(); ++i) {
+ if (i < base_num_inputs) {
+ TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
+ } else {
+ TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ }
+ }
+
+ if (full_param.mkldnn_param.enable_float_output) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
+ } else {
+ if (full_param.mkldnn_param.min_calib_range.has_value() &&
+ full_param.mkldnn_param.max_calib_range.has_value()) {
+ if (full_param.mkldnn_param.with_relu) {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
+ } else {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
+ }
+ } else {
+ TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
+ }
+ TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
+ }
+ return true;
+ } else {
+ return DefaultSubgraphOpType(attrs, in_types, out_types);
+ }
+}
+
+static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,
+ const int dev_mask,
+ DispatchMode *dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ if (full_param.mkldnn_param.quantized) {
+ std::vector<int> base_in_attrs;
+ std::vector<int> base_out_attrs;
+ FillBaseInputOutputInfo(full_param.default_param, &base_in_attrs, &base_out_attrs,
+ in_attrs, out_attrs);
+ bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+ &base_in_attrs, &base_out_attrs);
+
+ for (size_t i = 0; i < in_attrs->size(); ++i) {
+ if (i < base_in_attrs.size())
+ in_attrs->at(i) = base_in_attrs[i];
+ else
+ type_assign(&in_attrs->at(i), mxnet::kDefaultStorage);
+ }
+
+ out_attrs->at(0) = base_out_attrs[0];
+ if (!full_param.mkldnn_param.enable_float_output) {
+ type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
+ type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
+ }
+ return ret;
+ } else {
+ return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
+ in_attrs, out_attrs);
+ }
+}
+
+static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs,
+ Context ctx,
+ const mxnet::ShapeVector &in_shapes,
+ const std::vector<int> &in_types) {
+ return OpStatePtr::Create<SgMKLDNNFCOp>(attrs);
+}
+
+static void SgMKLDNNFCForward(const OpStatePtr &state_pointer,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ SgMKLDNNFCOp &op = state_pointer.get_state<SgMKLDNNFCOp>();
+ op.Forward(ctx, inputs, req, outputs);
+}
+
+nnvm::NodePtr SgMKLDNNFCQuantizedOp(const NodeAttrs& attrs) {
+ nnvm::NodePtr node = nnvm::Node::Create();
+ node->attrs.op = Op::Get("_sg_mkldnn_fully_connected");
+ node->attrs.name = "quantized_" + attrs.name;
+ node->attrs.dict = attrs.dict;
+ node->attrs.dict["quantized"] = "True";
+ node->attrs.subgraphs.reserve(attrs.subgraphs.size());
+ for (auto sub : attrs.subgraphs) {
+ node->attrs.subgraphs.push_back(sub);
+ }
+ node->op()->attr_parser(&(node->attrs));
+ return node;
+}
+
+NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
+.describe(R"code(_sg_mkldnn_fully_connected)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ auto num_inputs = full_param.default_param.no_bias ? 2 : 3;
+ if (full_param.mkldnn_param.quantized)
+ return num_inputs * 3;
+ else
+ return num_inputs;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+ auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ return (full_param.mkldnn_param.quantized &&
+ !full_param.mkldnn_param.enable_float_output) ? 3 : 1;
+})
+.set_attr_parser(SgMKLDNNFCParamParser)
+.set_attr<nnvm::FListInputNames>("FListInputNames", SgMKLDNNFCListInputNames)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", SgMKLDNNFCListOutputNames)
+.set_attr<mxnet::FInferShape>("FInferShape", SgMKLDNNFCInferShape)
+.set_attr<nnvm::FInferType>("FInferType", SgMKLDNNFCInferType)
+.set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNFCStorageType)
+.set_attr<FCreateOpState>("FCreateOpState", CreateSgMKLDNNFCState)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNFCForward)
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+ DefaultSubgraphOpMutableInputs)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<FQuantizedOp>("FQuantizedOp", SgMKLDNNFCQuantizedOp)
+.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; });
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_MKLDNN == 1
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc
new file mode 100644
index 0000000..d2d176f
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_fc_post_quantize_property.cc
+ * \brief Partition gragph property for MKLDNN Quantized FullyConnected operator
+ * \author Ciyong Chen
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "../common.h"
+#include "../subgraph_property.h"
+#include "../../nn/fully_connected-inl.h"
+#include "../../quantization/requantize-inl.h"
+
+namespace mxnet {
+namespace op {
+
+#define QUANTIZED_FC_NAME "_sg_mkldnn_fully_connected"
+
+class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
+ public:
+ /*! \brief pattern match status */
+ enum SelectStatus {
+ kFail = 0,
+ kStart,
+ kRequantize,
+ kSuccess,
+ };
+
+ private:
+ bool disable_all;
+ bool disable_float_output;
+ SelectStatus status;
+ std::vector<const nnvm::Node *> matched_list;
+
+ public:
+ explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all,
+ const bool dis_float_output)
+ : disable_all(dis_all),
+ disable_float_output(dis_float_output) {}
+
+ bool Select(const nnvm::Node &n) override {
+ if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) {
+ status = disable_all ? kSuccess : kStart;
+ matched_list.clear();
+ matched_list.push_back(&n);
+ return true;
+ }
+ return false;
+ }
+
+ bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ return false;
+ }
+
+ bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ if (status == kFail || status == kSuccess || new_node.is_variable())
+ return false;
+ // If n isn't the last matched node, then we encoutered a internal
+ // branch, we should pop out the node behind n and stop fusion.
+ if (matched_list.back() != &n) {
+ if (std::find(matched_list.begin(), matched_list.end(), &n) !=
+ matched_list.end()) {
+ while (matched_list.back() != &n) {
+ matched_list.pop_back();
+ }
+ }
+
+ status = kSuccess;
+ return false;
+ }
+
+ switch (status) {
+ case kStart:
+ if (new_node.op() == Op::Get("_contrib_requantize")) {
+ auto const ¶m = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
+ if (param.min_calib_range.has_value() &&
+ param.max_calib_range.has_value()) {
+ matched_list.push_back(&new_node);
+ status = kRequantize;
+ return true;
+ }
+ }
+ case kRequantize:
+ if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
+ matched_list.push_back(&new_node);
+ status = kSuccess;
+ return true;
+ }
+ default:
+ status = kSuccess;
+ return false;
+ }
+ }
+
+ std::vector<nnvm::Node *> Filter(
+ const std::vector<nnvm::Node *> &candidates) override {
+ if ((status != kSuccess) || (matched_list.size() <= 1)) {
+ return std::vector<nnvm::Node *>(0);
+ } else {
+ std::vector<nnvm::Node *> ret;
+ for (auto i : matched_list) {
+ auto non_const_i = const_cast<nnvm::Node *>(i);
+ if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
+ candidates.end()) {
+ ret.push_back(non_const_i);
+ }
+ }
+ return ret;
+ }
+ }
+};
+
+class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
+ public:
+ SgMKLDNNFCPostQuantizeProperty() {
+ disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_POST_OPT", false);
+ disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL", false);
+ disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT", false);
+
+ disable_all = disable_all || disable_fuse_all;
+ if (disable_all) {
+ LOG(INFO) << "MKLDNN FullyConnected post-quantization optimization pass is disabled.";
+ } else {
+ LOG(INFO) << "Start to execute MKLDNN FullyConected post-quantization optimization pass.";
+ }
+ }
+
+ static SubgraphPropertyPtr Create() {
+ return std::make_shared<SgMKLDNNFCPostQuantizeProperty>();
+ }
+
+ nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const override {
+ nnvm::NodePtr fc_node = nullptr;
+ nnvm::NodePtr requantize_node = nullptr;
+ nnvm::NodePtr dequantize_node = nullptr;
+
+ DFSVisit(sym.outputs, [&](const nnvm::NodePtr &node) {
+ if (node->is_variable()) return;
+ if (node->op() == Op::Get(QUANTIZED_FC_NAME)) {
+ fc_node = node;
+ } else if (node->op() == Op::Get("_contrib_requantize")) {
+ requantize_node = node;
+ } else if (node->op() == Op::Get("_contrib_dequantize")) {
+ dequantize_node = node;
+ }
+ });
+
+ CHECK_NOTNULL(fc_node);
+ CHECK_NOTNULL(requantize_node);
+ auto const &requantize_param =
+ nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
+ CHECK(requantize_param.min_calib_range.has_value());
+ CHECK(requantize_param.max_calib_range.has_value());
+
+ // When only fused quantized_fullyconnected and requantize, set min/max_cablib_range,
+ // When fused quantized_fullyconnected + requantize + dequantize, set dequantize flag to true.
+ if (dequantize_node != nullptr) {
+ fc_node->attrs.dict["enable_float_output"] = "True";
+ } else {
+ fc_node->attrs.dict["min_calib_range"] =
+ std::to_string(requantize_param.min_calib_range.value());
+ fc_node->attrs.dict["max_calib_range"] =
+ std::to_string(requantize_param.max_calib_range.value());
+ }
+ fc_node->op()->attr_parser(&(fc_node->attrs));
+ return fc_node;
+ }
+
+ SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ auto selector =
+ std::make_shared<SgMKLDNNFCPostQuantizeSelector>(disable_all,
+ disable_float_output);
+ return selector;
+ }
+
+ void ConnectSubgraphOutputs(
+ const nnvm::NodePtr n,
+ std::vector<nnvm::NodeEntry *> *output_entries) const override {
+ for (size_t i = 0; i < output_entries->size(); ++i) {
+ auto entry_ptr = output_entries->at(i);
+ *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+ }
+ }
+
+ private:
+ bool disable_all;
+ bool disable_fuse_all;
+ bool disable_float_output;
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_FC_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_MKLDNN == 1
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc
new file mode 100644
index 0000000..e4fa02d
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_fc_property.cc
+ * \brief Partition gragph property for FullyConnected operator
+ * \author Ciyong Chen
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "../common.h"
+#include "../subgraph_property.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNFCSelector : public SubgraphSelector {
+ public:
+ /*! \brief pattern match status */
+ enum SelectStatus {
+ kFail = 0,
+ kStart,
+ kSuccess,
+ };
+
+ private:
+ bool disable_all;
+ bool disable_fc_relu;
+ SelectStatus status;
+ std::vector<const nnvm::Node *> matched_list;
+
+ public:
+ SgMKLDNNFCSelector(const bool dis_all, const bool dis_fc_relu)
+ : disable_all(dis_all),
+ disable_fc_relu(dis_fc_relu) {}
+
+ bool Select(const nnvm::Node &n) override {
+ if (n.op() == Op::Get("FullyConnected")) {
+ status = disable_all ? kSuccess : kStart;
+ matched_list.clear();
+ matched_list.push_back(&n);
+ return true;
+ }
+ return false;
+ }
+
+ bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ return false;
+ }
+
+ bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ if (status == kFail || status == kSuccess || new_node.is_variable())
+ return false;
+
+ // If n isn't the last matched node, then we encoutered a internal
+ // branch, we should pop out the node behind n and stop fusion.
+ if (matched_list.back() != &n) {
+ if (std::find(matched_list.begin(), matched_list.end(), &n) !=
+ matched_list.end()) {
+ while (matched_list.back() != &n) {
+ matched_list.pop_back();
+ }
+ }
+
+ status = kSuccess;
+ return false;
+ }
+
+ switch (status) {
+ case kStart:
+ if ((!disable_fc_relu) &&
+ new_node.op() == Op::Get("Activation") &&
+ new_node.attrs.dict.at("act_type") == "relu") {
+ matched_list.push_back(&new_node);
+ status = kSuccess;
+ return true;
+ }
+ default:
+ status = kSuccess;
+ return false;
+ }
+ }
+
+ std::vector<nnvm::Node *> Filter(
+ const std::vector<nnvm::Node *> &candidates) override {
+ if (status == kFail) {
+ return std::vector<nnvm::Node *>(0);
+ } else {
+ std::vector<nnvm::Node *> ret;
+ for (auto i : matched_list) {
+ auto non_const_i = const_cast<nnvm::Node *>(i);
+ if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
+ candidates.end()) {
+ ret.push_back(non_const_i);
+ }
+ }
+ return candidates;
+ }
+ }
+};
+
+class SgMKLDNNFCProperty : public SubgraphProperty {
+ public:
+ SgMKLDNNFCProperty() {
+ disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", false);
+ disable_fc_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_FC_RELU", false);
+
+ disable_all = disable_all || disable_fc_relu;
+ if (disable_all) {
+ LOG(INFO) << "MKLDNN FullyConnected optimization pass is disabled.";
+ } else {
+ LOG(INFO) << "Start to execute MKLDNN FullyConnected optimization pass.";
+ }
+ }
+
+ static SubgraphPropertyPtr Create() {
+ return std::make_shared<SgMKLDNNFCProperty>();
+ }
+
+ nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const override {
+ nnvm::NodePtr n = nnvm::Node::Create();
+ // This op has single output, remove duplicated.
+ auto last_node = sym.outputs[0].node;
+ nnvm::Symbol new_sym;
+ new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0});
+ std::ostringstream node_name;
+ node_name << "sg_mkldnn_";
+ DFSVisit(new_sym.outputs, [&](const nnvm::NodePtr &node) {
+ if (node->is_variable()) return;
+ auto &sub_name = node->op()->name;
+ if (sub_name == "FullyConnected") {
+ node_name << "fully_connected_";
+ } else if ((sub_name == "Activation") &&
+ (node->attrs.dict.at("act_type") == "relu")) {
+ node_name << "relu_";
+ n->attrs.dict["with_relu"] = "True";
+ }
+ });
+ node_name << std::to_string(subgraph_id);
+ n->attrs.name = node_name.str();
+ n->attrs.op = Op::Get("_sg_mkldnn_fully_connected");
+ CHECK(n->attrs.op);
+ n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+ n->op()->attr_parser(&(n->attrs));
+ return n;
+ }
+
+ SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ auto selector = std::make_shared<SgMKLDNNFCSelector>(
+ disable_all, disable_fc_relu);
+ return selector;
+ }
+
+ void ConnectSubgraphOutputs(
+ const nnvm::NodePtr n,
+ std::vector<nnvm::NodeEntry *> *output_entries) const override {
+ // Connect all extern output entries to output[0]
+ for (size_t i = 0; i < output_entries->size(); ++i) {
+ auto entry_ptr = output_entries->at(i);
+ *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+ }
+ }
+
+ private:
+ bool disable_all;
+ bool disable_fc_relu;
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_FC, SgMKLDNNFCProperty);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_MKLDNN == 1
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 8de854c..e6fe001 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -32,17 +32,40 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../unittest/'))
from common import with_seed
from mxnet.test_utils import assert_almost_equal
+import itertools
+
+OP_NAME='op_name'
+QUANTIZED_OP_NAME='quantized_op_name'
+SG_PASS_NAME='sg_pass_name'
+POST_SG_PASS_NAME='post_sg_pass_name'
+config = {
+ 'conv': {
+ OP_NAME: 'sg_mkldnn_conv',
+ QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv',
+ SG_PASS_NAME: 'MKLDNN',
+ POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE'
+ },
+ 'fc': {
+ OP_NAME: 'sg_mkldnn_fully_connected',
+ QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected',
+ SG_PASS_NAME: 'MKLDNN_FC',
+ POST_SG_PASS_NAME: 'MKLDNN_POST_FC_QUANTIZE'
+ }
+}
DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)]
-def check_qsym_calibrated(qsym, out_type):
- assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1
+def check_qsym_calibrated(qsym, out_type, name='conv'):
+ quantized_op_name = config[name][QUANTIZED_OP_NAME]
+ assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1
for k, v in qsym.attr_dict().items():
- if k.find('quantized_sg_mkldnn_conv') != -1:
- assert 'min_calib_range' in v
- assert 'max_calib_range' in v
if k.find('_quantize') != -1:
assert v['out_type'] == out_type
+ if k.find(quantized_op_name) != -1:
+ if name == 'fc' and 'enable_float_output' in v:
+ continue
+ assert 'min_calib_range' in v
+ assert 'max_calib_range' in v
def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape):
mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
@@ -81,22 +104,27 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape):
data = mx.random.uniform(-1.0, 1.0, shape=data_shape)
net(data)
-def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=False):
- fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc')
+def check_quantize(sym, data_shape, out_type, name='conv',
+ check_calibration=True, gluon_forward=False):
+ sg_pass_name = config[name][SG_PASS_NAME]
+ post_sg_pass_name = config[name][POST_SG_PASS_NAME]
+
+ fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc_softmax')
if gluon_forward == True:
sym = fc
- sym_sg = fc.get_backend_symbol("MKLDNN")
+ sym_sg = sym.get_backend_symbol(sg_pass_name)
mod = Module(symbol=sym, label_names=[])
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])
else:
sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
- sym_sg = sym.get_backend_symbol("MKLDNN")
+ sym_sg = sym.get_backend_symbol(sg_pass_name)
label_shape = (data_shape[0], 10)
mod = Module(symbol=sym)
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
+
mod.init_params(mx.init.Normal(0.5))
arg_params, aux_params = mod.get_params()
@@ -108,9 +136,12 @@ def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=Fal
output.wait_to_read()
ref_out = mod.get_outputs()
+ # TODO(ciyong), exclude the second fc due to int8 fully_connected is not
+ # supported before mkldnn 0.18
excluded_sym_names = []
if mx.current_context() == mx.cpu():
- excluded_sym_names += ['fc']
+ excluded_sym_names += ['fc_softmax']
+ excluded_sym_names += ['sg_mkldnn_fully_connected_1']
calib_data = mx.nd.random.uniform(shape=data_shape)
calib_data = NDArrayIter(data=calib_data)
@@ -126,9 +157,9 @@ def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=Fal
calib_data=calib_data,
calib_layer=calib_layer,
num_calib_examples=5)
- qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE")
- if check_conv:
- check_qsym_calibrated(qsym, out_type)
+ qsym = qsym.get_backend_symbol(post_sg_pass_name)
+ if check_calibration:
+ check_qsym_calibrated(qsym, out_type, name=name)
if gluon_forward == True:
check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape)
else:
@@ -137,22 +168,24 @@ def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=Fal
for i in range(len(ref_out)):
assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)
-
@with_seed()
-def check_fusion(sym, data_shape, attrs_op):
- sym_sg = sym.get_backend_symbol("MKLDNN")
- assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1
+def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True):
+ op_name = config[name][OP_NAME]
+ sg_pass_name = config[name][SG_PASS_NAME]
+
+ sym_sg = sym.get_backend_symbol(sg_pass_name)
+ assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1
for k, v in sym_sg.attr_dict().items():
- if k.find('sg_mkldnn_conv') != -1:
+ if k.find(op_name) != -1:
for attr_op in attrs_op:
- assert v[attr_op] == 'true'
+ assert v[attr_op] in ['true', 'True']
arg_shapes, _, aux_shapes = sym.infer_shape()
arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes]
aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe.forward()
- os.environ['MXNET_SUBGRAPH_BACKEND'] = 'MKLDNN'
+ os.environ['MXNET_SUBGRAPH_BACKEND'] = sg_pass_name
exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe_sg.forward()
del os.environ['MXNET_SUBGRAPH_BACKEND']
@@ -160,18 +193,33 @@ def check_fusion(sym, data_shape, attrs_op):
assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3)
# fp32 to int8
- for out_type in ('uint8', 'int8', 'auto'):
- check_quantize(sym, data_shape, out_type)
- check_quantize(sym, data_shape, out_type, gluon_forward=True)
+ # TODO(ciyong), int8 fully_connected will be supported after mkldnn 0.18
+ if name == 'fc':
+ out_type_list = ['uint8', 'auto']
+ else:
+ out_type_list = ['uint8', 'int8', 'auto']
+
+ if check_quantization:
+ for out_type in out_type_list:
+ check_quantize(sym, data_shape, out_type, name=name)
+ # TODO(ciyong), since quantized fc save its params in int8, while gluon treat the default
+ # variable from symbol file as fp32 which results in mismatch dtype of params.
+ # Skip quantized fc in gluon pass.
+ if name != 'fc':
+ check_quantize(sym, data_shape, out_type, name=name, gluon_forward=True)
+
+def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None,
+ date_shape=(4,4,10,10), name='conv'):
+ op_name = config[name][OP_NAME]
+ sg_pass_name = config[name][SG_PASS_NAME]
-def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)):
for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs):
- sym_sg = sym.get_backend_symbol("MKLDNN")
+ sym_sg = sym.get_backend_symbol(sg_pass_name)
exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null')
attrs_dict = sym_sg.attr_dict()
for k, v in attrs_dict.items():
- if k.find('sg_mkldnn_conv') != -1:
+ if k.find(op_name) != -1:
for attr in attrs:
assert v[attr] == 'true'
for exc_attr in excluded_attr:
@@ -443,6 +491,45 @@ def neg_conv_bn_add_relu(data_shape):
excluded_attrs.append(['with_postsum_relu'])
return syms, attrs, excluded_attrs
+def single_fc(no_bias, data_shape, flatten=True):
+ attr = ['']
+ data, weight = head_symbol(data_shape)
+ fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64,
+ no_bias=no_bias, flatten=flatten)
+ return fc, attr
+
+def fc_relu(no_bias, data_shape, flatten=True):
+ attr = ['with_relu']
+ data, weight = head_symbol(data_shape)
+ fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64,
+ no_bias=no_bias, flatten=flatten)
+ relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu")
+ return relu, attr
+
+# fc + relu can't be fusion case
+# eg.1
+# fc -----------> relu
+# |
+# |
+# ---------------> [custom op]
+def neg_fc_relu(no_bias, data_shape, flatten=True):
+ syms = []
+ attrs = []
+ excluded_attrs = []
+ data, weight = head_symbol(data_shape)
+
+ # eg.1 ([custom op] = pool)
+ fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64,
+ no_bias=no_bias, flatten=flatten)
+ relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu")
+ sigmoid = mx.symbol.Activation(data=fc, name='sigmoid', act_type="sigmoid")
+ sym = tail_neg_symbol(relu, sigmoid)
+
+ syms.append(sym)
+ attrs.append([])
+ excluded_attrs.append([])
+ return syms, attrs, excluded_attrs
+
@with_seed()
def test_pos_single_conv():
for data_shape in DATA_SHAPE:
@@ -503,14 +590,14 @@ def test_pos_single_concat():
for data_shape in DATA_SHAPE:
for out_type in ('uint8', 'int8', 'auto'):
net = single_concat(data_shape, 2, 1)
- check_quantize(net, data_shape, out_type, False)
- check_quantize(net, data_shape, out_type, False, True)
+ check_quantize(net, data_shape, out_type, name='conv', check_calibration=False)
+ check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True)
net = single_concat(data_shape, 4, 2)
- check_quantize(net, data_shape, out_type, False)
- check_quantize(net, data_shape, out_type, False, True)
+ check_quantize(net, data_shape, out_type, name='conv', check_calibration=False)
+ check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True)
net = single_concat(data_shape, 4, 3)
- check_quantize(net, data_shape, out_type, False)
- check_quantize(net, data_shape, out_type, False, True)
+ check_quantize(net, data_shape, out_type, name='conv', check_calibration=False)
+ check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True)
@with_seed()
def test_neg_conv_bn():
@@ -542,6 +629,30 @@ def test_neg_conv_bn_add_relu():
syms, attrs, excluded_attrs = neg_conv_bn_add_relu(data_shape)
check_neg_fusion(syms, attrs, excluded_attrs, data_shape)
+@with_seed()
+def test_single_fc():
+ for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):
+ syms, attrs = single_fc(no_bias, dshape, flatten)
+ if flatten is True:
+ check_fusion(syms, dshape, attrs, name='fc', check_quantization=True)
+ else:
+ check_fusion(syms, dshape, attrs, name='fc', check_quantization=False)
+
+
+@with_seed()
+def test_fc_relu():
+ for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):
+ syms, attrs = fc_relu(no_bias, dshape, flatten)
+ if flatten is True:
+ check_fusion(syms, dshape, attrs, name='fc', check_quantization=True)
+ else:
+ check_fusion(syms, dshape, attrs, name='fc', check_quantization=False)
+
+@with_seed()
+def test_neg_fc_relu():
+ for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):
+ syms, attrs, excluded_attrs = neg_fc_relu(no_bias, dshape, flatten)
+ check_neg_fusion(syms, attrs, excluded_attrs, dshape, name='fc')
if __name__ == "__main__":
import nose
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index d8c7f08..e4cc277 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -278,7 +278,7 @@ def test_quantized_pooling():
@with_seed()
def test_quantized_fc():
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
- if mx.current_context().device_type != 'gpu':
+ if is_test_for_native_cpu():
hasMKL = False;
for key in os.environ.keys():
if operator.eq(key, "BUILD_TAG"):
@@ -288,31 +288,62 @@ def test_quantized_fc():
if hasMKL == False:
print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library')
return
+ elif qdtype == 'int8' and is_test_for_mkldnn():
+ print('skipped testing test_quantized_fc for mkldnn cpu int8 since it is not supported yet')
+ return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
return
+ def maxabs(a, b):
+ return mx.nd.maximum(mx.nd.abs(a), mx.nd.abs(b))
+
data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten)
arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape)
arg_names = fc_fp32.list_arguments()
fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
+ int8_range = 127.0
if qdtype == 'uint8':
data_low = 0.0
data_high = 63.0
+ quantized_range = 255.0
else:
data_low = -63.0
data_high = 63.0
- fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
- shape=data_shape).astype('int32')
- fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
- shape=arg_shapes[1]).astype('int32')
+ quantized_range = 127.0
+
+ data = mx.nd.random.uniform(low=data_low, high=data_high,
+ shape=data_shape).astype('int32')
+ weight = mx.nd.random.uniform(low=data_low, high=data_high,
+ shape=arg_shapes[1]).astype('int32')
+ fc_fp32_exe.arg_dict[arg_names[0]][:] = data
+ fc_fp32_exe.arg_dict[arg_names[1]][:] = weight
+
+ data_min = mx.nd.min(data).astype('float32')
+ data_max = mx.nd.max(data).astype('float32')
+ weight_min = mx.nd.min(weight).astype('float32')
+ weight_max = mx.nd.max(weight).astype('float32')
+ data_range = maxabs(data_min, data_max)
+ weight_range = maxabs(weight_min, weight_max)
+
if not no_bias:
- fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
- shape=arg_shapes[2]).astype('int32')
+ bias = mx.nd.random.uniform(low=data_low, high=data_high,
+ shape=arg_shapes[2]).astype('int32')
+ bias_min = mx.nd.min(bias).astype('float32')
+ bias_max = mx.nd.max(bias).astype('float32')
+ bias_range = maxabs(bias_min, bias_max)
+
+ bias_scale = int8_range / bias_range
+ data_scale = quantized_range / data_range
+ weight_scale = int8_range / weight_range
+ bias_int32_rescale = data_scale * weight_scale / bias_scale
+ new_bias = mx.nd.cast(bias, dtype='float32') * bias_int32_rescale
+ fc_fp32_exe.arg_dict[arg_names[2]][:] = new_bias.astype('int32')
+
output = fc_fp32_exe.forward()[0]
- qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
+ qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, num_hidden=num_hidden,
no_bias=no_bias, flatten=flatten)
qarg_names = fc_int8.list_arguments()
@@ -322,20 +353,19 @@ def test_quantized_fc():
fc_int8_exe = fc_int8.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null')
fc_int8_exe.arg_dict[qarg_names[0]][:] = fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
fc_int8_exe.arg_dict[qarg_names[1]][:] = fc_fp32_exe.arg_dict[arg_names[1]].astype('int8')
- quantized_range = 127.0
if no_bias:
- fc_int8_exe.arg_dict[qarg_names[2]][:] = -quantized_range
- fc_int8_exe.arg_dict[qarg_names[3]][:] = quantized_range
- fc_int8_exe.arg_dict[qarg_names[4]][:] = -quantized_range
- fc_int8_exe.arg_dict[qarg_names[5]][:] = quantized_range
+ fc_int8_exe.arg_dict[qarg_names[2]][:] = -data_range
+ fc_int8_exe.arg_dict[qarg_names[3]][:] = data_range
+ fc_int8_exe.arg_dict[qarg_names[4]][:] = -weight_range
+ fc_int8_exe.arg_dict[qarg_names[5]][:] = weight_range
else:
- fc_int8_exe.arg_dict[qarg_names[2]][:] = fc_fp32_exe.arg_dict[arg_names[2]].astype('int8')
- fc_int8_exe.arg_dict[qarg_names[3]][:] = -quantized_range
- fc_int8_exe.arg_dict[qarg_names[4]][:] = quantized_range
- fc_int8_exe.arg_dict[qarg_names[5]][:] = -quantized_range
- fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range
- fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range
- fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range
+ fc_int8_exe.arg_dict[qarg_names[2]][:] = bias.astype('int8')
+ fc_int8_exe.arg_dict[qarg_names[3]][:] = -data_range
+ fc_int8_exe.arg_dict[qarg_names[4]][:] = data_range
+ fc_int8_exe.arg_dict[qarg_names[5]][:] = -weight_range
+ fc_int8_exe.arg_dict[qarg_names[6]][:] = weight_range
+ fc_int8_exe.arg_dict[qarg_names[7]][:] = -bias_range
+ fc_int8_exe.arg_dict[qarg_names[8]][:] = bias_range
qoutput, min_range, max_range = fc_int8_exe.forward()
if no_bias: