You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/07/07 17:00:42 UTC

[incubator-mxnet] branch v1.x updated: [FEATURE] Fuse FC + sum for quantization (#20400)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new ec5945e  [FEATURE] Fuse FC + sum for quantization (#20400)
ec5945e is described below

commit ec5945e3d653c52fdd37bcdc7794badfddd7612e
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Wed Jul 7 18:58:44 2021 +0200

    [FEATURE] Fuse FC + sum for quantization (#20400)
    
    * Fix FullyConnected channel wise quantization
    
    * Fuse FullyConnected + sum
    
    Floating point sum input could be fused with quantized FullyConnected
    operator. OneDNN post operator sum is used for that. To simplify the
    solution separate pass is applied after quantization when FC output
    type is set to float.
    
    * Fix spelling
    
    * Apply review comments
---
 .../nn/mkldnn/mkldnn_fully_connected-inl.h         |   5 +
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc   |   4 +
 src/operator/quantization/quantize_graph_pass.cc   |   1 +
 src/operator/subgraph/mkldnn/mkldnn_fc.cc          | 282 +++++++++++++++------
 src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h  | 241 ++++++++++++++++++
 .../subgraph/mkldnn/mkldnn_subgraph_property.cc    |   4 +
 tests/python/mkl/test_subgraph.py                  | 115 +++++++++
 7 files changed, 581 insertions(+), 71 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index 1c9396e..13e6319 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -31,6 +31,7 @@
 
 #include <vector>
 #include <string>
+#include <memory>
 #include "../fully_connected-inl.h"
 #include "./mkldnn_base-inl.h"
 
@@ -41,6 +42,8 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
   bool quantized;
   bool enable_float_output;
   bool with_eltwise;
+  bool with_sum;
+  float sum_scale = 1.0f;
   dmlc::optional<float> min_calib_range;  // min float value calculated from calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from calibration dataset
   dmlc::optional<bool> channel_wise_quantize;
@@ -52,6 +55,8 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
     .describe("Whether to enable float32 output");
     DMLC_DECLARE_FIELD(with_eltwise).set_default(false)
     .describe("Whether there's a post with_eltwise after FullyConnected operator");
+    DMLC_DECLARE_FIELD(with_sum).set_default(false)
+    .describe("Add post sum");
     DMLC_DECLARE_FIELD(min_calib_range)
     .set_default(dmlc::optional<float>())
     .describe("The minimum scalar value in the form of float32 obtained "
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 1cf9e22..b88eeab 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -25,6 +25,7 @@
 */
 
 #if MXNET_USE_MKLDNN == 1
+#include <unordered_map>
 #include "mkldnn_fully_connected-inl.h"
 
 namespace mxnet {
@@ -51,6 +52,9 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
                        full_param.eltwise_param.alpha,
                        full_param.eltwise_param.beta);
   }
+  if (full_param.mkldnn_param.with_sum) {
+    ops.append_sum(full_param.mkldnn_param.sum_scale);
+  }
   attr.set_post_ops(ops);
 
   if (full_param.mkldnn_param.quantized && full_param.output_scales.size()) {
diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index f5060c9..74da6e9 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -170,6 +170,7 @@ inline QuantizeType NeedQuantize(ObjectPtr node,
         if ((quantize_granularity == "channel-wise") &&
             (node->op() == Op::Get("_sg_mkldnn_fully_connected"))) {
           quantized_node->attrs.dict["channel_wise_quantize"] = "True";
+          quantized_node->op()->attr_parser(&(quantized_node->attrs));
         }
         quantized_node_map->insert(std::make_pair(node, quantized_node));
       }
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 0eff06a..0726161 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -45,6 +45,73 @@
 namespace mxnet {
 namespace op {
 
+
+static inline size_t GetInSumIndex(const MKLDNNFCFullParam &param) {
+  assert(param.mkldnn_param.with_sum);
+  return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1);
+}
+
+
+class FCInputIndex {
+ public:
+  explicit FCInputIndex(const MKLDNNFCFullParam full_param) {
+    auto &mkldnn_param = full_param.mkldnn_param;
+    const bool has_bias = !full_param.default_param.no_bias;
+    const bool quantized = mkldnn_param.quantized;
+    const bool sum_input_quantized = quantized && mkldnn_param.with_sum &&
+                                     !mkldnn_param.enable_float_output;
+    const bool channel_wise = quantized && mkldnn_param.channel_wise_quantize.has_value() &&
+                              mkldnn_param.channel_wise_quantize.value();
+
+    // Calculate position of particular input in the input vector:
+    int index     = 0;
+    data          = index++;
+    weight        = index++;
+    bias          = has_bias ? index++ : 0;
+    num_quantized = index + (sum_input_quantized ? 1 : 0);
+    sum           = mkldnn_param.with_sum ? index++: 0;
+    num_base      = index;
+
+    data_min      = quantized ? index++ : 0;
+    data_max      = quantized ? index++ : 0;
+    weight_min    = (quantized && !channel_wise) ? index++ : 0;
+    weight_max    = (quantized && !channel_wise) ? index++ : 0;
+    bias_min      = (quantized && !channel_wise && has_bias) ? index++ : 0;
+    bias_max      = (quantized && !channel_wise && has_bias) ? index++ : 0;
+    sum_min       = sum_input_quantized ? index++ : 0;
+    sum_max       = sum_input_quantized ? index++ : 0;
+    num_total     = index;
+  }
+
+  // true if sum input is used and it is float number
+  bool IsSumInputFloat() const { return (sum && !sum_min); }
+  int GetTotal() const { return num_total; }
+  int GetBase() const { return num_base; }
+
+  // return number of standard inputs which are quantized (represented as integer)
+  int GetQuantized() const { return num_quantized;}
+
+  // Represent index of particular input in the input vector:
+  int data;
+  int weight;
+  int bias;
+  int sum;
+  int data_min;
+  int data_max;
+  int weight_min;
+  int weight_max;
+  int bias_min;
+  int bias_max;
+  int sum_min;
+  int sum_max;
+
+ private:
+  int num_base;       // Number of standard inputs
+  int num_total;      // Number of total inputs: standard + additional needed for quantization
+  int num_quantized;  // Number of standard inputs which are quantized
+};
+
+
 class SgMKLDNNFCOp {
  public:
   explicit SgMKLDNNFCOp(const nnvm::NodeAttrs &attrs)
@@ -66,8 +133,8 @@ class SgMKLDNNFCOp {
 
  private:
   bool initialized_{false};
-  bool channel_wise_runtime_{false};
   bool reorder_data_{false};
+  bool inplace_{false};
   nnvm::Symbol subgraph_sym_;
   MKLDNNFCFullParam full_param_;
   mkldnn_args_map_t args_;
@@ -80,6 +147,8 @@ class SgMKLDNNFCOp {
   float cached_max_data_;
   float cached_min_weight_;
   float cached_max_weight_;
+  float cached_sum_min_;
+  float cached_sum_max_;
   float cached_min_bias_;
   float cached_max_bias_;
   size_t weight_ver_;
@@ -88,71 +157,109 @@ class SgMKLDNNFCOp {
   float cached_max_output_;
   float data_scale_{0.0f};
   std::vector<float> weight_scales_;
-  size_t total_num_inputs_;
-  size_t total_num_outputs_;
 };
 
 void SgMKLDNNFCOp::Forward(const OpContext &ctx,
                            const std::vector<NDArray> &in_data,
                            const std::vector<OpReqType> &req,
                            const std::vector<NDArray> &out_data) {
-  auto &mkldnn_param = full_param_.mkldnn_param;
-  auto &default_param = full_param_.default_param;
-  bool has_bias = !default_param.no_bias;
-  size_t base_num_inputs = has_bias ? 3 : 2;
-  size_t base_num_outputs = 1;
-
-  float min_data = 0.0f;
-  float max_data = 0.0f;
-  float min_weight = 0.0f;
-  float max_weight = 0.0f;
-  float min_bias = 0.0f;
-  float max_bias = 0.0f;
-
-  if (!initialized_) {
-    if (mkldnn_param.channel_wise_quantize.has_value() &&
-        mkldnn_param.channel_wise_quantize) {
-      channel_wise_runtime_ = true;
+  auto &mkldnn_param        = full_param_.mkldnn_param;
+  auto &default_param       = full_param_.default_param;
+  const bool has_bias       = !default_param.no_bias;
+  const bool quantized      = mkldnn_param.quantized;
+  const bool out_quantized  = mkldnn_param.quantized && !mkldnn_param.enable_float_output;
+  const bool channel_wise   = quantized && mkldnn_param.channel_wise_quantize.has_value() &&
+                              mkldnn_param.channel_wise_quantize.value();
+
+  const FCInputIndex idx(full_param_);
+
+  CHECK_EQ(in_data.size(), idx.GetTotal());
+
+  int index = 0;
+  const int out_index = index++;
+  const int out_min_index = out_quantized ? index++ : 0;
+  const int out_max_index = out_quantized ? index++ : 0;
+  CHECK_EQ(out_data.size(), index);   // index is equal to total number of outpits
+
+  float min_data    = 0.0f;
+  float max_data    = 0.0f;
+  float min_weight  = 0.0f;
+  float max_weight  = 0.0f;
+  float min_bias    = 0.0f;
+  float max_bias    = 0.0f;
+
+  const float sum_min = idx.sum_min ? in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
+  const float sum_max = idx.sum_max ? in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
+
+  NDArray data = in_data[idx.data];
+  const NDArray &weight = in_data[idx.weight];
+  NDArray output;
+
+  if (mkldnn_param.with_sum) {
+    if (!initialized_) {
+      // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option,
+      // which make check (req[out_index] == kWriteInplace) useless.
+      auto in_mkl_mem = in_data[idx.sum].GetMKLDNNData();
+      auto out_mkl_mem = out_data[out_index].GetMKLDNNData();
+      if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
+        inplace_ = true;
+      }
     }
-
-    total_num_inputs_ = base_num_inputs;
-    total_num_outputs_ = base_num_outputs;
-    if (mkldnn_param.quantized) {
-      total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3);
-      total_num_outputs_ =
-        mkldnn_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3);
+    if (inplace_) {
+      output = in_data[idx.sum];
+    } else {
+      // Not in place: copy in_data[idx.sum] into outputs[out_index].
+      auto in_mkl_mem = in_data[idx.sum].GetMKLDNNData();
+      auto out_mkl_mem = out_data[out_index].GetMKLDNNData();
+      if (out_data[out_index].dtype() == mshadow::kInt32) {
+        auto mem_desc = in_mkl_mem->get_desc();
+        auto this_dtype = get_mkldnn_type(mshadow::kInt32);
+        mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
+        mkldnn_mem_ptr tmp_mem(new mkldnn::memory(mem_desc, CpuEngine::Get()->get_engine(),
+                                                  out_mkl_mem->get_data_handle()));
+        MKLDNNStream::Get()->RegisterMem(tmp_mem);
+        MKLDNNStream::Get()->RegisterPrimArgs(
+            mkldnn::reorder(*in_mkl_mem, *tmp_mem),
+            {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
+        output = NDArray(tmp_mem);
+      } else {
+        mkldnn_mem_ptr tmp_mem(new mkldnn::memory(in_mkl_mem->get_desc(),
+                                                  CpuEngine::Get()->get_engine(),
+                                                  out_mkl_mem->get_data_handle()));
+        MKLDNNStream::Get()->RegisterMem(tmp_mem);
+        MKLDNNMemoryCopy(*in_mkl_mem, tmp_mem.get());
+        output = NDArray(tmp_mem);
+      }
     }
+  } else {
+    output = out_data[out_index];
   }
-  CHECK_EQ(in_data.size(), total_num_inputs_);
-  CHECK_EQ(out_data.size(), total_num_outputs_);
-
-  NDArray data = in_data[fullc::kData];
-  const NDArray &weight = in_data[fullc::kWeight];
-  const NDArray &output = out_data[fullc::kOut];
 
   if (mkldnn_param.quantized) {
-    if (!channel_wise_runtime_) {
-      min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr<float>()[0];
-      max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr<float>()[0];
+    if (!channel_wise) {
+      min_weight = in_data[idx.weight_min].data().dptr<float>()[0];
+      max_weight = in_data[idx.weight_max].data().dptr<float>()[0];
       if (has_bias) {
-        min_bias = in_data[base_num_inputs + quantized_fullc::kBiasMin].data().dptr<float>()[0];
-        max_bias = in_data[base_num_inputs + quantized_fullc::kBiasMax].data().dptr<float>()[0];
+        min_bias = in_data[idx.bias_min].data().dptr<float>()[0];
+        max_bias = in_data[idx.bias_max].data().dptr<float>()[0];
       }
     }
-    min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr<float>()[0];
-    max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr<float>()[0];
+    min_data = in_data[idx.data_min].data().dptr<float>()[0];
+    max_data = in_data[idx.data_max].data().dptr<float>()[0];
   }
 
   if (initialized_ && mkldnn_param.quantized &&
       dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
-    if (channel_wise_runtime_) {
+    if (channel_wise) {
       if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
           weight_ver_ != weight.version() ||
-          (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
+          (has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
         initialized_ = false;
       }
     } else {
       if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
           cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
           (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias))) {
         initialized_ = false;
@@ -169,11 +276,13 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
     cached_max_weight_ = max_weight;
     weight_ver_ = weight.version();
     cached_weight_ = weight;
+    cached_sum_min_ = sum_min;
+    cached_sum_max_ = sum_max;
     if (has_bias) {
       cached_min_bias_ = min_bias;
       cached_max_bias_ = max_bias;
-      bias_ver_ = in_data[fullc::kBias].version();
-      cached_bias_ = in_data[fullc::kBias];
+      bias_ver_ = in_data[idx.bias].version();
+      cached_bias_ = in_data[idx.bias];
     } else {
       cached_bias_ = NDArray();
     }
@@ -232,7 +341,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
       // True          True                       True
       // True          False                      Error
       // False         True/False                 False
-      if (channel_wise_runtime_ && !support_channelwise_scale) {
+      if (channel_wise && !support_channelwise_scale) {
         LOG(FATAL)
           << "Currently, channel-wise quantization requires fuse requantize or dequantize."
           << " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
@@ -240,7 +349,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
           << " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
           << " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
       }
-      support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_;
+      support_channelwise_scale = support_channelwise_scale && channel_wise;
 
       if (support_channelwise_scale) {
         MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
@@ -290,6 +399,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
       }
 
       size_t num_channel = cached_weight_.shape()[0];
+      float out_scale = 1.0f;
       if (fuse_requantize || mkldnn_param.enable_float_output) {
         float tmp_scale_ = 1.0f;
         if (fuse_requantize) {
@@ -298,10 +408,8 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
             full_param_.eltwise_param.scale =
               GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
           } else {
-            tmp_scale_ =
-              GetQuantizeScale(output.dtype(),
-                               cached_min_output_,
-                               cached_max_output_) / data_scale_;
+          out_scale = GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_);
+          tmp_scale_ = out_scale / data_scale_;
           }
         } else {
           tmp_scale_ = 1.0 / data_scale_;
@@ -329,8 +437,15 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
               &max_weight);
         }
         full_param_.output_scales.resize(0);
+        out_scale = data_scale_ * weight_scales_[0];
       }
-    }
+
+      if (mkldnn_param.with_sum && !mkldnn_param.enable_float_output) {
+        float sum_in_scale =
+          GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, cached_sum_max_);
+        mkldnn_param.sum_scale = out_scale / sum_in_scale;
+      }
+    }   // if (mkldnn_param.quantized)
 
     fwd_.reset(new MKLDNNFullyConnectedForward(full_param_, ctx.is_train, data, cached_weight_,
       (has_bias ? &cached_bias_ : nullptr), out_md));
@@ -367,6 +482,21 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
     initialized_ = true;
   }
 
+  if (mkldnn_param.with_sum) {
+    const auto& output_mem = output.GetMKLDNNData();
+    const auto& out_mem_desc = output_mem->get_desc();
+    auto dst_mem_desc = fwd_->fwd_pd.dst_desc();
+    if (out_mem_desc != dst_mem_desc) {
+      auto tmp_out_mem = output.GetMKLDNNDataReorder(dst_mem_desc);
+      dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
+      mkldnn_mem_ptr new_out_mem(new mkldnn::memory(dst_mem_desc, CpuEngine::Get()->get_engine(),
+                                                    output_mem->get_data_handle()));
+      MKLDNNStream::Get()->RegisterMem(new_out_mem);
+      MKLDNNMemoryCopy(*tmp_out_mem, new_out_mem.get());
+      output = NDArray(new_out_mem);
+    }
+  }
+
   if (reorder_data_) {
     data = data.Reorder2Default();
   }
@@ -380,8 +510,8 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
   MKLDNNStream::Get()->Submit();
 
   if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) {
-    float *min_output_ptr = out_data[quantized_fullc::kOutMin].data().dptr<float>();
-    float *max_output_ptr = out_data[quantized_fullc::kOutMax].data().dptr<float>();
+    float *min_output_ptr = out_data[out_min_index].data().dptr<float>();
+    float *max_output_ptr = out_data[out_max_index].data().dptr<float>();
     *min_output_ptr = cached_min_output_;
     *max_output_ptr = cached_max_output_;
   }
@@ -441,6 +571,10 @@ static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs &attrs)
   std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
   if (full_param.mkldnn_param.quantized) {
     bool channel_wise = false;
+    if (full_param.mkldnn_param.with_sum) {
+      input_names.emplace_back("sum");
+    }
+
     if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
         full_param.mkldnn_param.channel_wise_quantize) {
       channel_wise = true;
@@ -472,12 +606,12 @@ static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs &attrs
 }
 
 template <typename T>
-static inline void FillBaseInputOutputInfo(const FullyConnectedParam &param,
+static inline void FillBaseInputOutputInfo(const MKLDNNFCFullParam &param,
                                            std::vector<T> *base_in_attrs,
                                            std::vector<T> *base_out_attrs,
                                            std::vector<T> *in_attrs,
                                            std::vector<T> *out_attrs) {
-  auto base_num_inputs = param.no_bias ? 2 : 3;
+  auto base_num_inputs = FCInputIndex(param).GetBase();
 
   base_out_attrs->push_back(out_attrs->at(0));
   for (int i = 0; i < base_num_inputs; ++i) {
@@ -492,7 +626,7 @@ static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs,
   if (full_param.mkldnn_param.quantized) {
     mxnet::ShapeVector base_in_shapes;
     mxnet::ShapeVector base_out_shapes;
-    FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes,
+    FillBaseInputOutputInfo(full_param, &base_in_shapes, &base_out_shapes,
                             in_shapes, out_shapes);
     bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
 
@@ -524,7 +658,7 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
         full_param.mkldnn_param.channel_wise_quantize) {
       channel_wise = true;
     }
-    size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;
+    size_t num_integer_inputs = FCInputIndex(full_param).GetQuantized();
     CHECK(in_types->at(0) == mshadow::kInt8 ||
           in_types->at(0) == mshadow::kUint8)
         << "QuantizedFullyConnected only supports int8/uint8 input, while "
@@ -546,7 +680,7 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
                 << in_types->at(2) << " is given.";
           }
         }
-        for (size_t i = base_num_inputs; i < in_types->size(); ++i) {
+        for (size_t i = num_integer_inputs; i < in_types->size(); ++i) {
           TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
         }
     }
@@ -582,7 +716,7 @@ static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,
   if (full_param.mkldnn_param.quantized) {
     std::vector<int> base_in_attrs;
     std::vector<int> base_out_attrs;
-    FillBaseInputOutputInfo(full_param.default_param, &base_in_attrs, &base_out_attrs,
+    FillBaseInputOutputInfo(full_param, &base_in_attrs, &base_out_attrs,
                             in_attrs, out_attrs);
     bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode,
                                             &base_in_attrs, &base_out_attrs);
@@ -606,6 +740,18 @@ static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,
   }
 }
 
+
+std::vector<std::pair<int, int>> SgMKLDNNFCInplaceOption(
+    const NodeAttrs &attrs) {
+  auto const &param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
+  if (param.mkldnn_param.with_sum) {
+    return std::vector<std::pair<int, int>>{{FCInputIndex(param).sum, 0}};
+  } else {
+    return std::vector<std::pair<int, int>>();
+  }
+}
+
+
 static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs,
                                         Context ctx,
                                         const mxnet::ShapeVector &in_shapes,
@@ -640,13 +786,16 @@ static bool SgMKLDNNAvoidFCQuantizeInput(const NodeAttrs& attrs, const size_t in
                                          const std::string quantize_granularity) {
   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
   std::unordered_set<size_t> avoid_indexes;
+  FCInputIndex idx(full_param);
   if (quantize_granularity == "channel-wise") {
     avoid_indexes.insert(fullc::kWeight);   // weight
     if (!full_param.default_param.no_bias) {
       avoid_indexes.insert(fullc::kBias);   // bias
     }
   }
-
+  if (idx.IsSumInputFloat()) {
+      avoid_indexes.insert(idx.sum);
+  }
   return avoid_indexes.count(index_to_check);
 }
 
@@ -654,17 +803,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
 .describe(R"code(_sg_mkldnn_fully_connected)code" ADD_FILELINE)
 .set_num_inputs([](const NodeAttrs& attrs) {
   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
-  auto num_inputs = full_param.default_param.no_bias ? 2 : 3;
-  if (full_param.mkldnn_param.quantized) {
-    if (full_param.mkldnn_param.channel_wise_quantize.has_value() &&
-        full_param.mkldnn_param.channel_wise_quantize) {
-      return num_inputs + 2;  // min_data, max_data
-    } else {
-      return num_inputs * 3;
-    }
-  } else {
-    return num_inputs;
-  }
+  return FCInputIndex(full_param).GetTotal();
 })
 .set_num_outputs([](const NodeAttrs& attrs) {
   auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
@@ -689,6 +828,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
                                DefaultSubgraphOpMutableInputs)
 .set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", SgMKLDNNFCInplaceOption)
 .set_attr<FQuantizable>("FQuantizable", [](const NodeAttrs& attrs) {
     return QuantizeType::kMust;
 })
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
new file mode 100755
index 0000000..3103859
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc_sum_fuse.h
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+  \brief It fuses FC + SUM for floating point output in second post quantization pass
+*/
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_SUM_FUSE_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_SUM_FUSE_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <string>
+#include <vector>
+#include <memory>
+#include <unordered_set>
+#include <utility>
+#include "../common.h"
+#include "../../tensor/matrix_op-inl.h"
+#include "mkldnn_subgraph_base-inl.h"
+#include "mkldnn_fc-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNFCSumFuseSelector : public SubgraphSelector {
+ private:
+  /*! \brief pattern match status */
+  enum SelectStatus {
+    kFail = 0,
+    kStart,
+    kSuccess,
+  };
+
+ private:
+  bool quantized_;
+  SelectStatus status_;
+  std::vector<const nnvm::Node *> matched_list_;
+
+ public:
+  explicit SgMKLDNNFCSumFuseSelector(bool quantized) :
+      quantized_(quantized) {}
+
+  bool Select(const nnvm::Node &n, const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (n.op() == Op::Get("_sg_mkldnn_fully_connected") && SupportMKLDNNAttr(node_attr)) {
+      auto const &fc_param = nnvm::get<MKLDNNFCFullParam>(n.attrs.parsed);
+      // TODO(anko) remove fc_param.mkldnn_param.quantized from if below
+      //            to fuse even for not quantized?
+      if (fc_param.mkldnn_param.enable_float_output && fc_param.mkldnn_param.quantized) {
+        status_ = kStart;
+        matched_list_.clear();
+        matched_list_.push_back(&n);
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) {
+      return false;
+    }
+    // If n isn't the last matched node, then we encoutered a internal
+    // branch, we should pop out the node behind n and stop fusion.
+    if (matched_list_.back() != &n) {
+      if (std::find(matched_list_.begin(), matched_list_.end(), &n) !=
+        matched_list_.end()) {
+        while (matched_list_.back() != &n) {
+          matched_list_.pop_back();
+        }
+      }
+      status_ = kSuccess;
+      return false;
+    }
+
+    switch (status_) {
+      case kStart:
+        if (new_node.op()->name == "elemwise_add") {
+          matched_list_.push_back(&new_node);
+          status_ = kSuccess;
+          return true;
+        }
+      default:
+        status_ = kSuccess;
+        return false;
+    }
+  }
+
+  std::vector<nnvm::Node *> Filter(
+      const std::vector<nnvm::Node *> &candidates) override {
+    if (status_ == kFail) {
+      return std::vector<nnvm::Node *>(0);
+    } else {
+      std::vector<nnvm::Node *> ret;
+      for (auto i : matched_list_) {
+        auto non_const_i = const_cast<nnvm::Node *>(i);
+        if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
+            candidates.end()) {
+          ret.push_back(non_const_i);
+        }
+      }
+      return candidates;
+    }
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list_.size(), 1);
+    auto new_selector = SgMKLDNNFCSumFuseSelector(quantized_);
+    new_selector.Select(*matched_list_[0], nullptr);
+    *this = new_selector;
+  }
+};
+
+class SgMKLDNNFCSumFuseProperty : public SubgraphProperty {
+ public:
+  SgMKLDNNFCSumFuseProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string &name = "MKLDNN FullyConnected post quantization second pass";
+    auto property = std::make_shared<SgMKLDNNFCSumFuseProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_SUM", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                   const int subgraph_id = 0) const override {
+    nnvm::ObjectPtr fc_node = nullptr;
+    nnvm::ObjectPtr ew_add_node = nullptr;
+
+    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) {
+      if (node->is_variable()) return;
+      auto &sub_name = node->op()->name;
+      if (sub_name == "_sg_mkldnn_fully_connected") {
+        fc_node = node;
+      } else if (sub_name == "elemwise_add") {
+        ew_add_node = node;
+      }
+    });
+
+    CHECK_NOTNULL(fc_node);
+    if (ew_add_node != nullptr) {
+      CHECK_NOTNULL(fc_node->attrs.subgraphs[0]);
+      auto fc_orginal = fc_node->attrs.subgraphs[0]->outputs[0].node;
+      if (fc_orginal->op() == Op::Get("FullyConnected")) {
+        nnvm::Symbol new_sym;
+        nnvm::NodeEntry &ew_input_with_fc = (ew_add_node->inputs[1].node == fc_node) ?
+                                        ew_add_node->inputs[1] :
+                                        ew_add_node->inputs[0];
+        ew_input_with_fc.node = fc_orginal;
+        new_sym.outputs.emplace_back(ew_add_node);
+        fc_node->attrs.subgraphs.clear();
+        fc_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+        fc_node->attrs.dict["with_sum"] = "True";
+        fc_node->op()->attr_parser(&(fc_node->attrs));
+      }
+    }
+    return fc_node;
+  }
+
+  SubgraphSelectorPtr CreateSubgraphSelector() const override {
+    bool quantized = HasAttr("quantize") ? GetAttr<bool>("quantize") : false;
+    auto selector =
+      std::make_shared<SgMKLDNNFCSumFuseSelector>(quantized);
+    return selector;
+  }
+
+  void ConnectSubgraphOutputs(
+      const nnvm::ObjectPtr n,
+      std::vector<nnvm::NodeEntry *> *output_entries) const override {
+    // Connect all extern output entries to output[0]
+    for (size_t i = 0; i < output_entries->size(); ++i) {
+      auto entry_ptr = output_entries->at(i);
+      *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+    }
+  }
+
+  void ConnectSubgraphInputs(
+      const nnvm::ObjectPtr n, std::vector<nnvm::NodeEntry *> *input_entries,
+      std::vector<nnvm::NodeEntry> *orig_input_entries) const override {
+    auto sym = n->attrs.subgraphs[0];
+    auto const &fc_param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
+    std::unordered_set<const nnvm::Node *> node_sets;
+    DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr &node) {
+        if (node->is_variable()) return;
+        node_sets.insert(node.get());
+        if (node->op()->name == "elemwise_add") {
+          const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
+
+          // Make sure n is the left operand of sum, if not,
+          // switch sum operands sequence to ensure that
+          // the extra sum operand stays in the last of inputs.
+          if (node_sets.count(node->inputs[1].node.get())) {
+            std::swap(node->inputs[0],  node->inputs[1]);
+            std::rotate(input_entries->begin(),
+                        input_entries->begin() + 1,
+                        input_entries->begin() + base_inputs);
+            std::rotate(orig_input_entries->begin(),
+                        orig_input_entries->begin() + 1,
+                        orig_input_entries->begin() + base_inputs);
+          } else {
+            std::rotate(input_entries->begin() + base_inputs - 1 ,
+                        input_entries->end() - 1,
+                        input_entries->end());
+            std::rotate(orig_input_entries->begin() + base_inputs - 1,
+                        orig_input_entries->end() - 1 ,
+                        orig_input_entries->end());
+          }
+        }
+      });
+    n->inputs = *orig_input_entries;
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_SUM_FUSE_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index 9190ba4..5283189 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -23,6 +23,7 @@
 #include "mkldnn_fc_property.h"
 #include "mkldnn_post_quantize_property.h"
 #include "mkldnn_fc_post_quantize_property.h"
+#include "mkldnn_fc_sum_fuse.h"
 #include "mkldnn_elemwisemul_post_quantize_property.h"
 #include "mkldnn_post_quantize_align_scale_property.h"
 #include "mkldnn_transformer_property.h"
@@ -60,6 +61,9 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty);
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCSumFuseProperty)
+.set_attr("quantize", true);
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 811b006..1210c05 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -398,6 +398,18 @@ def conv_add2(no_bias, data_shape):
   sum = pool + conv1
   return sum, attr
 
