You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/09/29 15:37:18 UTC

[incubator-mxnet] branch v1.x updated: [FEATURE] Add fusing FullyConnected with element-wise add (#20599)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new a3ed54b  [FEATURE] Add fusing FullyConnected with element-wise add (#20599)
a3ed54b is described below

commit a3ed54be2680ec8c65ce54433e550b8115c8525c
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Wed Sep 29 17:34:30 2021 +0200

    [FEATURE] Add fusing FullyConnected with element-wise add (#20599)
    
    * Add fusing FullyConnected with element-wise add
    
    It adds implementation for MKLDNN_QUANTIZE backend with
    quantize_mode='full' and for MKLDNN backend (floating point fuse)
    
    * Disable fusing FC+add during first quatization pass
    
    * Align naming convention with convolution operator
    
    Convolution uses convention data_name_[min|max] which is object
    oriented and more readable.
    
    * Enable tests for full quantize mode and for MKLDNN backend
    
    * Add comment in CreateSubgraphNode
---
 .../nn/mkldnn/mkldnn_fully_connected-inl.h         |  38 ++--
 src/operator/subgraph/build_subgraph.cc            |  13 ++
 src/operator/subgraph/mkldnn/mkldnn_fc.cc          | 194 +++++++++++----------
 src/operator/subgraph/mkldnn/mkldnn_fc_property.h  |   3 +
 src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h  |  80 ++++++---
 .../subgraph/mkldnn/mkldnn_subgraph_property.cc    |   9 +-
 tests/python/mkl/test_subgraph.py                  |  21 ++-
 7 files changed, 213 insertions(+), 145 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index d2ccdef..28195af 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -45,6 +45,7 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
   bool enable_float_output;
   bool with_eltwise;
   bool with_sum;
+  bool for_quantization;  // True for operator created during first quantization pass
   float sum_scale = 1.0f;
   dmlc::optional<float> min_calib_range;  // min float value calculated from calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from calibration dataset
@@ -63,6 +64,9 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
             "Whether there's a post with_eltwise after FullyConnected "
             "operator");
     DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
+    DMLC_DECLARE_FIELD(for_quantization)
+        .set_default(false)
+        .describe("True for first quantization pass");
     DMLC_DECLARE_FIELD(min_calib_range)
         .set_default(dmlc::optional<float>())
         .describe(
@@ -108,13 +112,12 @@ class FCInputIndex {
                               mkldnn_param.channel_wise_quantize.value();
 
     // Calculate position of particular input in the input vector:
-    int index     = 0;
-    data          = index++;
-    weight        = index++;
-    bias          = has_bias ? index++ : 0;
-    num_quantized = index + (sum_input_quantized ? 1 : 0);
-    sum           = mkldnn_param.with_sum ? index++ : 0;
-    num_base      = index;
+    int index = 0;
+    data      = index++;
+    weight    = index++;
+    bias      = has_bias ? index++ : 0;
+    sum       = mkldnn_param.with_sum ? index++ : 0;
+    num_base  = index;  // note number of base inputs
 
     data_min   = quantized ? index++ : 0;
     data_max   = quantized ? index++ : 0;
@@ -124,10 +127,20 @@ class FCInputIndex {
     bias_max   = (quantized && !channel_wise && has_bias) ? index++ : 0;
     sum_min    = sum_input_quantized ? index++ : 0;
     sum_max    = sum_input_quantized ? index++ : 0;
-    num_total  = index;
+    num_total  = index;  // note number of total inputs
+  }
+
+  // Returns true if sum input exists
+  bool IsSumExist() const {
+    return sum;
   }
 
-  // true if sum input is used and it is float number
+  // Returns true if bias input exists
+  bool IsBiasExist() const {
+    return bias;
+  }
+
+  // Returns true if sum input exists and it is float number
   bool IsSumInputFloat() const {
     return (sum && !sum_min);
   }
@@ -138,12 +151,6 @@ class FCInputIndex {
     return num_base;
   }
 
-  // return number of standard inputs which are quantized (represented as
-  // integer)
-  int GetQuantized() const {
-    return num_quantized;
-  }
-
   // Represent index of particular input in the input vector:
   int data;
   int weight;
@@ -162,7 +169,6 @@ class FCInputIndex {
   int num_base;       // Number of standard inputs
   int num_total;      // Number of total inputs: standard + additional needed for
                       // quantization
-  int num_quantized;  // Number of standard inputs which are quantized
 };
 
 mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullParam& full_param,
diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc
index ba6ba03..e736680 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -663,6 +663,8 @@ void CreateSubgraphNode(nnvm::Graph* g,
       subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);
 
     const auto& indexed_graph = g->indexed_graph();
+
+    // Clear previous outputs
     for (size_t i = 0; i < n->inputs.size(); ++i) {
       auto& e = n->inputs[i];
       // update entry_top_order_map with newly created orig_input_entries
@@ -677,6 +679,17 @@ void CreateSubgraphNode(nnvm::Graph* g,
         for (BiDirectedNode* dest_node : subgraph_nodes) {
           sn->outputs.erase(dest_node->node);
         }
+      }
+    }
+
+    // Set outputs according to current inputs
+    for (size_t i = 0; i < n->inputs.size(); ++i) {
+      auto& e = n->inputs[i];
+      // update input entries' source simple nodes' outputs map
+      nnvm::Node* node = e.node.get();
+      if (indexed_graph.exist(node)) {
+        const auto nid     = indexed_graph.node_id(node);
+        BiDirectedNode* sn = simple_nodes[nid].get();
         sn->outputs[n.get()].push_back(i);
       }
     }
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 123d491..9c481d6 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -77,18 +77,18 @@ class SgMKLDNNFCOp {
   std::shared_ptr<mkldnn::memory> cached_out_mem_;
   NDArray cached_weight_;
   NDArray cached_bias_;
-  float cached_min_data_;
-  float cached_max_data_;
-  float cached_min_weight_;
-  float cached_max_weight_;
+  float cached_data_min_;
+  float cached_data_max_;
+  float cached_weight_min_;
+  float cached_weight_max_;
   float cached_sum_min_;
   float cached_sum_max_;
-  float cached_min_bias_;
-  float cached_max_bias_;
+  float cached_bias_min_;
+  float cached_bias_max_;
   size_t weight_ver_;
   size_t bias_ver_;
-  float cached_min_output_;
-  float cached_max_output_;
+  float cached_output_min_;
+  float cached_output_max_;
   float data_scale_{0.0f};
   std::vector<float> weight_scales_;
 };
@@ -115,12 +115,12 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
   const int out_max_index = out_quantized ? index++ : 0;
   CHECK_EQ(out_data.size(), index);  // index is equal to total number of outpits
 
-  float min_data   = 0.0f;
-  float max_data   = 0.0f;
-  float min_weight = 0.0f;
-  float max_weight = 0.0f;
-  float min_bias   = 0.0f;
-  float max_bias   = 0.0f;
+  float data_min   = 0.0f;
+  float data_max   = 0.0f;
+  float weight_min = 0.0f;
+  float weight_max = 0.0f;
+  float bias_min   = 0.0f;
+  float bias_max   = 0.0f;
 
   const float sum_min = idx.sum_min ? in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
   const float sum_max = idx.sum_max ? in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
@@ -171,31 +171,31 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
 
   if (mkldnn_param.quantized) {
     if (!channel_wise) {
-      min_weight = in_data[idx.weight_min].data().dptr<float>()[0];
-      max_weight = in_data[idx.weight_max].data().dptr<float>()[0];
+      weight_min = in_data[idx.weight_min].data().dptr<float>()[0];
+      weight_max = in_data[idx.weight_max].data().dptr<float>()[0];
       if (has_bias) {
-        min_bias = in_data[idx.bias_min].data().dptr<float>()[0];
-        max_bias = in_data[idx.bias_max].data().dptr<float>()[0];
+        bias_min = in_data[idx.bias_min].data().dptr<float>()[0];
+        bias_max = in_data[idx.bias_max].data().dptr<float>()[0];
       }
     }
-    min_data = in_data[idx.data_min].data().dptr<float>()[0];
-    max_data = in_data[idx.data_max].data().dptr<float>()[0];
+    data_min = in_data[idx.data_min].data().dptr<float>()[0];
+    data_max = in_data[idx.data_max].data().dptr<float>()[0];
   }
 
   if (initialized_ && mkldnn_param.quantized &&
       dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
     if (channel_wise) {
-      if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
           cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
           weight_ver_ != weight.version() ||
           (has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
         initialized_ = false;
       }
     } else {
-      if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
           cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
-          cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
-          (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) {
+          cached_weight_min_ != weight_min || cached_weight_max_ != weight_max ||
+          (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ != bias_max))) {
         initialized_ = false;
       }
     }
@@ -204,17 +204,17 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
   if (!initialized_) {
     const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     const auto engine   = CpuEngine::Get()->get_engine();
-    cached_min_data_    = min_data;
-    cached_max_data_    = max_data;
-    cached_min_weight_  = min_weight;
-    cached_max_weight_  = max_weight;
+    cached_data_min_    = data_min;
+    cached_data_max_    = data_max;
+    cached_weight_min_  = weight_min;
+    cached_weight_max_  = weight_max;
     weight_ver_         = weight.version();
     cached_weight_      = weight;
     cached_sum_min_     = sum_min;
     cached_sum_max_     = sum_max;
     if (has_bias) {
-      cached_min_bias_ = min_bias;
-      cached_max_bias_ = max_bias;
+      cached_bias_min_ = bias_min;
+      cached_bias_max_ = bias_max;
       bias_ver_        = in_data[idx.bias].version();
       cached_bias_     = in_data[idx.bias];
     } else {
@@ -259,13 +259,13 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
     bool support_channelwise_scale = false;
     if (mkldnn_param.quantized) {
       CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
-      data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
+      data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_);
 
       bool fuse_requantize = false;
       // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
       if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) {
-        cached_min_output_        = mkldnn_param.min_calib_range.value();
-        cached_max_output_        = mkldnn_param.max_calib_range.value();
+        cached_output_min_        = mkldnn_param.min_calib_range.value();
+        cached_output_max_        = mkldnn_param.max_calib_range.value();
         support_channelwise_scale = true;
         fuse_requantize           = true;
       }
@@ -297,16 +297,16 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       } else {
         weight_scales_.resize(1);
         weight_scales_[0] =
-            GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_);
+            GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, cached_weight_max_);
         if (has_bias) {
           if (cached_bias_.dtype() == mshadow::kInt8) {
-            float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
+            float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_bias_min_, cached_bias_max_);
 
             float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
             // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set
             // the maximum value of bias to INT_MAX / 2.
             float bias_max_rescale =
-                MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale;
+                MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_, cached_bias_max_) / bias_scale;
             if (bias_int32_rescale > bias_max_rescale) {
               // avoid overflow on bias
               bias_int32_rescale = bias_max_rescale;
@@ -343,9 +343,9 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
           if (mkldnn_param.with_eltwise) {
             tmp_scale_ = 1.0 / data_scale_;
             full_param_.eltwise_param.scale =
-                GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
+                GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_);
           } else {
-            out_scale  = GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
+            out_scale  = GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_);
             tmp_scale_ = out_scale / data_scale_;
           }
         } else {
@@ -368,22 +368,22 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
           mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
               s,
               1,
-              &cached_min_output_,
-              &cached_max_output_,
-              &min_data,
-              &max_data,
-              &min_weight,
-              &max_weight);
+              &cached_output_min_,
+              &cached_output_max_,
+              &data_min,
+              &data_max,
+              &weight_min,
+              &weight_max);
         } else {
           mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
               s,
               1,
-              &cached_min_output_,
-              &cached_max_output_,
-              &min_data,
-              &max_data,
-              &min_weight,
-              &max_weight);
+              &cached_output_min_,
+              &cached_output_max_,
+              &data_min,
+              &data_max,
+              &weight_min,
+              &weight_max);
         }
         full_param_.output_scales.resize(0);
         out_scale = data_scale_ * weight_scales_[0];
