You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/09/29 15:37:18 UTC
[incubator-mxnet] branch v1.x updated: [FEATURE] Add fusing
FullyConnected with element-wise add (#20599)
This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new a3ed54b [FEATURE] Add fusing FullyConnected with element-wise add (#20599)
a3ed54b is described below
commit a3ed54be2680ec8c65ce54433e550b8115c8525c
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Wed Sep 29 17:34:30 2021 +0200
[FEATURE] Add fusing FullyConnected with element-wise add (#20599)
* Add fusing FullyConnected with element-wise add
It adds implementation for MKLDNN_QUANTIZE backend with
quantize_mode='full' and for MKLDNN backend (floating point fuse)
* Disable fusing FC+add during first quatization pass
* Align naming convention with convolution operator
Convolution uses convention data_name_[min|max] which is object
oriented and more readable.
* Enable tests for full quantize mode and for MKLDNN backend
* Add comment in CreateSubgraphNode
---
.../nn/mkldnn/mkldnn_fully_connected-inl.h | 38 ++--
src/operator/subgraph/build_subgraph.cc | 13 ++
src/operator/subgraph/mkldnn/mkldnn_fc.cc | 194 +++++++++++----------
src/operator/subgraph/mkldnn/mkldnn_fc_property.h | 3 +
src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h | 80 ++++++---
.../subgraph/mkldnn/mkldnn_subgraph_property.cc | 9 +-
tests/python/mkl/test_subgraph.py | 21 ++-
7 files changed, 213 insertions(+), 145 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index d2ccdef..28195af 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -45,6 +45,7 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
bool enable_float_output;
bool with_eltwise;
bool with_sum;
+ bool for_quantization; // True for operator created during first quantization pass
float sum_scale = 1.0f;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
@@ -63,6 +64,9 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
"Whether there's a post with_eltwise after FullyConnected "
"operator");
DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
+ DMLC_DECLARE_FIELD(for_quantization)
+ .set_default(false)
+ .describe("True for first quantization pass");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
@@ -108,13 +112,12 @@ class FCInputIndex {
mkldnn_param.channel_wise_quantize.value();
// Calculate position of particular input in the input vector:
- int index = 0;
- data = index++;
- weight = index++;
- bias = has_bias ? index++ : 0;
- num_quantized = index + (sum_input_quantized ? 1 : 0);
- sum = mkldnn_param.with_sum ? index++ : 0;
- num_base = index;
+ int index = 0;
+ data = index++;
+ weight = index++;
+ bias = has_bias ? index++ : 0;
+ sum = mkldnn_param.with_sum ? index++ : 0;
+ num_base = index; // note number of base inputs
data_min = quantized ? index++ : 0;
data_max = quantized ? index++ : 0;
@@ -124,10 +127,20 @@ class FCInputIndex {
bias_max = (quantized && !channel_wise && has_bias) ? index++ : 0;
sum_min = sum_input_quantized ? index++ : 0;
sum_max = sum_input_quantized ? index++ : 0;
- num_total = index;
+ num_total = index; // note number of total inputs
+ }
+
+ // Returns true if sum input exists
+ bool IsSumExist() const {
+ return sum;
}
- // true if sum input is used and it is float number
+ // Returns true if bias input exists
+ bool IsBiasExist() const {
+ return bias;
+ }
+
+ // Returns true if sum input exists and it is float number
bool IsSumInputFloat() const {
return (sum && !sum_min);
}
@@ -138,12 +151,6 @@ class FCInputIndex {
return num_base;
}
- // return number of standard inputs which are quantized (represented as
- // integer)
- int GetQuantized() const {
- return num_quantized;
- }
-
// Represent index of particular input in the input vector:
int data;
int weight;
@@ -162,7 +169,6 @@ class FCInputIndex {
int num_base; // Number of standard inputs
int num_total; // Number of total inputs: standard + additional needed for
// quantization
- int num_quantized; // Number of standard inputs which are quantized
};
mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullParam& full_param,
diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc
index ba6ba03..e736680 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -663,6 +663,8 @@ void CreateSubgraphNode(nnvm::Graph* g,
subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);
const auto& indexed_graph = g->indexed_graph();
+
+ // Clear previous outputs
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto& e = n->inputs[i];
// update entry_top_order_map with newly created orig_input_entries
@@ -677,6 +679,17 @@ void CreateSubgraphNode(nnvm::Graph* g,
for (BiDirectedNode* dest_node : subgraph_nodes) {
sn->outputs.erase(dest_node->node);
}
+ }
+ }
+
+ // Set outputs according to current inputs
+ for (size_t i = 0; i < n->inputs.size(); ++i) {
+ auto& e = n->inputs[i];
+ // update input entries' source simple nodes' outputs map
+ nnvm::Node* node = e.node.get();
+ if (indexed_graph.exist(node)) {
+ const auto nid = indexed_graph.node_id(node);
+ BiDirectedNode* sn = simple_nodes[nid].get();
sn->outputs[n.get()].push_back(i);
}
}
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 123d491..9c481d6 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -77,18 +77,18 @@ class SgMKLDNNFCOp {
std::shared_ptr<mkldnn::memory> cached_out_mem_;
NDArray cached_weight_;
NDArray cached_bias_;
- float cached_min_data_;
- float cached_max_data_;
- float cached_min_weight_;
- float cached_max_weight_;
+ float cached_data_min_;
+ float cached_data_max_;
+ float cached_weight_min_;
+ float cached_weight_max_;
float cached_sum_min_;
float cached_sum_max_;
- float cached_min_bias_;
- float cached_max_bias_;
+ float cached_bias_min_;
+ float cached_bias_max_;
size_t weight_ver_;
size_t bias_ver_;
- float cached_min_output_;
- float cached_max_output_;
+ float cached_output_min_;
+ float cached_output_max_;
float data_scale_{0.0f};
std::vector<float> weight_scales_;
};
@@ -115,12 +115,12 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
const int out_max_index = out_quantized ? index++ : 0;
CHECK_EQ(out_data.size(), index); // index is equal to total number of outpits
- float min_data = 0.0f;
- float max_data = 0.0f;
- float min_weight = 0.0f;
- float max_weight = 0.0f;
- float min_bias = 0.0f;
- float max_bias = 0.0f;
+ float data_min = 0.0f;
+ float data_max = 0.0f;
+ float weight_min = 0.0f;
+ float weight_max = 0.0f;
+ float bias_min = 0.0f;
+ float bias_max = 0.0f;
const float sum_min = idx.sum_min ? in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
const float sum_max = idx.sum_max ? in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
@@ -171,31 +171,31 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
if (mkldnn_param.quantized) {
if (!channel_wise) {
- min_weight = in_data[idx.weight_min].data().dptr<float>()[0];
- max_weight = in_data[idx.weight_max].data().dptr<float>()[0];
+ weight_min = in_data[idx.weight_min].data().dptr<float>()[0];
+ weight_max = in_data[idx.weight_max].data().dptr<float>()[0];
if (has_bias) {
- min_bias = in_data[idx.bias_min].data().dptr<float>()[0];
- max_bias = in_data[idx.bias_max].data().dptr<float>()[0];
+ bias_min = in_data[idx.bias_min].data().dptr<float>()[0];
+ bias_max = in_data[idx.bias_max].data().dptr<float>()[0];
}
}
- min_data = in_data[idx.data_min].data().dptr<float>()[0];
- max_data = in_data[idx.data_max].data().dptr<float>()[0];
+ data_min = in_data[idx.data_min].data().dptr<float>()[0];
+ data_max = in_data[idx.data_max].data().dptr<float>()[0];
}
if (initialized_ && mkldnn_param.quantized &&
dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
if (channel_wise) {
- if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+ if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
weight_ver_ != weight.version() ||
(has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
initialized_ = false;
}
} else {
- if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+ if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
- cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
- (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) {
+ cached_weight_min_ != weight_min || cached_weight_max_ != weight_max ||
+ (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ != bias_max))) {
initialized_ = false;
}
}
@@ -204,17 +204,17 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
if (!initialized_) {
const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const auto engine = CpuEngine::Get()->get_engine();
- cached_min_data_ = min_data;
- cached_max_data_ = max_data;
- cached_min_weight_ = min_weight;
- cached_max_weight_ = max_weight;
+ cached_data_min_ = data_min;
+ cached_data_max_ = data_max;
+ cached_weight_min_ = weight_min;
+ cached_weight_max_ = weight_max;
weight_ver_ = weight.version();
cached_weight_ = weight;
cached_sum_min_ = sum_min;
cached_sum_max_ = sum_max;
if (has_bias) {
- cached_min_bias_ = min_bias;
- cached_max_bias_ = max_bias;
+ cached_bias_min_ = bias_min;
+ cached_bias_max_ = bias_max;
bias_ver_ = in_data[idx.bias].version();
cached_bias_ = in_data[idx.bias];
} else {
@@ -259,13 +259,13 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
bool support_channelwise_scale = false;
if (mkldnn_param.quantized) {
CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
- data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
+ data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_);
bool fuse_requantize = false;
// Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) {
- cached_min_output_ = mkldnn_param.min_calib_range.value();
- cached_max_output_ = mkldnn_param.max_calib_range.value();
+ cached_output_min_ = mkldnn_param.min_calib_range.value();
+ cached_output_max_ = mkldnn_param.max_calib_range.value();
support_channelwise_scale = true;
fuse_requantize = true;
}
@@ -297,16 +297,16 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
} else {
weight_scales_.resize(1);
weight_scales_[0] =
- GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_);
+ GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, cached_weight_max_);
if (has_bias) {
if (cached_bias_.dtype() == mshadow::kInt8) {
- float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
+ float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_bias_min_, cached_bias_max_);
float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
// TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set
// the maximum value of bias to INT_MAX / 2.
float bias_max_rescale =
- MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale;
+ MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_, cached_bias_max_) / bias_scale;
if (bias_int32_rescale > bias_max_rescale) {
// avoid overflow on bias
bias_int32_rescale = bias_max_rescale;
@@ -343,9 +343,9 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
if (mkldnn_param.with_eltwise) {
tmp_scale_ = 1.0 / data_scale_;
full_param_.eltwise_param.scale =
- GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
+ GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_);
} else {
- out_scale = GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
+ out_scale = GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_);
tmp_scale_ = out_scale / data_scale_;
}
} else {
@@ -368,22 +368,22 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
s,
1,
- &cached_min_output_,
- &cached_max_output_,
- &min_data,
- &max_data,
- &min_weight,
- &max_weight);
+ &cached_output_min_,
+ &cached_output_max_,
+ &data_min,
+ &data_max,
+ &weight_min,
+ &weight_max);
} else {
mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
s,
1,
- &cached_min_output_,
- &cached_max_output_,
- &min_data,
- &max_data,
- &min_weight,
- &max_weight);
+ &cached_output_min_,
+ &cached_output_max_,
+ &data_min,
+ &data_max,
+ &weight_min,
+ &weight_max);
}
full_param_.output_scales.resize(0);
out_scale = data_scale_ * weight_scales_[0];
@@ -470,15 +470,15 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
MKLDNNStream::Get()->Submit();
if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) {
- float *min_output_ptr = out_data[out_min_index].data().dptr<float>();
- float *max_output_ptr = out_data[out_max_index].data().dptr<float>();
+ float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
+ float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
if (mkldnn_param.shifted_output.has_value() && mkldnn_param.shifted_output.value()) {
- *min_output_ptr = 0;
- *max_output_ptr = cached_max_output_ - cached_min_output_;
+ *output_min_ptr = 0;
+ *output_max_ptr = cached_output_max_ - cached_output_min_;
} else {
- *min_output_ptr = cached_min_output_;
- *max_output_ptr = cached_max_output_;
+ *output_min_ptr = cached_output_min_;
+ *output_max_ptr = cached_output_max_;
}
}
}
@@ -534,27 +534,25 @@ static void SgMKLDNNFCParamParser(nnvm::NodeAttrs* attrs) {
static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+ auto const& mkldnn_param = full_param.mkldnn_param;
std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
- if (full_param.mkldnn_param.quantized) {
- bool channel_wise = false;
- if (full_param.mkldnn_param.with_sum) {
- input_names.emplace_back("sum");
- }
-
- if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
- full_param.mkldnn_param.channel_wise_quantize) {
- channel_wise = true;
- }
- input_names.emplace_back("min_data");
- input_names.emplace_back("max_data");
+ if (mkldnn_param.quantized) {
+ const bool channel_wise =
+ mkldnn_param.channel_wise_quantize.has_value() && mkldnn_param.channel_wise_quantize;
+ input_names.emplace_back("data_min");
+ input_names.emplace_back("data_max");
if (!channel_wise) {
- input_names.emplace_back("min_weight");
- input_names.emplace_back("max_weight");
+ input_names.emplace_back("weight_min");
+ input_names.emplace_back("weight_max");
if (!full_param.default_param.no_bias) {
- input_names.emplace_back("min_bias");
- input_names.emplace_back("max_bias");
+ input_names.emplace_back("bias_min");
+ input_names.emplace_back("bias_max");
}
}
+ if (mkldnn_param.with_sum && !mkldnn_param.enable_float_output) {
+ input_names.emplace_back("sum_min");
+ input_names.emplace_back("sum_max");
+ }
}
return input_names;
}
@@ -565,7 +563,7 @@ static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs& attrs
if (full_param.mkldnn_param.enable_float_output)
return std::vector<std::string>{"output"};
else
- return std::vector<std::string>{"output", "min_output", "max_output"};
+ return std::vector<std::string>{"output", "output_min", "output_max"};
} else {
return std::vector<std::string>{"output"};
}
@@ -618,35 +616,43 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_types) {
auto const& full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
if (full_param.mkldnn_param.quantized) {
- bool channel_wise = false;
- if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
- full_param.mkldnn_param.channel_wise_quantize) {
- channel_wise = true;
- }
- size_t num_integer_inputs = FCInputIndex(full_param).GetQuantized();
- CHECK(in_types->at(0) == mshadow::kInt8 || in_types->at(0) == mshadow::kUint8)
- << "QuantizedFullyConnected only supports int8/uint8 input, while " << in_types->at(0)
- << " is given.";
+ const bool channel_wise = full_param.mkldnn_param.channel_wise_quantize.has_value() &&
+ full_param.mkldnn_param.channel_wise_quantize;
+ const FCInputIndex idx(full_param);
+
+ CHECK(in_types->at(idx.data) == mshadow::kInt8 || in_types->at(idx.data) == mshadow::kUint8)
+ << "QuantizedFullyConnected data input only supports int8/uint8, while "
+ << in_types->at(idx.data) << " is given.";
if (channel_wise) {
- for (size_t i = 1; i < in_types->size(); ++i) {
- TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kFloat32);
+ if (idx.IsBiasExist()) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kFloat32);
}
} else {
- TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kInt8);
- if (!full_param.default_param.no_bias) {
- if (in_types->at(2) == -1) {
- TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kInt32);
+ TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kInt8);
+ if (idx.IsBiasExist()) {
+ if (in_types->at(idx.bias) == -1) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kInt32);
} else {
- CHECK(in_types->at(2) == mshadow::kInt8 || in_types->at(2) == mshadow::kInt32)
- << "QuantizedFullyConnected only supports int8/int32 bias, "
- "while "
- << in_types->at(2) << " is given.";
+ CHECK(in_types->at(idx.bias) == mshadow::kInt8 ||
+ in_types->at(idx.bias) == mshadow::kInt32)
+ << "QuantizedFullyConnected bias input only supports int8/int32, while "
+ << in_types->at(idx.bias) << " is given.";
}
}
- for (size_t i = num_integer_inputs; i < in_types->size(); ++i) {
- TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ }
+ if (idx.IsSumExist()) {
+ if (full_param.mkldnn_param.enable_float_output) {
+ TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32);
+ } else {
+ CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum) == mshadow::kUint8)
+ << "QuantizedFullyConnected sum input only supports int8/uint8, while "
+ << in_types->at(idx.sum) << " is given.";
}
}
+ for (size_t i = idx.data_min; i < in_types->size(); ++i) {
+ TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+ }
if (full_param.mkldnn_param.enable_float_output) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
index 8f50ff5..68cdbb0 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
@@ -193,6 +193,9 @@ class SgMKLDNNFCProperty : public SubgraphProperty {
auto& sub_name = node->op()->name;
if (sub_name == "FullyConnected") {
node_name << "fully_connected_";
+ if (HasAttr("quantize") && GetAttr<bool>("quantize")) {
+ n->attrs.dict["for_quantization"] = "True";
+ }
} else if (SupportMKLDNNFCEltwiseFusion(sub_name)) {
node_name << "eltwise_";
n->attrs.dict["with_eltwise"] = "True";
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
index 9d231e0..8e1b2a4 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
@@ -18,8 +18,12 @@
*/
/*
- \brief It fuses FC + SUM for floating point output in second post quantization
- pass
+ \file
+ \brief For fusing FullyConnected operator with element-wise add.
+
+ Element-wise add operator is replaced by MKLDNN FC "sum" post operator.
+ It adds FC results to existing values in output. For quantized integer version
+ this output is scaled to the proper range.
*/
#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_SUM_FUSE_H_
@@ -40,6 +44,14 @@
namespace mxnet {
namespace op {
+inline bool EndsWith(std::string const& value, std::string const& ending) {
+ if (ending.size() > value.size()) {
+ return false;
+ } else {
+ return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
+ }
+}
+
class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
private:
/*! \brief pattern match status */
@@ -49,7 +61,6 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
kSuccess,
};
- private:
bool quantized_;
SelectStatus status_;
std::vector<const nnvm::Node*> matched_list_;
@@ -60,9 +71,10 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
if (n.op() == Op::Get("_sg_mkldnn_fully_connected") && SupportMKLDNNAttr(node_attr)) {
auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
- // TODO(anko) remove fc_param.mkldnn_param.quantized from if below
- // to fuse even for not quantized?
- if (fc_param.mkldnn_param.enable_float_output && fc_param.mkldnn_param.quantized) {
+ if ((!quantized_ && !fc_param.mkldnn_param.for_quantization) ||
+ fc_param.mkldnn_param.quantized) {
+ // Start subgraph when fusing for floats (quantized_ is false for MKLDNN backend) or
+ // when FC is already quantized (second pass for MKLDNN_QUANTIZE).
status_ = kStart;
matched_list_.clear();
matched_list_.push_back(&n);
@@ -94,7 +106,17 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
switch (status_) {
case kStart:
- if (new_node.op()->name == "elemwise_add") {
+ // Find _contrib_quantized_elemwise_add or elemwise_add
+ if (EndsWith(new_node.op()->name, "elemwise_add")) {
+ if (quantized_) {
+ auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
+ if (!fc_param.mkldnn_param.enable_float_output) {
+ // For quantized graph, when FC floating point output is not enabled
+ // elementwise add must also be quantized (min and max value have to be already stored
+ // in elementwise add).
+ CHECK_EQ(new_node.attrs.dict.count("min_calib_range"), 1);
+ }
+ }
matched_list_.push_back(&new_node);
status_ = kSuccess;
return true;
@@ -133,7 +155,7 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
SgMKLDNNFCSumFuseProperty() {}
static SubgraphPropertyPtr Create() {
- static const std::string& name = "MKLDNN FullyConnected post quantization second pass";
+ static const std::string& name = "MKLDNN fuse FullyConnected with sum";
auto property = std::make_shared<SgMKLDNNFCSumFuseProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
@@ -149,12 +171,13 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
nnvm::ObjectPtr ew_add_node = nullptr;
DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
- if (node->is_variable())
+ if (node->is_variable()) {
return;
+ }
auto& sub_name = node->op()->name;
if (sub_name == "_sg_mkldnn_fully_connected") {
fc_node = node;
- } else if (sub_name == "elemwise_add") {
+ } else if (EndsWith(sub_name, "elemwise_add")) {
ew_add_node = node;
}
});
@@ -163,13 +186,23 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
if (ew_add_node != nullptr) {
CHECK_NOTNULL(fc_node->attrs.subgraphs[0]);
auto fc_orginal = fc_node->attrs.subgraphs[0]->outputs[0].node;
+
if (fc_orginal->op() == Op::Get("FullyConnected")) {
nnvm::Symbol new_sym;
- nnvm::NodeEntry& ew_input_with_fc = (ew_add_node->inputs[1].node == fc_node)
- ? ew_add_node->inputs[1]
- : ew_add_node->inputs[0];
- ew_input_with_fc.node = fc_orginal;
- new_sym.outputs.emplace_back(ew_add_node);
+ // Create a new elemwise_add node to not alter the original one.
+ // It is needed in subgraph to properly calculate InferShape.
+ nnvm::ObjectPtr n = nnvm::Node::Create();
+ n->attrs.op = Op::Get("elemwise_add");
+ n->attrs.name = ew_add_node->attrs.name;
+
+ if (ew_add_node->inputs[0].node == fc_node) {
+ n->inputs.emplace_back(fc_orginal);
+ n->inputs.emplace_back(ew_add_node->inputs[1]);
+ } else {
+ n->inputs.emplace_back(ew_add_node->inputs[0]);
+ n->inputs.emplace_back(fc_orginal);
+ }
+ new_sym.outputs.emplace_back(n);
fc_node->attrs.subgraphs.clear();
fc_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
fc_node->attrs.dict["with_sum"] = "True";
@@ -201,10 +234,11 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
std::unordered_set<const nnvm::Node*> node_sets;
DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
- if (node->is_variable())
+ if (node->is_variable()) {
return;
+ }
node_sets.insert(node.get());
- if (node->op()->name == "elemwise_add") {
+ if (EndsWith(node->op()->name, "elemwise_add")) {
const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
// Make sure n is the left operand of sum, if not,
@@ -219,12 +253,16 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
orig_input_entries->begin() + 1,
orig_input_entries->begin() + base_inputs);
} else {
+ const int not_rotated_end =
+ (fc_param.mkldnn_param.quantized && !fc_param.mkldnn_param.enable_float_output) ? 2
+ : 0;
+
std::rotate(input_entries->begin() + base_inputs - 1,
- input_entries->end() - 1,
- input_entries->end());
+ input_entries->end() - 1 - not_rotated_end,
+ input_entries->end() - not_rotated_end);
std::rotate(orig_input_entries->begin() + base_inputs - 1,
- orig_input_entries->end() - 1,
- orig_input_entries->end());
+ orig_input_entries->end() - 1 - not_rotated_end,
+ orig_input_entries->end() - not_rotated_end);
}
}
});
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index cdac3e0..2701e0d 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -37,26 +37,19 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
.set_attr("context", Context::CPU());
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCSumFuseProperty);
MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE).set_attr("context", Context::CPU());
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty).set_attr("quantize", true);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty).set_attr("quantize", true);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCSumFuseProperty)
.set_attr("quantize", true);
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 1210c05..a6af10f 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -399,9 +399,14 @@ def conv_add2(no_bias, data_shape):
return sum, attr
-# fc + sum
-def fc_sum(no_bias, data_shape):
- attr = {'fc': {'with_sum': 'true', 'quantized' : 'true', 'enable_float_output': 'true'}}
+# FullyConnected + element wise add
+def fc_sum(no_bias, data_shape, quantize_mode=None):
+ attr = {'fc': {'with_sum': 'true'}}
+ if quantize_mode is not None:
+ attr['fc']['quantized'] = 'true'
+ if quantize_mode == 'smart':
+ attr['fc']['enable_float_output'] = 'true'
+
data, weight = head_symbol(data_shape)
sym1 = mx.symbol.FullyConnected(data=data, weight=weight, no_bias=no_bias, num_hidden=10)
data2 = mx.symbol.var('data_2', shape= (data_shape[0], 10), dtype="float32", init = mx.init.Normal(0.3))
@@ -898,9 +903,13 @@ def test_pos_fc_sum():
check_qsym_dummy_forward_fc_sum(qsym, batch, data_shapes)
for data_shape in DATA_SHAPE:
- for out_type in ('auto',):
- net, attrs, inputs = fc_sum(False, data_shape)
- check_quantize_fc_sum(net, inputs, out_type, attrs)
+ net, attrs, inputs = fc_sum(False, data_shape)
+ check_fusion(net,data_shape, attrs, check_quantization = False)
+ for quantize_mode in ('smart', 'full'):
+ for data_shape in DATA_SHAPE:
+ for out_type in ('auto', 'int8'):
+ net, attrs, inputs = fc_sum(False, data_shape, quantize_mode)
+ check_quantize_fc_sum(net, inputs, out_type, attrs, quantize_mode = quantize_mode)
@with_seed()