+
+# fc + sum
+def fc_sum(no_bias, data_shape):
+  attr = {'fc': {'with_sum': 'true', 'quantized' : 'true', 'enable_float_output': 'true'}}
+  data, weight = head_symbol(data_shape)
+  sym1 = mx.symbol.FullyConnected(data=data, weight=weight, no_bias=no_bias, num_hidden=10)
+  data2 = mx.symbol.var('data_2', shape= (data_shape[0], 10), dtype="float32", init = mx.init.Normal(0.3))
+  sum = mx.symbol.elemwise_add(data2, sym1)
+  inputs = [('data', data_shape), ('data_2', (data_shape[0], 10))]
+  return sum, attr, inputs
+
+
 # conv + bn + act fusion case
 def conv_bn_act(no_bias, data_shape, alg):
   attr = {'conv': {'with_bn': 'true', 'with_act': 'true'}}
@@ -788,6 +800,109 @@ def test_pos_conv_bn_act():
       net, attrs = conv_bn_act(True, data_shape, alg)
       check_fusion(net, data_shape, attrs, check_quantization=quantize)
 
+
+@with_seed()
+def test_pos_fc_sum():
+  def check_fusion_parameter(sym_sg, attrs_dict):
+    for name, attrs in attrs_dict.items():
+      if name in config:
+        op_name = config[name][OP_NAME]
+      else:
+        op_name = name
+      assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1
+      if len(attrs):
+          found = False
+          for k, v in sym_sg.attr_dict().items():
+            if k.find(op_name) != -1:
+              found = True
+              for attr_name, attr_value in attrs.items():
+                assert v[attr_name].lower() == attr_value.lower()
+          assert found
+
+  def check_qsym_dummy_forward_fc_sum(qsym, batch, data_shapes):
+    data_names = list(i[0] for i in data_shapes)
+    mod = Module(symbol=qsym, data_names=data_names, label_names=None, context=mx.current_context())
+    mod.bind(for_training=False, data_shapes=data_shapes)
+    mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
+    mod.forward(batch, is_train=False)
+    for output in mod.get_outputs():
+      output.wait_to_read()
+    return mod.get_outputs()
+
+  def check_qsym_forward_fc_sum(qsym, qarg_params, qaux_params, batch, data_shapes):
+    data_names = list(i[0] for i in data_shapes)
+    mod = Module(symbol=qsym, data_names=data_names, label_names=None, context=mx.current_context())
+    mod.bind(for_training=False,
+            data_shapes=data_shapes)
+    mod.set_params(qarg_params, qaux_params)
+    mod.forward(batch, is_train=False)
+    for output in mod.get_outputs():
+      output.wait_to_read()
+    return mod.get_outputs()
+
+  def check_quantize_fc_sum(sym, data_shapes, out_type, atrs, name='fc', quantize_mode='smart'):
+    quantize_granularity_list = ['tensor-wise']
+    if name == 'fc':
+      quantize_granularity_list += ['channel-wise']
+
+    if name in config:
+      name = config[name][OP_NAME]
+    sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
+    data_names = ()
+    for data_shape in data_shapes:
+      data_names += (data_shape[0],)
+    mod = Module(symbol=sym, data_names=data_names, label_names=None)
+    mod.bind(for_training=False, data_shapes=data_shapes)
+    mod.init_params(mx.init.Normal(0.5))
+    arg_params, aux_params = mod.get_params()
+
+    if out_type == 'uint8':
+      data = [mx.random.uniform(0.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes]
+    else:
+      data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes]
+    batch = mx.io.DataBatch(data, [])
+
+    mod.forward(batch, is_train=False)
+    for output in mod.get_outputs():
+        output.wait_to_read()
+    ref_out = mod.get_outputs()
+
+    excluded_sym_names = []
+    excluded_op_names = []
+
+    calib_data = CalibIter(batch, data_shapes, 1)
+
+    for quantize_granularity in quantize_granularity_list:
+      qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg,
+                                                                      arg_params=arg_params,
+                                                                      aux_params=aux_params,
+                                                                      ctx=mx.current_context(),
+                                                                      excluded_sym_names=excluded_sym_names,
+                                                                      excluded_op_names=excluded_op_names,
+                                                                      quantized_dtype=out_type,
+                                                                      data_names=data_names,
+                                                                      calib_mode='naive',
+                                                                      calib_data=calib_data,
+                                                                      label_names=None,
+                                                                      num_calib_examples=1,
+                                                                      quantize_mode=quantize_mode,
+                                                                      quantize_granularity=quantize_granularity)
+      qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME)
+      check_fusion_parameter(qsym, atrs)
+      quantized_out = check_qsym_forward_fc_sum(qsym, qarg_params, qaux_params, batch, data_shapes)
+      for i in range(len(ref_out)):
+        min_range = mx.nd.min(ref_out[i]).asscalar()
+        max_range = mx.nd.max(ref_out[i]).asscalar()
+        atol = 0.1 * max(abs(min_range), abs(max_range))
+        assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+      check_qsym_dummy_forward_fc_sum(qsym, batch, data_shapes)
+
+  for data_shape in DATA_SHAPE:
+    for out_type in ('auto',):
+      net, attrs, inputs = fc_sum(False, data_shape)
+      check_quantize_fc_sum(net, inputs, out_type, attrs)
+
+
 @with_seed()
 def test_pos_conv_bn_sum_act():
   act_list = {"relu": True,