@@ -470,15 +470,15 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
   MKLDNNStream::Get()->Submit();
 
   if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) {
-    float *min_output_ptr = out_data[out_min_index].data().dptr<float>();
-    float *max_output_ptr = out_data[out_max_index].data().dptr<float>();
+    float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
+    float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
 
     if (mkldnn_param.shifted_output.has_value() && mkldnn_param.shifted_output.value()) {
-      *min_output_ptr = 0;
-      *max_output_ptr = cached_max_output_ - cached_min_output_;
+      *output_min_ptr = 0;
+      *output_max_ptr = cached_output_max_ - cached_output_min_;
     } else {
-      *min_output_ptr = cached_min_output_;
-      *max_output_ptr = cached_max_output_;
+      *output_min_ptr = cached_output_min_;
+      *output_max_ptr = cached_output_max_;
     }
   }
 }
@@ -534,27 +534,25 @@ static void SgMKLDNNFCParamParser(nnvm::NodeAttrs* attrs) {
 
 static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs& attrs) {
   auto const& full_param               = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+  auto const& mkldnn_param             = full_param.mkldnn_param;
   std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
-  if (full_param.mkldnn_param.quantized) {
-    bool channel_wise = false;
-    if (full_param.mkldnn_param.with_sum) {
-      input_names.emplace_back("sum");
-    }
-
-    if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
-        full_param.mkldnn_param.channel_wise_quantize) {
-      channel_wise = true;
-    }
-    input_names.emplace_back("min_data");
-    input_names.emplace_back("max_data");
+  if (mkldnn_param.quantized) {
+    const bool channel_wise =
+        mkldnn_param.channel_wise_quantize.has_value() && mkldnn_param.channel_wise_quantize;
+    input_names.emplace_back("data_min");
+    input_names.emplace_back("data_max");
     if (!channel_wise) {
-      input_names.emplace_back("min_weight");
-      input_names.emplace_back("max_weight");
+      input_names.emplace_back("weight_min");
+      input_names.emplace_back("weight_max");
       if (!full_param.default_param.no_bias) {
-        input_names.emplace_back("min_bias");
-        input_names.emplace_back("max_bias");
+        input_names.emplace_back("bias_min");
+        input_names.emplace_back("bias_max");
       }
     }
+    if (mkldnn_param.with_sum && !mkldnn_param.enable_float_output) {
+      input_names.emplace_back("sum_min");
+      input_names.emplace_back("sum_max");
+    }
   }
   return input_names;
 }
