You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/07/04 22:27:21 UTC

[incubator-mxnet] branch v1.x updated: [FEATURE] Performance improvement by asymmetric quantization Quantize+FC (#20302)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 59120fb  [FEATURE] Performance improvement by asymmetric quantization Quantize+FC (#20302)
59120fb is described below

commit 59120fba640a5d1a87fd5f28809c70b5e188698d
Author: Sylwester Fraczek <sy...@intel.com>
AuthorDate: Mon Jul 5 00:25:09 2021 +0200

    [FEATURE] Performance improvement by asymmetric quantization Quantize+FC (#20302)
    
    * mod of quantize and part of fc
    
    * lint fixes
    
    * for now switch shifted quantization to disabled by default
    
    * ci errors fix
    
    * ci errors fix
    
    * ci fix
    
    * ci fix
    
    * ci fixes
    
    * review fixes
    
    * small fixes
    
    * add tolerance to assert_almost_equal
    
    * review fixes
    
    * fix lint
    
    * fix for ci
    
    * fixup
---
 python/mxnet/contrib/quantization.py               |  13 +-
 .../quantization/mkldnn/mkldnn_quantize_v2-inl.h   |  37 ++++-
 src/operator/quantization/quantization_utils.h     |   8 +-
 src/operator/quantization/quantize_graph_pass.cc   | 155 +++++++++++++++++++++
 src/operator/quantization/quantize_v2-inl.h        |  14 +-
 src/operator/quantization/requantize-inl.h         |   8 +-
 src/operator/subgraph/mkldnn/mkldnn_fc.cc          |  95 +++++++------
 tests/python/quantization/test_quantization.py     | 102 +++++++++++++-
 8 files changed, 365 insertions(+), 67 deletions(-)

diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 9b2c756..82d1b95 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -632,7 +632,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params,
         raise ValueError(
             'quantize_model_mkldnn only support Intel cpu platform with MKL-DNN Backend')
 
-    sym = sym.get_backend_symbol('MKLDNN_QUANTIZE')
+    sym = sym.optimize_for('MKLDNN_QUANTIZE')
 
     qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
                                                    data_names=data_names, label_names=label_names,
@@ -643,7 +643,7 @@ def quantize_model_mkldnn(sym, arg_params, aux_params,
                                                    quantized_dtype=quantized_dtype, quantize_mode=quantize_mode,
                                                    quantize_granularity=quantize_granularity, logger=logger)
 
-    qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')
+    qsym = qsym.optimize_for('MKLDNN_QUANTIZE')
 
     return qsym, qarg_params, aux_params
 
@@ -942,7 +942,7 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quant
         logger.info('These layers have been excluded %s' % exclude_layers)
 
     if ctx == mx.cpu():
-        symnet = symnet.get_backend_symbol('MKLDNN_QUANTIZE')
+        symnet = symnet.optimize_for('MKLDNN_QUANTIZE')
 
     qsym, qarg_params, aux_params, collector = quantize_graph(
         sym=symnet, arg_params=args, aux_params=auxs, ctx=ctx,
@@ -979,7 +979,7 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quant
         data_names = [pair[0] for pair in data_shapes]
 
     if ctx == mx.cpu():
-        qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')
+        qsym = qsym.optimize_for('MKLDNN_QUANTIZE')
 
     from ..gluon import SymbolBlock
     data_sym = []
@@ -998,6 +998,11 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quant
         nd_save(param_name, save_dict)
         net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved')
         net.collect_params().reset_ctx(ctx)
+        if quantized_dtype == 'auto':
+            net.optimize_for(x=data_nd, backend="OneDNNShiftedQuantization")
+            tmp_file = os.path.join(tmpdirname, 'model')
+            net.export(tmp_file)
+            net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names, tmp_file + '-0000.params')
     return net
 
 def quantize_net(network, quantized_dtype='auto', quantize_mode='full',
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
index 6e10efa..f97e719 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
@@ -26,6 +26,7 @@
 #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
 #if MXNET_USE_MKLDNN == 1
 #include <algorithm>
+#include <memory>
 #include <string>
 #include <vector>
 #include "../../nn/mkldnn/mkldnn_base-inl.h"
@@ -47,6 +48,8 @@ class SgMKLDNNQuantizeOperator {
   QuantizeV2Param param_;
   float cached_data_min_{0.f};
   float cached_data_max_{0.f};
+  float cached_scale_;
+  uint8_t cached_shift_{0};
   mkldnn::memory::desc o_desc_;
   mkldnn_args_map_t args_;
   std::shared_ptr<mkldnn::reorder> fwd_pd_;
@@ -55,7 +58,6 @@ class SgMKLDNNQuantizeOperator {
 void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
                                        const std::vector<OpReqType> &req,
                                        const std::vector<NDArray> &outputs) {
-  float quantized_range = 0.0;
   NDArray in_buffer = inputs[0];
   float data_min = mshadow::red::limits::MaxValue<float>();
   float data_max = mshadow::red::limits::MinValue<float>();
@@ -109,13 +111,19 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<N
 
     // Write output min/max
     auto out_type = GetQuantizeOutputType(param_);
-    if (out_type == mshadow::kUint8) {
-      quantized_range = kUint8Range;
+    const bool shifted = param_.shifted.has_value() && param_.shifted.value();
+    if (shifted) {
+      // if shifted == true we have guarantee that data_min is negative because
+      // we require that in shifted quantization pass in quantize_graph_pass
+      // Modify out min/max range to reflect shifted data
+      out_type = mshadow::kUint8;
+      *outputs[1].data().dptr<float>() = 0;
+      *outputs[2].data().dptr<float>() = data_max - data_min;
+    } else if (out_type == mshadow::kUint8) {
       *outputs[1].data().dptr<float>() = data_min;
       *outputs[2].data().dptr<float>() = data_max;
     } else if (out_type == mshadow::kInt8) {
       float real_range = MaxAbs(data_min, data_max);
-      quantized_range = kInt8Range;
       *outputs[1].data().dptr<float>() = -real_range;
       *outputs[2].data().dptr<float>() = real_range;
     } else {
@@ -125,12 +133,23 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<N
     if (!initalized_) {
       cached_data_min_ = data_min;
       cached_data_max_ = data_max;
-      float real_range = MaxAbs(data_min, data_max);
-      float scale = quantized_range / real_range;
+      if (shifted) {
+        CHECK_LT(data_min, 0);  // assert that we are working on signed
+        cached_scale_ = kUint8Range / (data_max - data_min);
+        cached_shift_ = static_cast<uint8_t>(std::round(cached_scale_ * -cached_data_min_));
+      } else {
+        cached_scale_ = GetQuantizeScale(out_type, data_min, data_max);
+      }
       mkldnn::primitive_attr attr;
       const int mask = 0;
-      std::vector<float> scales = {scale};
+      std::vector<float> scales = {cached_scale_};
       attr.set_output_scales(mask, scales);
+      if (shifted) {
+        // TODO(sfraczek): change to zero point when optimized in oneDNN
+        dnnl::post_ops po;
+        po.append_sum();
+        attr.set_post_ops(po);
+      }
       mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
       auto i_desc = i_mem->get_desc();
       size_t i_ndim = in_buffer.shape().ndim();
@@ -152,6 +171,10 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<N
     args_[MKLDNN_ARG_TO] = *o_mem.second;
     MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
     CommitOutput(outputs[0], o_mem);
+    if (shifted) {
+      uint8_t *raw_out_mem = static_cast<uint8_t *>(o_mem.second->get_data_handle());
+      std::fill_n(raw_out_mem, outputs[0].shape().Size(), cached_shift_);
+    }
     MKLDNNStream::Get()->Submit();
   }
 }
diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h
index 5230576..e6feb2e 100644
--- a/src/operator/quantization/quantization_utils.h
+++ b/src/operator/quantization/quantization_utils.h
@@ -187,12 +187,12 @@ inline size_t ConfigReduce(mshadow::Stream<xpu>* s,
   return broadcast::ReduceWorkspaceSize<NDim, DType>(s, *dst_shape, kWriteTo, *src_shape);
 }
 
-enum QuantizeOutType { kAuto = 0, kInt8, kUint8 };
+enum QuantizeOutType { qAuto = 0, qInt8, qUint8 };
 
 template<typename Param>
 static mshadow::TypeFlag GetQuantizeOutputType(const Param &param) {
   auto out_type = mshadow::kInt8;
-  if (param.out_type == QuantizeOutType::kAuto) {
+  if (param.out_type == QuantizeOutType::qAuto) {
     if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
       if (param.min_calib_range.value() >= 0.0) {
         out_type = mshadow::kUint8;
@@ -200,9 +200,9 @@ static mshadow::TypeFlag GetQuantizeOutputType(const Param &param) {
         out_type = mshadow::kInt8;
       }
     }
-  } else if (param.out_type == QuantizeOutType::kInt8) {
+  } else if (param.out_type == QuantizeOutType::qInt8) {
     out_type = mshadow::kInt8;
-  } else if (param.out_type == QuantizeOutType::kUint8) {
+  } else if (param.out_type == QuantizeOutType::qUint8) {
     out_type = mshadow::kUint8;
   } else {
     LOG(FATAL) << "Unsupported out_type in params: " <<param.out_type;
diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index 5c43e13..f5060c9 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -31,6 +31,7 @@
 #include <unordered_set>
 #include <vector>
 #include "quantize_v2-inl.h"
+#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
 #include "../../common/utils.h"
 
 namespace mxnet {
@@ -578,6 +579,155 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
   return g;
 }
 
+static NDArray* FindInArgByName(const Graph &g, const std::string& name) {
+  const std::vector<std::string>& in_arg_names =
+      g.GetAttr<std::vector<std::string>>("in_arg_names");
+  size_t i = std::distance(in_arg_names.begin(),
+                           std::find(in_arg_names.begin(), in_arg_names.end(), name));
+  if (i == in_arg_names.size()) {
+    LOG(FATAL) << name << " not found in in_arg_names";
+  }
+  return g.GetAttr<NDArray **>("in_args")[i];
+}
+
+static inline bool IsOneDNNFullyConnected(const ObjectPtr& n) {
+#if MXNET_USE_MKLDNN == 1
+  if (n->op() == Op::Get("_sg_mkldnn_fully_connected")) {
+    auto const& param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
+    if (param.default_param.no_bias == false &&
+        n->inputs[2].node->is_variable()) {
+      if (!(param.mkldnn_param.channel_wise_quantize.has_value() &&
+            param.mkldnn_param.channel_wise_quantize.value())) {
+        return true;
+      }
+    }
+  }
+#endif
+  return false;
+}
+
+static inline bool IsQuantize(const ObjectPtr& n) {
+  if (n->op() == Op::Get("_contrib_quantize_v2")) {
+    auto const &param = nnvm::get<QuantizeV2Param>(n->attrs.parsed);
+    if (param.min_calib_range.has_value() &&
+        param.min_calib_range.value() < 0.0f) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Rescales weights, min_weight and max_weight. Returns bias_int32_rescale.
+static inline float RescaleWeights(const Graph &g, const ObjectPtr &fc, NDArray* weight_tensor) {
+  ObjectPtr &quantize = fc->inputs[0].node;
+  auto min_data = std::stof(quantize->attrs.dict.at("min_calib_range"));
+  auto max_data = std::stof(quantize->attrs.dict.at("max_calib_range"));
+
+  float *min_weight = FindInArgByName(g, fc->inputs[5].node->attrs.name)->data().dptr<float>();
+  float *max_weight = FindInArgByName(g, fc->inputs[6].node->attrs.name)->data().dptr<float>();
+  float min_bias = *FindInArgByName(g, fc->inputs[7].node->attrs.name)->data().dptr<float>();
+  float max_bias = *FindInArgByName(g, fc->inputs[8].node->attrs.name)->data().dptr<float>();
+
+  float data_scale_ = kUint8Range / (max_data - min_data);
+  float weight_scale = GetQuantizeScale(mshadow::kInt8, *min_weight, *max_weight);
+  float bias_scale = GetQuantizeScale(mshadow::kInt8, min_bias, max_bias);
+  float bias_int32_rescale = data_scale_ * weight_scale / bias_scale;
+
+  // // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the
+  // // maximum value of bias to INT_MAX / 2.
+  float bias_max_rescale = mshadow::red::limits::MaxValue<int32_t>() / 2 /
+                           MaxAbs(min_bias, max_bias) / bias_scale;
+  if (bias_int32_rescale > bias_max_rescale) {
+    LOG(INFO) << "RESCALING WEIGHTS in shifted quantization because bias scale "
+                 "is too big in layer " << fc->attrs.name;
+    // avoid overflow on bias
+    bias_int32_rescale = bias_max_rescale;
+    float weight_rescale =
+        bias_int32_rescale * bias_scale / data_scale_ / weight_scale;
+
+    size_t weight_size = weight_tensor->shape().Size();
+    int8_t *weight_ptr = weight_tensor->data().dptr<int8_t>();
+    for (int32_t i = 0; i < static_cast<int32_t>(weight_size); ++i) {
+      weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+    }
+    *min_weight *= weight_rescale;
+    *max_weight *= weight_rescale;
+  }
+  return bias_int32_rescale;
+}
+
+static inline void ShiftBias(int32_t* bias_ptr_int32, size_t bias_size,
+                             NDArray* weight_tensor, int32_t shift_value) {
+  CHECK_EQ(static_cast<size_t>(weight_tensor->shape()[0]), bias_size);
+  int8_t* weight_ptr = weight_tensor->data().dptr<int8_t>();
+  for (dim_t i = 0; i < weight_tensor->shape()[0]; ++i) {
+    for (dim_t j = 0; j < weight_tensor->shape()[1]; j++) {
+      bias_ptr_int32[i] -= shift_value * (*weight_ptr++);
+    }
+  }
+}
+
+Graph OneDNNShiftedQuantization(Graph&& g) {
+  bool disable_shifted_quant =
+      dmlc::GetEnv("MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS", true);
+  LOG(INFO) << "Running OneDNN shifted quantization: " << !disable_shifted_quant;
+  // No change to aux params
+  g.attrs["new_aux_names"] = std::make_shared<nnvm::any>(std::vector<std::string>());
+  g.attrs["new_aux"] = std::make_shared<nnvm::any>(std::vector<NDArray *>());
+
+  // New args to replace the old
+  std::vector<std::string> new_arg_names;
+  std::vector<NDArray *> new_arg_vector;
+
+#if MXNET_USE_MKLDNN == 1
+  if (!disable_shifted_quant) {
+    DFSVisit(g.outputs, [&](const ObjectPtr &fc) {
+      // Find Quantize->FC pattern and rescale bias from int8 to int32 and shift
+      if (IsOneDNNFullyConnected(fc)) {
+        ObjectPtr &quantize = fc->inputs[0].node;
+        if (IsQuantize(quantize)) {
+          ObjectPtr& bias_node = fc->inputs[2].node;
+          std::string bias_name_old = bias_node->attrs.name;
+          NDArray* bias_in_arg_ptr = FindInArgByName(g, bias_name_old);
+          if (bias_in_arg_ptr->dtype() != mshadow::kInt8) return;
+          std::string bias_name_s32 = bias_node->attrs.name + "_s32";
+          bias_node = CreateNode("nullptr", bias_name_s32);
+          new_arg_names.push_back(bias_name_s32);
+
+          quantize->attrs.dict["shifted"] = "True";
+          if (quantize->op()->attr_parser) quantize->op()->attr_parser(&(quantize->attrs));
+
+          NDArray *weight_tensor = FindInArgByName(g, fc->inputs[1].node->attrs.name);
+
+          float bias_int32_rescale = RescaleWeights(g, fc, weight_tensor);
+
+          new_arg_vector.push_back(
+              new NDArray(kDefaultStorage, bias_in_arg_ptr->shape(),
+                          Context::CPU(), false, mshadow::kInt32));
+          int32_t *bias_ptr_int32 = new_arg_vector.back()->data().dptr<int32_t>();
+          size_t bias_size = bias_in_arg_ptr->shape().Size();
+          int8_t *bias_ptr_old = bias_in_arg_ptr->data().dptr<int8_t>();
+
+          for (size_t i = 0; i < bias_size; ++i) {
+            bias_ptr_int32[i] = static_cast<int32_t>(
+                std::round(bias_ptr_old[i] * bias_int32_rescale));
+          }
+          float min_data = std::stof(quantize->attrs.dict.at("min_calib_range"));
+          float max_data = std::stof(quantize->attrs.dict.at("max_calib_range"));
+          float data_scale = kUint8Range / (max_data - min_data);
+          int32_t shift_value = static_cast<int32_t>(std::round(data_scale * -min_data));
+          ShiftBias(bias_ptr_int32, bias_size, weight_tensor, shift_value);
+          LOG(INFO) << "applied shifted quantization on QUANTIZE->FC";
+        }
+      }
+    });
+  }
+#endif
+  g.attrs["new_arg_names"] = std::make_shared<nnvm::any>(new_arg_names);
+  g.attrs["new_args"] = std::make_shared<nnvm::any>(new_arg_vector);
+  return g;
+}
+
 NNVM_REGISTER_PASS(QuantizeGraph)
 .describe("")
 .set_body(QuantizeGraph)
@@ -589,5 +739,10 @@ NNVM_REGISTER_PASS(SetCalibTableToQuantizedGraph)
 .set_body(SetCalibTableToQuantizedGraph)
 .set_change_graph(true);
 
+NNVM_REGISTER_PASS(OneDNNShiftedQuantization)
+.describe("Enables shifted quantization.")
+.set_body(OneDNNShiftedQuantization)
+.set_change_graph(true);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index d8814cc..abd26b1 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -41,12 +41,13 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
   int out_type;
   dmlc::optional<float> min_calib_range;
   dmlc::optional<float> max_calib_range;
+  dmlc::optional<bool> shifted;
   DMLC_DECLARE_PARAMETER(QuantizeV2Param) {
     DMLC_DECLARE_FIELD(out_type)
-      .add_enum("auto", QuantizeOutType::kAuto)
-      .add_enum("int8", QuantizeOutType::kInt8)
-      .add_enum("uint8", QuantizeOutType::kUint8)
-      .set_default(QuantizeOutType::kInt8)
+      .add_enum("auto", QuantizeOutType::qAuto)
+      .add_enum("int8", QuantizeOutType::qInt8)
+      .add_enum("uint8", QuantizeOutType::qUint8)
+      .set_default(QuantizeOutType::qInt8)
       .describe("Output data type. `auto` can be specified to automatically determine output type "
                 "according to min_calib_range.");
     DMLC_DECLARE_FIELD(min_calib_range)
@@ -57,6 +58,9 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
       .set_default(dmlc::optional<float>())
       .describe("The maximum scalar value in the form of float32. If present, it will be used to "
                 "quantize the fp32 data into int8 or uint8.");
+    DMLC_DECLARE_FIELD(shifted)
+      .set_default(dmlc::optional<bool>())
+      .describe("Whether quantization ouptut should be shifted.");
   }
 };
 
@@ -130,7 +134,7 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector<int>
   CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
         in_attrs->at(0) == mshadow::kInt8);
   auto out_type = GetQuantizeOutputType(param);
-  if (out_type == mshadow::kUint8) {
+  if (out_type == mshadow::kUint8 || (param.shifted.has_value() && param.shifted.value())) {
     TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
   } else if (out_type == mshadow::kInt8) {
     TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8);
diff --git a/src/operator/quantization/requantize-inl.h b/src/operator/quantization/requantize-inl.h
index 2bdc3a7..ec604cb 100644
--- a/src/operator/quantization/requantize-inl.h
+++ b/src/operator/quantization/requantize-inl.h
@@ -43,10 +43,10 @@ struct RequantizeParam : public dmlc::Parameter<RequantizeParam> {
   dmlc::optional<float> max_calib_range;  // max float value calculated from calibration dataset
   DMLC_DECLARE_PARAMETER(RequantizeParam) {
     DMLC_DECLARE_FIELD(out_type)
-      .add_enum("auto", QuantizeOutType::kAuto)
-      .add_enum("int8", QuantizeOutType::kInt8)
-      .add_enum("uint8", QuantizeOutType::kUint8)
-      .set_default(QuantizeOutType::kInt8)
+      .add_enum("auto", QuantizeOutType::qAuto)
+      .add_enum("int8", QuantizeOutType::qInt8)
+      .add_enum("uint8", QuantizeOutType::qUint8)
+      .set_default(QuantizeOutType::qInt8)
       .describe("Output data type. `auto` can be specified to automatically determine output type "
                 "according to min_calib_range.");
     DMLC_DECLARE_FIELD(min_calib_range)
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index dbaffc3..0eff06a 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -26,18 +26,21 @@
 
 #if MXNET_USE_MKLDNN == 1
 
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
-#include <string>
-#include "../common.h"
+#include "../../nn/mkldnn/mkldnn_act-inl.h"
 #include "../../nn/mkldnn/mkldnn_base-inl.h"
-#include "../../nn/mkldnn/mkldnn_ops-inl.h"
 #include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
-#include "../../nn/mkldnn/mkldnn_act-inl.h"
-#include "../../tensor/matrix_op-inl.h"
+#include "../../nn/mkldnn/mkldnn_ops-inl.h"
 #include "../../quantization/quantization_utils.h"
-#include "mkldnn_fc-inl.h"
+#include "../../tensor/matrix_op-inl.h"
+#include "../common.h"
 #include "mkldnn_common.h"
+#include "mkldnn_fc-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -250,34 +253,38 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
         weight_scales_[0] =
           GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, cached_max_weight_);
         if (has_bias) {
-          float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
-          float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
-          // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value
-          // of bias to INT_MAX / 2.
-          float bias_max_rescale =
-              MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale;
-          if (bias_int32_rescale > bias_max_rescale) {
-            // avoid overflow on bias
-            bias_int32_rescale = bias_max_rescale;
-            float weight_rescale =
-              bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0];
-            int8_t *weight_ptr = weight.data().dptr<int8_t>();
-            size_t weight_size = weight.shape().Size();
+          if (cached_bias_.dtype() == mshadow::kInt8) {
+            float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
+
+            float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
+            // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value
+            // of bias to INT_MAX / 2.
+            float bias_max_rescale =
+                MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) / bias_scale;
+            if (bias_int32_rescale > bias_max_rescale) {
+              // avoid overflow on bias
+              bias_int32_rescale = bias_max_rescale;
+              float weight_rescale =
+                bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0];
+              int8_t *weight_ptr = weight.data().dptr<int8_t>();
+              size_t weight_size = weight.shape().Size();
+              #pragma omp parallel for num_threads(nthreads)
+              for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
+                weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+              }
+              weight_scales_[0] *= weight_rescale;
+            }
+            NDArray bias = in_data[fullc::kBias];
+            cached_bias_ =
+                NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32);
+            int8_t *bias_ptr = bias.data().dptr<int8_t>();
+            int32_t *quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
+            size_t bias_size = bias.shape().Size();
+
             #pragma omp parallel for num_threads(nthreads)
-            for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
-              weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+            for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
+              quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
             }
-            weight_scales_[0] *= weight_rescale;
-          }
-          NDArray bias = in_data[fullc::kBias];
-          cached_bias_ =
-              NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32);
-          int8_t *bias_ptr = bias.data().dptr<int8_t>();
-          int32_t *quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
-          size_t bias_size = bias.shape().Size();
-          #pragma omp parallel for num_threads(nthreads)
-          for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
-            quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
           }
         }
       }
@@ -522,16 +529,26 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
           in_types->at(0) == mshadow::kUint8)
         << "QuantizedFullyConnected only supports int8/uint8 input, while "
         << in_types->at(0) << " is given.";
-    for (size_t i = 1; i < in_types->size(); ++i) {
-      if (channel_wise) {
+    if (channel_wise) {
+      for (size_t i = 1; i < in_types->size(); ++i) {
         TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
-      } else {
-        if (i < base_num_inputs) {
-          TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
-        } else {
+      }
+    } else {
+        TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kInt8);
+        if (!full_param.default_param.no_bias) {
+          if (in_types->at(2) == -1) {
+            TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kInt32);
+          } else {
+            CHECK(in_types->at(2) == mshadow::kInt8 ||
+                  in_types->at(2) == mshadow::kInt32)
+                << "QuantizedFullyConnected only supports int8/int32 bias, "
+                   "while "
+                << in_types->at(2) << " is given.";
+          }
+        }
+        for (size_t i = base_num_inputs; i < in_types->size(); ++i) {
           TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
         }
-      }
     }
 
     if (full_param.mkldnn_param.enable_float_output) {
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 8c6100d..6c1878a 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -22,11 +22,11 @@ import os
 import mxnet as mx
 import numpy as np
 from mxnet.gluon.model_zoo import vision
-from mxnet.test_utils import assert_almost_equal, assert_exception, rand_ndarray, rand_shape_nd, same, DummyIter
+from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err, assert_exception
+from mxnet.test_utils import rand_ndarray, rand_shape_nd, same, DummyIter, environment
 from common import with_seed
 from mxnet.module import Module
 from mxnet.io import NDArrayIter
-import unittest
 import operator
 
 def is_test_for_gpu():
@@ -38,7 +38,7 @@ def is_test_for_mkldnn():
 
 def is_test_for_native_cpu():
     return (mx.current_context().device_type == 'cpu'
-            and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == None)
+            and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') is None)
 
 @with_seed()
 def test_quantize_float32_to_int8():
@@ -1156,7 +1156,7 @@ def test_quantize_gluon_with_forward():
         quantized_resnet18_v1(random_data)
 
         for mode in ['naive', 'entropy']:
-            qdtype = qdtype if mode is 'naive' else 'auto'
+            qdtype = qdtype if mode == 'naive' else 'auto'
             quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype,
                                                                   exclude_layers=None,
                                                                   exclude_layers_match=excluded_names_match,
@@ -1254,6 +1254,100 @@ def test_get_optimal_thresholds():
         assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4)
 
 
+@with_seed()
+def test_onednn_shifted_quantization():
+    batch_size = 1
+    if not is_test_for_mkldnn():
+        print("Test only for mkldnn")
+        return
+
+    def collect_param(fc_layer, p):
+        param = fc_layer.collect_params(p)[p]._reduce()
+        return param
+
+    def quantize_param_to_int8(param):
+        p_int8, p_min, p_max = mx.ndarray.contrib.quantize(data=param,
+                                                           min_range=mx.ndarray.min(param),
+                                                           max_range=mx.ndarray.max(param),
+                                                           out_type='int8')
+        p_scale = 127.5 / np.max(np.abs([p_max.asnumpy(), p_min.asnumpy()]))
+        return p_int8.asnumpy(), p_scale
+
+    def get_shifted_bias(quantize_attrs, weights_int8, weights_scale, bias_int8, bias_scale):
+        max_data = float(quantize_attrs['max_calib_range'])
+        min_data = float(quantize_attrs['min_calib_range'])
+        data_scale = 255.5 / (max_data - min_data)
+        shift_value = np.array(np.round(data_scale * -min_data), np.uint8)
+        shift_matrix = shift_value * np.ones((1, weights_int8.shape[1]), np.int32)
+        shift_matrix = np.array(np.dot(shift_matrix, weights_int8.T).squeeze(), np.int32)
+        bias_int32_rescale = data_scale * weights_scale / bias_scale
+        bias_int32 = np.array(np.round(bias_int8 * bias_int32_rescale), dtype=np.int32)
+        bias_shifted = bias_int32 - shift_matrix
+        return bias_shifted
+
+    def quantize_fc_layer(fc_layer, qdtype, random_data):
+        calib_data = NDArrayIter(data=random_data, batch_size=batch_size)
+        calib_data = DummyIter(calib_data)
+        fc_layer = mx.contrib.quant.quantize_net(fc_layer, quantized_dtype=qdtype,
+                                                 exclude_layers=None,
+                                                 exclude_layers_match=[],
+                                                 calib_data=calib_data,
+                                                 calib_mode='naive',
+                                                 num_calib_examples=1,
+                                                 ctx=mx.current_context())
+        fc_layer.hybridize(static_alloc=True, static_shape=True)
+        fc_layer(random_data).wait_to_read()
+
+        _, sym = fc_layer._cached_graph
+        quantize_attrs = sym.attr_dict()['data_quantize']
+        return fc_layer, quantize_attrs
+
+    def get_fc_layer():
+        fc_layer = mx.gluon.nn.Dense(5, use_bias=True, flatten=True,
+                                     weight_initializer=mx.initializer.Normal(),
+                                     bias_initializer=mx.initializer.Normal())
+        fc_layer.initialize()
+        return fc_layer
+
+    # Shifted quantization should set new bias to FC and add shift to output of quantize
+    # b'=b-shift*w because FC(x+shift,w,b)=(x+shift)*w+b
+    def check(number, qdtype):
+        random_data = mx.nd.random_uniform(low=0 if qdtype == 'uint8' else -1, high=1, shape=(batch_size, 32))
+        fc_layer = get_fc_layer()
+        out = fc_layer(random_data)
+        out.wait_to_read()
+
+        if qdtype == 'auto':
+            bias_int8, bias_scale = quantize_param_to_int8(
+                collect_param(fc_layer, 'dense%d_bias' % number))
+            weights_int8, weights_scale = quantize_param_to_int8(
+                collect_param(fc_layer, 'dense%d_weight' % number))
+
+        fc_layer_quantized, quantize_attrs = quantize_fc_layer(fc_layer, qdtype, random_data)
+        out_q = fc_layer_quantized(random_data)
+        out_q.wait_to_read()
+
+        min_range = mx.nd.min(out).asscalar()
+        max_range = mx.nd.max(out).asscalar()
+        atol = 0.1 * max(abs(min_range), abs(max_range))
+        assert_almost_equal_with_err(out.asnumpy(), out_q.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+
+        if qdtype == 'auto':
+            assert quantize_attrs['shifted'] == 'True'
+            bias_s32 = collect_param(fc_layer_quantized, 'dense%d_bias_quantize_s32' % number)
+            assert bias_s32.dtype == np.int32
+            bias_shifted = get_shifted_bias(quantize_attrs, weights_int8, weights_scale, bias_int8, bias_scale)
+            assert_almost_equal(bias_s32, bias_shifted, rtol=1e-3, atol=1e-3)
+        else:
+            assert 'shifted' not in quantize_attrs
+            bias = collect_param(fc_layer_quantized, 'dense%d_bias_quantize' % number)
+            assert bias.dtype == np.int8
+
+    with environment({'MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS': '0'}):
+        for i, qdtype in enumerate(['int8', 'uint8', 'auto']):
+            check(i, qdtype)
+
+
 if __name__ == "__main__":
     import nose
     nose.runmodule()