You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/07 15:05:32 UTC
[incubator-mxnet] branch v1.x updated: Improve FC + add fusing (#20915)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 723807f Improve FC + add fusing (#20915)
723807f is described below
commit 723807f503528a19217b1ddc67f69099db9618b6
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Mon Mar 7 16:02:52 2022 +0100
Improve FC + add fusing (#20915)
- Remove for_quantization flag from final graph,
- Rename for_quantization flag to enable_fuse_add and simplify the logics,
- Replace Node with BiDirectedNode to not fuse FC with elemwise add when
FC output is used also as an input of other operator,
- Do not fuse with add if already fused with other elemwise operation
for quantized model.
---
.../nn/mkldnn/mkldnn_fully_connected-inl.h | 9 +--
src/operator/subgraph/mkldnn/mkldnn_fc_property.h | 3 -
src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h | 73 +++++++++++++---------
3 files changed, 48 insertions(+), 37 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index 28195af..6d6719e 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -45,7 +45,7 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
bool enable_float_output;
bool with_eltwise;
bool with_sum;
- bool for_quantization; // True for operator created during first quantization pass
+ bool enable_fuse_add;
float sum_scale = 1.0f;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
@@ -64,9 +64,10 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
"Whether there's a post with_eltwise after FullyConnected "
"operator");
DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
- DMLC_DECLARE_FIELD(for_quantization)
- .set_default(false)
- .describe("True for first quantization pass");
+ DMLC_DECLARE_FIELD(enable_fuse_add)
+ .set_default(true)
+ .describe(
+ "True if fusing add should happened. Temporary set for false during quantization");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
index 68cdbb0..8f50ff5 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
@@ -193,9 +193,6 @@ class SgMKLDNNFCProperty : public SubgraphProperty {
auto& sub_name = node->op()->name;
if (sub_name == "FullyConnected") {
node_name << "fully_connected_";
- if (HasAttr("quantize") && GetAttr<bool>("quantize")) {
- n->attrs.dict["for_quantization"] = "True";
- }
} else if (SupportMKLDNNFCEltwiseFusion(sub_name)) {
node_name << "eltwise_";
n->attrs.dict["with_eltwise"] = "True";
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
index 8e1b2a4..891b58c 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
@@ -52,7 +52,7 @@ inline bool EndsWith(std::string const& value, std::string const& ending) {
}
}
-class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
+class SgMKLDNNFCSumFuseSelector : public SubgraphSelectorV2 {
private:
/*! \brief pattern match status */
enum SelectStatus {
@@ -63,40 +63,60 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
bool quantized_;
SelectStatus status_;
- std::vector<const nnvm::Node*> matched_list_;
+ std::vector<const BiDirectedNode*> matched_list_;
public:
explicit SgMKLDNNFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
- bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
- if (n.op() == Op::Get("_sg_mkldnn_fully_connected") && SupportMKLDNNAttr(node_attr)) {
- auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
- if ((!quantized_ && !fc_param.mkldnn_param.for_quantization) ||
- fc_param.mkldnn_param.quantized) {
- // Start subgraph when fusing for floats (quantized_ is false for MKLDNN backend) or
- // when FC is already quantized (second pass for MKLDNN_QUANTIZE).
+ bool Select(const BiDirectedNode& seed_node,
+ const std::shared_ptr<NodeAttr>& node_attr) override {
+ const auto n = seed_node.node;
+ if ((n->op() == Op::Get("_sg_mkldnn_fully_connected")) && SupportMKLDNNAttr(node_attr) &&
+ (seed_node.outputs.size() == 1)) {
+ auto& fc_param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
+ if (quantized_) {
+ if (fc_param.mkldnn_param.enable_fuse_add) {
+ // Do not fuse during first pass with MKLDNN_QUANTIZE backend (quantized_ = true)
+ // and mark to not fuse during quantization for run with MKLDNN backend during
+ // quantization
+ n->attrs.dict["enable_fuse_add"] = "False";
+ fc_param.mkldnn_param.enable_fuse_add = false;
+ return false;
+ } else {
+ // On second pass MKLDNN_QUANTIZE backend fusing should happened, so
+ // set to true (default value) and remove from dictionary
+ n->attrs.dict.erase("enable_fuse_add");
+ fc_param.mkldnn_param.enable_fuse_add = true;
+ }
+ }
+ // Do not fuse for quantization if already fused with element-wise operation
+ const bool fuse = fc_param.mkldnn_param.enable_fuse_add &&
+ (!quantized_ || !fc_param.mkldnn_param.with_eltwise);
+ if (fuse) {
status_ = kStart;
matched_list_.clear();
- matched_list_.push_back(&n);
+ matched_list_.push_back(&seed_node);
return true;
}
}
return false;
}
- bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+ bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& input_node) override {
return false;
}
- bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
- if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) {
+ bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override {
+ const auto cur_n = cur_node.node;
+ const auto output_n = output_node.node;
+ if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
return false;
}
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
- if (matched_list_.back() != &n) {
- if (std::find(matched_list_.begin(), matched_list_.end(), &n) != matched_list_.end()) {
- while (matched_list_.back() != &n) {
+ if (matched_list_.back() != &cur_node) {
+ if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != matched_list_.end()) {
+ while (matched_list_.back() != &cur_node) {
matched_list_.pop_back();
}
}
@@ -107,17 +127,17 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
switch (status_) {
case kStart:
// Find _contrib_quantized_elemwise_add or elemwise_add
- if (EndsWith(new_node.op()->name, "elemwise_add")) {
+ if (EndsWith(output_n->op()->name, "elemwise_add")) {
if (quantized_) {
- auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
+ auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(cur_n->attrs.parsed);
if (!fc_param.mkldnn_param.enable_float_output) {
// For quantized graph, when FC floating point output is not enabled
// elementwise add must also be quantized (min and max value have to be already stored
// in elementwise add).
- CHECK_EQ(new_node.attrs.dict.count("min_calib_range"), 1);
+ CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
}
}
- matched_list_.push_back(&new_node);
+ matched_list_.push_back(&output_node);
status_ = kSuccess;
return true;
}
@@ -127,17 +147,10 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
}
}
- std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) override {
+ std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
if (status_ == kFail) {
- return std::vector<nnvm::Node*>(0);
+ return std::vector<BiDirectedNode*>(0);
} else {
- std::vector<nnvm::Node*> ret;
- for (auto i : matched_list_) {
- auto non_const_i = const_cast<nnvm::Node*>(i);
- if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) {
- ret.push_back(non_const_i);
- }
- }
return candidates;
}
}
@@ -212,7 +225,7 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
return fc_node;
}
- SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
bool quantized = HasAttr("quantize") ? GetAttr<bool>("quantize") : false;
auto selector = std::make_shared<SgMKLDNNFCSumFuseSelector>(quantized);
return selector;