@@ -565,7 +563,7 @@ static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs& attrs
     if (full_param.mkldnn_param.enable_float_output)
       return std::vector<std::string>{"output"};
     else
-      return std::vector<std::string>{"output", "min_output", "max_output"};
+      return std::vector<std::string>{"output", "output_min", "output_max"};
   } else {
     return std::vector<std::string>{"output"};
   }
@@ -618,35 +616,43 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs& attrs,
                                 std::vector<int>* out_types) {
   auto const& full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
   if (full_param.mkldnn_param.quantized) {
-    bool channel_wise = false;
-    if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
-        full_param.mkldnn_param.channel_wise_quantize) {
-      channel_wise = true;
-    }
-    size_t num_integer_inputs = FCInputIndex(full_param).GetQuantized();
-    CHECK(in_types->at(0) == mshadow::kInt8 || in_types->at(0) == mshadow::kUint8)
-        << "QuantizedFullyConnected only supports int8/uint8 input, while " << in_types->at(0)
-        << " is given.";
+    const bool channel_wise = full_param.mkldnn_param.channel_wise_quantize.has_value() &&
+                              full_param.mkldnn_param.channel_wise_quantize;
+    const FCInputIndex idx(full_param);
+
+    CHECK(in_types->at(idx.data) == mshadow::kInt8 || in_types->at(idx.data) == mshadow::kUint8)
+        << "QuantizedFullyConnected  data input only supports int8/uint8, while "
+        << in_types->at(idx.data) << " is given.";
     if (channel_wise) {
-      for (size_t i = 1; i < in_types->size(); ++i) {
-        TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+      TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kFloat32);
+      if (idx.IsBiasExist()) {
+        TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kFloat32);
       }
     } else {
-      TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kInt8);
-      if (!full_param.default_param.no_bias) {
-        if (in_types->at(2) == -1) {
-          TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kInt32);
+      TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kInt8);
+      if (idx.IsBiasExist()) {
+        if (in_types->at(idx.bias) == -1) {
+          TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kInt32);
         } else {
-          CHECK(in_types->at(2) == mshadow::kInt8 || in_types->at(2) == mshadow::kInt32)
-              << "QuantizedFullyConnected only supports int8/int32 bias, "
-                 "while "
-              << in_types->at(2) << " is given.";
+          CHECK(in_types->at(idx.bias) == mshadow::kInt8 ||
+                in_types->at(idx.bias) == mshadow::kInt32)
+              << "QuantizedFullyConnected bias input only supports int8/int32, while "
+              << in_types->at(idx.bias) << " is given.";
         }
       }
