You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/07 15:05:32 UTC

[incubator-mxnet] branch v1.x updated: Improve FC + add fusing (#20915)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 723807f  Improve FC + add fusing (#20915)
723807f is described below

commit 723807f503528a19217b1ddc67f69099db9618b6
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Mon Mar 7 16:02:52 2022 +0100

    Improve FC + add fusing (#20915)
    
    - Remove for_quantization flag from final graph,
    - Rename for_quantization flag to enable_fuse_add and simplify the logics,
    - Replace Node with BiDirectedNode to not fuse FC with elemwise add when
    FC output is used also as an input of other operator,
    - Do not fuse with add if already fused with other elemwise operation
    for quantized model.
---
 .../nn/mkldnn/mkldnn_fully_connected-inl.h         |  9 +--
 src/operator/subgraph/mkldnn/mkldnn_fc_property.h  |  3 -
 src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h  | 73 +++++++++++++---------
 3 files changed, 48 insertions(+), 37 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index 28195af..6d6719e 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -45,7 +45,7 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
   bool enable_float_output;
   bool with_eltwise;
   bool with_sum;
-  bool for_quantization;  // True for operator created during first quantization pass
+  bool enable_fuse_add;
   float sum_scale = 1.0f;
   dmlc::optional<float> min_calib_range;  // min float value calculated from calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from calibration dataset
@@ -64,9 +64,10 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
             "Whether there's a post with_eltwise after FullyConnected "
             "operator");
     DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
-    DMLC_DECLARE_FIELD(for_quantization)
-        .set_default(false)
-        .describe("True for first quantization pass");
+    DMLC_DECLARE_FIELD(enable_fuse_add)
+        .set_default(true)
+        .describe(
+            "True if fusing add should happened. Temporary set for false during quantization");
     DMLC_DECLARE_FIELD(min_calib_range)
         .set_default(dmlc::optional<float>())
         .describe(
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
index 68cdbb0..8f50ff5 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
@@ -193,9 +193,6 @@ class SgMKLDNNFCProperty : public SubgraphProperty {
       auto& sub_name = node->op()->name;
       if (sub_name == "FullyConnected") {
         node_name << "fully_connected_";
-        if (HasAttr("quantize") && GetAttr<bool>("quantize")) {
-          n->attrs.dict["for_quantization"] = "True";
-        }
       } else if (SupportMKLDNNFCEltwiseFusion(sub_name)) {
         node_name << "eltwise_";
         n->attrs.dict["with_eltwise"] = "True";
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
index 8e1b2a4..891b58c 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
@@ -52,7 +52,7 @@ inline bool EndsWith(std::string const& value, std::string const& ending) {
   }
 }
 
-class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
+class SgMKLDNNFCSumFuseSelector : public SubgraphSelectorV2 {
  private:
   /*! \brief pattern match status */
   enum SelectStatus {
@@ -63,40 +63,60 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
 
   bool quantized_;
   SelectStatus status_;
-  std::vector<const nnvm::Node*> matched_list_;
+  std::vector<const BiDirectedNode*> matched_list_;
 
  public:
   explicit SgMKLDNNFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
 
-  bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
-    if (n.op() == Op::Get("_sg_mkldnn_fully_connected") && SupportMKLDNNAttr(node_attr)) {
-      auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
-      if ((!quantized_ && !fc_param.mkldnn_param.for_quantization) ||
-          fc_param.mkldnn_param.quantized) {
-        // Start subgraph when fusing for floats (quantized_ is false for MKLDNN backend) or
-        // when FC is already quantized (second pass for MKLDNN_QUANTIZE).
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    const auto n = seed_node.node;
+    if ((n->op() == Op::Get("_sg_mkldnn_fully_connected")) && SupportMKLDNNAttr(node_attr) &&
+        (seed_node.outputs.size() == 1)) {
+      auto& fc_param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
+      if (quantized_) {
+        if (fc_param.mkldnn_param.enable_fuse_add) {
+          // Do not fuse during first pass with MKLDNN_QUANTIZE backend (quantized_ = true)
+          // and mark to not fuse during quantization for run with MKLDNN backend during
+          // quantization
+          n->attrs.dict["enable_fuse_add"]      = "False";
+          fc_param.mkldnn_param.enable_fuse_add = false;
+          return false;
+        } else {
+          // On second pass MKLDNN_QUANTIZE backend fusing should happened, so
+          // set to true (default value) and remove from dictionary
+          n->attrs.dict.erase("enable_fuse_add");
+          fc_param.mkldnn_param.enable_fuse_add = true;
+        }
+      }
+      // Do not fuse for quantization if already fused with element-wise operation
+      const bool fuse = fc_param.mkldnn_param.enable_fuse_add &&
+                        (!quantized_ || !fc_param.mkldnn_param.with_eltwise);
+      if (fuse) {
         status_ = kStart;
         matched_list_.clear();
-        matched_list_.push_back(&n);
+        matched_list_.push_back(&seed_node);
         return true;
       }
     }
     return false;
   }
 
-  bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+  bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& input_node) override {
     return false;
   }
 
-  bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
-    if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) {
+  bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override {
+    const auto cur_n    = cur_node.node;
+    const auto output_n = output_node.node;
+    if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
       return false;
     }
     // If n isn't the last matched node, then we encoutered a internal
     // branch, we should pop out the node behind n and stop fusion.
-    if (matched_list_.back() != &n) {
-      if (std::find(matched_list_.begin(), matched_list_.end(), &n) != matched_list_.end()) {
-        while (matched_list_.back() != &n) {
+    if (matched_list_.back() != &cur_node) {
+      if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != matched_list_.end()) {
+        while (matched_list_.back() != &cur_node) {
           matched_list_.pop_back();
         }
       }
@@ -107,17 +127,17 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
     switch (status_) {
       case kStart:
         // Find _contrib_quantized_elemwise_add or elemwise_add
-        if (EndsWith(new_node.op()->name, "elemwise_add")) {
+        if (EndsWith(output_n->op()->name, "elemwise_add")) {
           if (quantized_) {
-            auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
+            auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(cur_n->attrs.parsed);
             if (!fc_param.mkldnn_param.enable_float_output) {
               // For quantized graph, when FC floating point output is not enabled
               // elementwise add must also be quantized (min and max value have to be already stored
               // in elementwise add).
-              CHECK_EQ(new_node.attrs.dict.count("min_calib_range"), 1);
+              CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
             }
           }
-          matched_list_.push_back(&new_node);
+          matched_list_.push_back(&output_node);
           status_ = kSuccess;
           return true;
         }
@@ -127,17 +147,10 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
     }
   }
 
-  std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) override {
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
     if (status_ == kFail) {
-      return std::vector<nnvm::Node*>(0);
+      return std::vector<BiDirectedNode*>(0);
     } else {
-      std::vector<nnvm::Node*> ret;
-      for (auto i : matched_list_) {
-        auto non_const_i = const_cast<nnvm::Node*>(i);
-        if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) {
-          ret.push_back(non_const_i);
-        }
-      }
       return candidates;
     }
   }
@@ -212,7 +225,7 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
     return fc_node;
   }
 
-  SubgraphSelectorPtr CreateSubgraphSelector() const override {
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
     bool quantized = HasAttr("quantize") ? GetAttr<bool>("quantize") : false;
     auto selector  = std::make_shared<SgMKLDNNFCSumFuseSelector>(quantized);
     return selector;