-      for (size_t i = num_integer_inputs; i < in_types->size(); ++i) {
-        TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+    }
+    if (idx.IsSumExist()) {
+      if (full_param.mkldnn_param.enable_float_output) {
+        TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32);
+      } else {
+        CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum) == mshadow::kUint8)
+            << "QuantizedFullyConnected sum input only supports int8/uint8, while "
+            << in_types->at(idx.sum) << " is given.";
       }
     }
+    for (size_t i = idx.data_min; i < in_types->size(); ++i) {
+      TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+    }
 
     if (full_param.mkldnn_param.enable_float_output) {
       TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
index 8f50ff5..68cdbb0 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h
@@ -193,6 +193,9 @@ class SgMKLDNNFCProperty : public SubgraphProperty {
       auto& sub_name = node->op()->name;
       if (sub_name == "FullyConnected") {
         node_name << "fully_connected_";
+        if (HasAttr("quantize") && GetAttr<bool>("quantize")) {
+          n->attrs.dict["for_quantization"] = "True";
+        }
       } else if (SupportMKLDNNFCEltwiseFusion(sub_name)) {
         node_name << "eltwise_";
         n->attrs.dict["with_eltwise"] = "True";
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
index 9d231e0..8e1b2a4 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
@@ -18,8 +18,12 @@
  */
 
 /*
-  \brief It fuses FC + SUM for floating point output in second post quantization
-  pass
+  \file
+  \brief For fusing FullyConnected operator with element-wise add.
+
+  Element-wise add operator is replaced by MKLDNN FC "sum" post operator.
+  It adds FC results to existing values in output. For quantized integer version
+  this output is scaled to the proper range.
 */
 
 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_SUM_FUSE_H_
@@ -40,6 +44,14 @@
 namespace mxnet {
 namespace op {
 
+inline bool EndsWith(std::string const& value, std::string const& ending) {
+  if (ending.size() > value.size()) {
+    return false;
+  } else {
+    return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
+  }
+}
+
 class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
  private:
   /*! \brief pattern match status */
@@ -49,7 +61,6 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
     kSuccess,
   };
 
- private:
   bool quantized_;
   SelectStatus status_;
   std::vector<const nnvm::Node*> matched_list_;
@@ -60,9 +71,10 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
   bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) override {
     if (n.op() == Op::Get("_sg_mkldnn_fully_connected") && SupportMKLDNNAttr(node_attr)) {
       auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
-      // TODO(anko) remove fc_param.mkldnn_param.quantized from if below
-      //            to fuse even for not quantized?
-      if (fc_param.mkldnn_param.enable_float_output && fc_param.mkldnn_param.quantized) {
+      if ((!quantized_ && !fc_param.mkldnn_param.for_quantization) ||
+          fc_param.mkldnn_param.quantized) {
+        // Start subgraph when fusing for floats (quantized_ is false for MKLDNN backend) or
+        // when FC is already quantized (second pass for MKLDNN_QUANTIZE).
         status_ = kStart;
         matched_list_.clear();
         matched_list_.push_back(&n);
@@ -94,7 +106,17 @@ class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
 
     switch (status_) {
       case kStart:
-        if (new_node.op()->name == "elemwise_add") {
+        // Find _contrib_quantized_elemwise_add or elemwise_add
+        if (EndsWith(new_node.op()->name, "elemwise_add")) {
+          if (quantized_) {
+            auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
+            if (!fc_param.mkldnn_param.enable_float_output) {
+              // For quantized graph, when FC floating point output is not enabled
+              // elementwise add must also be quantized (min and max value have to be already stored
+              // in elementwise add).
+              CHECK_EQ(new_node.attrs.dict.count("min_calib_range"), 1);
+            }
+          }
           matched_list_.push_back(&new_node);
           status_ = kSuccess;
           return true;
@@ -133,7 +155,7 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
   SgMKLDNNFCSumFuseProperty() {}
 
   static SubgraphPropertyPtr Create() {
-    static const std::string& name = "MKLDNN FullyConnected post quantization second pass";
+    static const std::string& name = "MKLDNN fuse FullyConnected with sum";
     auto property                  = std::make_shared<SgMKLDNNFCSumFuseProperty>();
     property->SetAttr<std::string>("property_name", name);
     property->SetAttr<bool>("inference_only", true);
@@ -149,12 +171,13 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
     nnvm::ObjectPtr ew_add_node = nullptr;
 
     DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
-      if (node->is_variable())
+      if (node->is_variable()) {
         return;
+      }
       auto& sub_name = node->op()->name;
       if (sub_name == "_sg_mkldnn_fully_connected") {
         fc_node = node;
-      } else if (sub_name == "elemwise_add") {
+      } else if (EndsWith(sub_name, "elemwise_add")) {
         ew_add_node = node;
       }
     });
@@ -163,13 +186,23 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
     if (ew_add_node != nullptr) {
       CHECK_NOTNULL(fc_node->attrs.subgraphs[0]);
       auto fc_orginal = fc_node->attrs.subgraphs[0]->outputs[0].node;
+
       if (fc_orginal->op() == Op::Get("FullyConnected")) {
         nnvm::Symbol new_sym;
-        nnvm::NodeEntry& ew_input_with_fc = (ew_add_node->inputs[1].node == fc_node)
-                                                ? ew_add_node->inputs[1]
-                                                : ew_add_node->inputs[0];
-        ew_input_with_fc.node             = fc_orginal;
-        new_sym.outputs.emplace_back(ew_add_node);
+        // Create a new elemwise_add node to not alter the original one.
+        // It is needed in subgraph to properly calculate InferShape.
+        nnvm::ObjectPtr n = nnvm::Node::Create();
+        n->attrs.op       = Op::Get("elemwise_add");
+        n->attrs.name     = ew_add_node->attrs.name;
+
+        if (ew_add_node->inputs[0].node == fc_node) {
+          n->inputs.emplace_back(fc_orginal);
+          n->inputs.emplace_back(ew_add_node->inputs[1]);
+        } else {
+          n->inputs.emplace_back(ew_add_node->inputs[0]);
+          n->inputs.emplace_back(fc_orginal);
+        }
+        new_sym.outputs.emplace_back(n);
         fc_node->attrs.subgraphs.clear();
         fc_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
         fc_node->attrs.dict["with_sum"] = "True";
@@ -201,10 +234,11 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
     auto const& fc_param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
     std::unordered_set<const nnvm::Node*> node_sets;
     DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
-      if (node->is_variable())
+      if (node->is_variable()) {
         return;
+      }
       node_sets.insert(node.get());
-      if (node->op()->name == "elemwise_add") {
+      if (EndsWith(node->op()->name, "elemwise_add")) {
         const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
 
         // Make sure n is the left operand of sum, if not,
@@ -219,12 +253,16 @@ class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
                       orig_input_entries->begin() + 1,
                       orig_input_entries->begin() + base_inputs);
         } else {
+          const int not_rotated_end =
+              (fc_param.mkldnn_param.quantized && !fc_param.mkldnn_param.enable_float_output) ? 2
+                                                                                              : 0;
+
           std::rotate(input_entries->begin() + base_inputs - 1,
-                      input_entries->end() - 1,
-                      input_entries->end());
+                      input_entries->end() - 1 - not_rotated_end,
+                      input_entries->end() - not_rotated_end);
           std::rotate(orig_input_entries->begin() + base_inputs - 1,
-                      orig_input_entries->end() - 1,
-                      orig_input_entries->end());
+                      orig_input_entries->end() - 1 - not_rotated_end,
+                      orig_input_entries->end() - not_rotated_end);
         }
       }
     });
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index cdac3e0..2701e0d 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -37,26 +37,19 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
     .set_attr("context", Context::CPU());
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCSumFuseProperty);
 
 MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE).set_attr("context", Context::CPU());
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty).set_attr("quantize", true);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty).set_attr("quantize", true);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCSumFuseProperty)
     .set_attr("quantize", true);
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 1210c05..a6af10f 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -399,9 +399,14 @@ def conv_add2(no_bias, data_shape):
   return sum, attr
 
 
-# fc + sum
-def fc_sum(no_bias, data_shape):
-  attr = {'fc': {'with_sum': 'true', 'quantized' : 'true', 'enable_float_output': 'true'}}
+# FullyConnected + element wise add
+def fc_sum(no_bias, data_shape, quantize_mode=None):
+  attr = {'fc': {'with_sum': 'true'}}
+  if quantize_mode is not None:
+    attr['fc']['quantized'] = 'true'
+    if quantize_mode == 'smart':
+      attr['fc']['enable_float_output'] = 'true'
+
   data, weight = head_symbol(data_shape)
   sym1 = mx.symbol.FullyConnected(data=data, weight=weight, no_bias=no_bias, num_hidden=10)
   data2 = mx.symbol.var('data_2', shape= (data_shape[0], 10), dtype="float32", init = mx.init.Normal(0.3))
@@ -898,9 +903,13 @@ def test_pos_fc_sum():
       check_qsym_dummy_forward_fc_sum(qsym, batch, data_shapes)
 
   for data_shape in DATA_SHAPE:
-    for out_type in ('auto',):
-      net, attrs, inputs = fc_sum(False, data_shape)
-      check_quantize_fc_sum(net, inputs, out_type, attrs)
+    net, attrs, inputs = fc_sum(False, data_shape)
+    check_fusion(net,data_shape, attrs, check_quantization = False)
+  for quantize_mode in ('smart', 'full'):
+    for data_shape in DATA_SHAPE:
+      for out_type in ('auto', 'int8'):
+        net, attrs, inputs = fc_sum(False, data_shape, quantize_mode)
+        check_quantize_fc_sum(net, inputs, out_type, attrs, quantize_mode = quantize_mode)
 
 
 @with_seed()