You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/08/10 07:00:37 UTC

[incubator-mxnet] branch v1.x updated: [FEATURE] Asymmetric fc fc (#20430)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b1853e8  [FEATURE] Asymmetric fc fc (#20430)
b1853e8 is described below

commit b1853e8e305e2cb9f7d6f2cea8a6cec7fe6f01f7
Author: Sylwester Fraczek <sy...@intel.com>
AuthorDate: Tue Aug 10 08:58:58 2021 +0200

    [FEATURE] Asymmetric fc fc (#20430)
    
    * add fc->u8->fc fuse
    
    fix lint errors
    
    * Refactor shifted quantization
    
    * refactoring shifted quantization function
    
    * small refactoring and renaming
    
    * Code cleaning and reorganising in files
    
    * Sanity fixup
    
    * flag fixup
    
    * another ci fix
    
    remove unused function and change include path to relative
    
    * fix ci
    
    add default: to switch because clang gives error
    
    * enable shifted quant fc with_eltwise
    
    * fixed fc-fc after 'fuse fc+sum for quantizaiton' change
    
    * lint fixes
    
    * add #if MXNET_USE_MKLDNN to RescaleWeights
    
    * move FCInputIndex constructor to header
    
    * add FCInputIndex and add MXNET_USE_MKLDNN
    
    * move functions from header to cc
    
    * LOG running shifted quantization only when enabled
    
    * review fixes and other fixes
    
    * formatting
    
    * review fixes
    
    * clang-format'ted
    
    * fix CI
    
    Co-authored-by: DominikaJedynak <do...@gmail.com>
    Co-authored-by: Dominika Jedynak <do...@intel.com>
---
 python/mxnet/contrib/quantization.py               |   2 +-
 .../nn/mkldnn/mkldnn_fully_connected-inl.h         |  79 ++++++
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc   |   9 +
 .../quantization/asymmetric_quantize_graph_pass.cc | 277 +++++++++++++++++++++
 src/operator/quantization/quantize_graph_pass.cc   | 183 +-------------
 src/operator/quantization/quantize_graph_pass.h    |  65 +++++
 src/operator/subgraph/mkldnn/mkldnn_fc-inl.h       |   3 +
 src/operator/subgraph/mkldnn/mkldnn_fc.cc          |  88 +------
 tests/python/quantization/test_quantization.py     |  84 ++++++-
 9 files changed, 529 insertions(+), 261 deletions(-)

diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 82d1b95..94ead97 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -900,7 +900,7 @@ def quantize_net_v2(network, quantized_dtype='auto', quantize_mode='full', quant
     while True:
         try:
             network(*data_nd)
-        except TypeError:
+        except (TypeError, ValueError):
             del data_nd[-1]
             del calib_data.provide_data[-1]
             continue
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
index 352f7d9..d2ccdef 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
@@ -30,6 +30,7 @@
 #if MXNET_USE_MKLDNN == 1
 
 #include <memory>
+#include <unordered_map>
 #include <string>
 #include <vector>
 
@@ -47,6 +48,7 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
   float sum_scale = 1.0f;
   dmlc::optional<float> min_calib_range;  // min float value calculated from calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from calibration dataset
+  dmlc::optional<bool> shifted_output;
   dmlc::optional<bool> channel_wise_quantize;
 
   DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
@@ -73,6 +75,9 @@ struct MKLDNNFCParam : public dmlc::Parameter<MKLDNNFCParam> {
             "The maximum scalar value in the form of float32 obtained "
             "through calibration. If present, it will be used to by "
             "quantized fullyconnected op to calculate primitive scale");
+    DMLC_DECLARE_FIELD(shifted_output)
+        .set_default(dmlc::optional<bool>())
+        .describe("Whether quantized output should be shifted to u8.");
     DMLC_DECLARE_FIELD(channel_wise_quantize)
         .set_default(dmlc::optional<bool>())
         .describe("Whether support channel-wise-quantize for weight.");
@@ -86,6 +91,80 @@ struct MKLDNNFCFullParam {
   std::vector<float> output_scales = {0.0f};
 };
 
+static inline size_t GetInSumIndex(const MKLDNNFCFullParam& param) {
+  assert(param.mkldnn_param.with_sum);
+  return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1);
+}
+
+class FCInputIndex {
+ public:
+  explicit FCInputIndex(const MKLDNNFCFullParam full_param) {
+    auto& mkldnn_param   = full_param.mkldnn_param;
+    const bool has_bias  = !full_param.default_param.no_bias;
+    const bool quantized = mkldnn_param.quantized;
+    const bool sum_input_quantized =
+        quantized && mkldnn_param.with_sum && !mkldnn_param.enable_float_output;
+    const bool channel_wise = quantized && mkldnn_param.channel_wise_quantize.has_value() &&
+                              mkldnn_param.channel_wise_quantize.value();
+
+    // Calculate position of particular input in the input vector:
+    int index     = 0;
+    data          = index++;
+    weight        = index++;
+    bias          = has_bias ? index++ : 0;
+    num_quantized = index + (sum_input_quantized ? 1 : 0);
+    sum           = mkldnn_param.with_sum ? index++ : 0;
+    num_base      = index;
+
+    data_min   = quantized ? index++ : 0;
+    data_max   = quantized ? index++ : 0;
+    weight_min = (quantized && !channel_wise) ? index++ : 0;
+    weight_max = (quantized && !channel_wise) ? index++ : 0;
+    bias_min   = (quantized && !channel_wise && has_bias) ? index++ : 0;
+    bias_max   = (quantized && !channel_wise && has_bias) ? index++ : 0;
+    sum_min    = sum_input_quantized ? index++ : 0;
+    sum_max    = sum_input_quantized ? index++ : 0;
+    num_total  = index;
+  }
+
+  // true if sum input is used and it is float number
+  bool IsSumInputFloat() const {
+    return (sum && !sum_min);
+  }
+  int GetTotal() const {
+    return num_total;
+  }
+  int GetBase() const {
+    return num_base;
+  }
+
+  // return number of standard inputs which are quantized (represented as
+  // integer)
+  int GetQuantized() const {
+    return num_quantized;
+  }
+
+  // Represent index of particular input in the input vector:
+  int data;
+  int weight;
+  int bias;
+  int sum;
+  int data_min;
+  int data_max;
+  int weight_min;
+  int weight_max;
+  int bias_min;
+  int bias_max;
+  int sum_min;
+  int sum_max;
+
+ private:
+  int num_base;       // Number of standard inputs
+  int num_total;      // Number of total inputs: standard + additional needed for
+                      // quantization
+  int num_quantized;  // Number of standard inputs which are quantized
+};
+
 mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullParam& full_param,
                                                            const bool is_train,
                                                            const NDArray& data,
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 814ca45..4e0aa5a 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -28,6 +28,7 @@
 #include <unordered_map>
 
 #include "mkldnn_fully_connected-inl.h"
+#include "../../quantization/quantization_utils.h"
 
 namespace mxnet {
 namespace op {
@@ -55,6 +56,14 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
                        full_param.eltwise_param.alpha,
                        full_param.eltwise_param.beta);
   }
+  if (full_param.mkldnn_param.shifted_output.has_value() &&
+      full_param.mkldnn_param.shifted_output.value()) {
+    auto max = full_param.mkldnn_param.max_calib_range.value();
+    auto min = full_param.mkldnn_param.min_calib_range.value();
+    float scale = GetQuantizeScale(mshadow::kUint8, 0, max - min);
+    float shift = -min * scale;
+    ops.append_eltwise(1.f, dnnl::algorithm::eltwise_linear, 1.f, shift);
+  }
   if (full_param.mkldnn_param.with_sum) {
     ops.append_sum(full_param.mkldnn_param.sum_scale);
   }
diff --git a/src/operator/quantization/asymmetric_quantize_graph_pass.cc b/src/operator/quantization/asymmetric_quantize_graph_pass.cc
new file mode 100644
index 0000000..185448b
--- /dev/null
+++ b/src/operator/quantization/asymmetric_quantize_graph_pass.cc
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2021 by Contributors
+ * \file asymmetric_quantize_graph_pass.cc
+ * \brief
+ */
+#if MXNET_USE_MKLDNN == 1
+#include "quantize_graph_pass.h"
+
+namespace mxnet {
+namespace op {
+namespace asym_quant {
+
+using nnvm::Graph;
+using nnvm::ObjectPtr;
+
+template <bool require_bias>
+static bool IsOneDNNFullyConnected(const ObjectPtr& n) {
+  if (n->op() == Op::Get("_sg_mkldnn_fully_connected")) {
+    auto const& param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
+    FCInputIndex idx(param);
+    if (!(param.mkldnn_param.channel_wise_quantize.has_value() &&
+          param.mkldnn_param.channel_wise_quantize.value())) {
+      return !require_bias ||
+             (param.default_param.no_bias == false && n->inputs[idx.bias].node->is_variable());
+    }
+  }
+  return false;
+}
+
+static bool IsQuantize(const ObjectPtr& n) {
+  if (n->op() == Op::Get("_contrib_quantize_v2")) {
+    auto const& param = nnvm::get<QuantizeV2Param>(n->attrs.parsed);
+    if (param.min_calib_range.has_value() && param.min_calib_range.value() < 0.0f) {
+      return true;
+    }
+  }
+  return false;
+}
+
+static NDArray* FindInArgByName(const Graph& g, const std::string& name) {
+  const std::vector<std::string>& in_arg_names =
+      g.GetAttr<std::vector<std::string>>("in_arg_names");
+  size_t i = std::distance(in_arg_names.begin(),
+                           std::find(in_arg_names.begin(), in_arg_names.end(), name));
+  if (i == in_arg_names.size()) {
+    LOG(FATAL) << name << " not found in in_arg_names";
+  }
+  return g.GetAttr<NDArray**>("in_args")[i];
+}
+
+// Rescales weights, min_weight and max_weight. Returns bias_int32_rescale.
+static float RescaleWeights(const Graph& g, const ObjectPtr& fc, NDArray* weight_tensor) {
+  FCInputIndex idx(nnvm::get<MKLDNNFCFullParam>(fc->attrs.parsed));
+
+  float* min_weight =
+      FindInArgByName(g, fc->inputs[idx.weight_min].node->attrs.name)->data().dptr<float>();
+  float* max_weight =
+      FindInArgByName(g, fc->inputs[idx.weight_max].node->attrs.name)->data().dptr<float>();
+  float min_bias =
+      *FindInArgByName(g, fc->inputs[idx.bias_min].node->attrs.name)->data().dptr<float>();
+  float max_bias =
+      *FindInArgByName(g, fc->inputs[idx.bias_max].node->attrs.name)->data().dptr<float>();
+
+  float min_data           = std::stof(fc->inputs[idx.data].node->attrs.dict.at("min_calib_range"));
+  float max_data           = std::stof(fc->inputs[idx.data].node->attrs.dict.at("max_calib_range"));
+  float data_scale_        = kUint8Range / (max_data - min_data);
+  float weight_scale       = GetQuantizeScale(mshadow::kInt8, *min_weight, *max_weight);
+  float bias_scale         = GetQuantizeScale(mshadow::kInt8, min_bias, max_bias);
+  float bias_int32_rescale = data_scale_ * weight_scale / bias_scale;
+
+  // // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the
+  // // maximum value of bias to INT_MAX / 2.
+  float bias_max_rescale =
+      mshadow::red::limits::MaxValue<int32_t>() / 2 / MaxAbs(min_bias, max_bias) / bias_scale;
+  if (bias_int32_rescale > bias_max_rescale) {
+    LOG(INFO) << "RESCALING WEIGHTS in shifted quantization because bias scale "
+                 "is too big in layer "
+              << fc->attrs.name;
+    // avoid overflow on bias
+    bias_int32_rescale   = bias_max_rescale;
+    float weight_rescale = bias_int32_rescale * bias_scale / data_scale_ / weight_scale;
+
+    size_t weight_size = weight_tensor->shape().Size();
+    int8_t* weight_ptr = weight_tensor->data().dptr<int8_t>();
+    for (int32_t i = 0; i < static_cast<int32_t>(weight_size); ++i) {
+      weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+    }
+    *min_weight *= weight_rescale;
+    *max_weight *= weight_rescale;
+  }
+  return bias_int32_rescale;
+}
+
+static void ShiftBias(int32_t* bias_ptr_int32,
+                      size_t bias_size,
+                      NDArray* weight_tensor,
+                      int32_t shift_value) {
+  CHECK_EQ(static_cast<size_t>(weight_tensor->shape()[0]), bias_size);
+  int8_t* weight_ptr = weight_tensor->data().dptr<int8_t>();
+  for (dim_t i = 0; i < weight_tensor->shape()[0]; ++i) {
+    for (dim_t j = 0; j < weight_tensor->shape()[1]; j++) {
+      bias_ptr_int32[i] -= shift_value * (*weight_ptr++);
+    }
+  }
+}
+
+enum class Pattern { QuantizeFc, FcFc, None };
+
+static Pattern FindPattern(const ObjectPtr& node) {
+  if (IsOneDNNFullyConnected<true>(node)) {
+    if (IsQuantize(node->inputs[0].node)) {
+      return Pattern::QuantizeFc;
+    } else if (IsOneDNNFullyConnected<false>(node->inputs[0].node)) {
+      return Pattern::FcFc;
+    }
+  }
+  return Pattern::None;
+}
+
+static void QuantizeFcShiftedQuantization(const ObjectPtr& node,
+                                          Graph&& g,
+                                          std::vector<NDArray*>* new_arg_vector,
+                                          std::vector<std::string>* new_arg_names) {
+  ObjectPtr& quantize       = node->inputs[0].node;
+  ObjectPtr& bias_node      = node->inputs[2].node;
+  std::string bias_name_old = bias_node->attrs.name;
+  NDArray* bias_in_arg_ptr  = FindInArgByName(g, bias_name_old);
+  if (bias_in_arg_ptr->dtype() != mshadow::kInt8)
+    return;
+  std::string bias_name_s32 = bias_node->attrs.name + "_s32";
+  bias_node                 = CreateNode("nullptr", bias_name_s32);
+  new_arg_names->push_back(bias_name_s32);
+
+  quantize->attrs.dict["shifted"] = "True";
+  if (quantize->op()->attr_parser)
+    quantize->op()->attr_parser(&(quantize->attrs));
+
+  NDArray* weight_tensor = FindInArgByName(g, node->inputs[1].node->attrs.name);
+
+  float bias_int32_rescale = RescaleWeights(g, node, weight_tensor);
+
+  new_arg_vector->push_back(new NDArray(
+      kDefaultStorage, bias_in_arg_ptr->shape(), Context::CPU(), false, mshadow::kInt32));
+  int32_t* bias_ptr_int32 = new_arg_vector->back()->data().dptr<int32_t>();
+  size_t bias_size        = bias_in_arg_ptr->shape().Size();
+  int8_t* bias_ptr_old    = bias_in_arg_ptr->data().dptr<int8_t>();
+
+  for (size_t i = 0; i < bias_size; ++i) {
+    bias_ptr_int32[i] = static_cast<int32_t>(std::round(bias_ptr_old[i] * bias_int32_rescale));
+  }
+  float min_data      = std::stof(quantize->attrs.dict.at("min_calib_range"));
+  float max_data      = std::stof(quantize->attrs.dict.at("max_calib_range"));
+  float data_scale    = kUint8Range / (max_data - min_data);
+  int32_t shift_value = static_cast<int32_t>(std::round(data_scale * -min_data));
+  ShiftBias(bias_ptr_int32, bias_size, weight_tensor, shift_value);
+}
+
+static void FcFcShiftedQuantization(const ObjectPtr& node,
+                                    Graph&& g,
+                                    std::vector<NDArray*>* new_arg_vector,
+                                    std::vector<std::string>* new_arg_names) {
+  ObjectPtr& first_fc       = node->inputs[0].node;
+  ObjectPtr& bias_node      = node->inputs[2].node;
+  std::string bias_name_old = bias_node->attrs.name;
+  NDArray* bias_in_arg_ptr  = FindInArgByName(g, bias_name_old);
+  if (bias_in_arg_ptr->dtype() != mshadow::kInt8)
+    return;
+  std::string bias_name_s32 = bias_node->attrs.name + "_s32";
+  bias_node                 = CreateNode("nullptr", bias_name_s32);
+  new_arg_names->push_back(bias_name_s32);
+
+  first_fc->attrs.dict["shifted_output"] = "True";
+  if (first_fc->op()->attr_parser)
+    first_fc->op()->attr_parser(&(first_fc->attrs));
+
+  NDArray* weight_tensor = FindInArgByName(g, node->inputs[1].node->attrs.name);
+
+  float bias_int32_rescale = RescaleWeights(g, node, weight_tensor);
+
+  new_arg_vector->push_back(new NDArray(
+      kDefaultStorage, bias_in_arg_ptr->shape(), Context::CPU(), false, mshadow::kInt32));
+
+  int32_t* bias_ptr_int32 = new_arg_vector->back()->data().dptr<int32_t>();
+  size_t bias_size        = bias_in_arg_ptr->shape().Size();
+  int8_t* bias_ptr_old    = bias_in_arg_ptr->data().dptr<int8_t>();
+
+  for (size_t i = 0; i < bias_size; ++i) {
+    bias_ptr_int32[i] = static_cast<int32_t>(std::round(bias_ptr_old[i] * bias_int32_rescale));
+  }
+
+  float min_data      = std::stof(first_fc->attrs.dict.at("min_calib_range"));
+  float max_data      = std::stof(first_fc->attrs.dict.at("max_calib_range"));
+  float data_scale    = kUint8Range / (max_data - min_data);
+  int32_t shift_value = static_cast<int32_t>(std::round(data_scale * -min_data));
+  ShiftBias(bias_ptr_int32, bias_size, weight_tensor, shift_value);
+}
+
+static Graph OneDNNShiftedQuantization(Graph&& g) {
+  bool disable_shifted_quant =
+      dmlc::GetEnv("MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS", true);
+  bool quantize_fc = !dmlc::GetEnv("MXNET_DISABLE_SHIFTED_QUANTIZE_FC_OPTIMIZATION", false);
+  bool fc_fc       = !dmlc::GetEnv("MXNET_DISABLE_SHIFTED_FC_FC_OPTIMIZATION", false);
+  if (!disable_shifted_quant) {
+    LOG(INFO) << "Running OneDNN shifted quantization";
+  }
+  // No change to aux params
+  g.attrs["new_aux_names"] = std::make_shared<nnvm::any>(std::vector<std::string>());
+  g.attrs["new_aux"]       = std::make_shared<nnvm::any>(std::vector<NDArray*>());
+
+  // New args to replace the old
+  std::vector<std::string> new_arg_names;
+  std::vector<NDArray*> new_arg_vector;
+
+  if (!disable_shifted_quant) {
+    unsigned quantize_fc_counter = 0;
+    unsigned fc_fc_counter       = 0;
+    DFSVisit(g.outputs, [&](const ObjectPtr& node) {
+      Pattern p = FindPattern(node);
+      switch (p) {
+        case Pattern::QuantizeFc:
+          if (quantize_fc) {
+            QuantizeFcShiftedQuantization(
+                node, std::forward<Graph>(g), &new_arg_vector, &new_arg_names);
+            ++quantize_fc_counter;
+          }
+          break;
+        case Pattern::FcFc:
+          if (fc_fc) {
+            FcFcShiftedQuantization(node, std::forward<Graph>(g), &new_arg_vector, &new_arg_names);
+            ++fc_fc_counter;
+          }
+          break;
+        default:
+          break;
+      }
+    });
+    if (quantize_fc_counter > 0) {
+      LOG(INFO) << "applied shifted quantization on QUANTIZE->FC " << quantize_fc_counter
+                << " times";
+    }
+    if (fc_fc_counter > 0) {
+      LOG(INFO) << "applied shifted quantization on FC->FC " << fc_fc_counter << " times";
+    }
+  }
+  g.attrs["new_arg_names"] = std::make_shared<nnvm::any>(new_arg_names);
+  g.attrs["new_args"]      = std::make_shared<nnvm::any>(new_arg_vector);
+  return g;
+}
+
+NNVM_REGISTER_PASS(OneDNNShiftedQuantization)
+    .describe("Enables shifted quantization.")
+    .set_body(OneDNNShiftedQuantization)
+    .set_change_graph(true);
+
+}  // namespace asym_quant
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index 74da6e9..9a8cbc2 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -19,20 +19,13 @@
 
 /*!
  *  Copyright (c) 2016 by Contributors
- * \file quantization.cc
+ * \file quantize_graph_pass.cc
  * \brief
  */
 
-#include <mxnet/op_attr_types.h>
-#include <nnvm/graph.h>
-#include <nnvm/pass.h>
-#include <queue>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-#include "quantize_v2-inl.h"
-#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
-#include "../../common/utils.h"
+#include "quantize_graph_pass.h"
+#include <memory>
+#include <utility>
 
 namespace mxnet {
 namespace op {
@@ -56,20 +49,6 @@ static inline size_t GetNumOutputs(ObjectPtr node) {
   return num_outputs;
 }
 
-ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-      nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
-
 /*!
  * \brief Insert a node named with node_name holding the op of op_name
  * before the node current and after the node previous.
@@ -580,155 +559,6 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
   return g;
 }
 
-static NDArray* FindInArgByName(const Graph &g, const std::string& name) {
-  const std::vector<std::string>& in_arg_names =
-      g.GetAttr<std::vector<std::string>>("in_arg_names");
-  size_t i = std::distance(in_arg_names.begin(),
-                           std::find(in_arg_names.begin(), in_arg_names.end(), name));
-  if (i == in_arg_names.size()) {
-    LOG(FATAL) << name << " not found in in_arg_names";
-  }
-  return g.GetAttr<NDArray **>("in_args")[i];
-}
-
-static inline bool IsOneDNNFullyConnected(const ObjectPtr& n) {
-#if MXNET_USE_MKLDNN == 1
-  if (n->op() == Op::Get("_sg_mkldnn_fully_connected")) {
-    auto const& param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
-    if (param.default_param.no_bias == false &&
-        n->inputs[2].node->is_variable()) {
-      if (!(param.mkldnn_param.channel_wise_quantize.has_value() &&
-            param.mkldnn_param.channel_wise_quantize.value())) {
-        return true;
-      }
-    }
-  }
-#endif
-  return false;
-}
-
-static inline bool IsQuantize(const ObjectPtr& n) {
-  if (n->op() == Op::Get("_contrib_quantize_v2")) {
-    auto const &param = nnvm::get<QuantizeV2Param>(n->attrs.parsed);
-    if (param.min_calib_range.has_value() &&
-        param.min_calib_range.value() < 0.0f) {
-      return true;
-    }
-  }
-  return false;
-}
-
-// Rescales weights, min_weight and max_weight. Returns bias_int32_rescale.
-static inline float RescaleWeights(const Graph &g, const ObjectPtr &fc, NDArray* weight_tensor) {
-  ObjectPtr &quantize = fc->inputs[0].node;
-  auto min_data = std::stof(quantize->attrs.dict.at("min_calib_range"));
-  auto max_data = std::stof(quantize->attrs.dict.at("max_calib_range"));
-
-  float *min_weight = FindInArgByName(g, fc->inputs[5].node->attrs.name)->data().dptr<float>();
-  float *max_weight = FindInArgByName(g, fc->inputs[6].node->attrs.name)->data().dptr<float>();
-  float min_bias = *FindInArgByName(g, fc->inputs[7].node->attrs.name)->data().dptr<float>();
-  float max_bias = *FindInArgByName(g, fc->inputs[8].node->attrs.name)->data().dptr<float>();
-
-  float data_scale_ = kUint8Range / (max_data - min_data);
-  float weight_scale = GetQuantizeScale(mshadow::kInt8, *min_weight, *max_weight);
-  float bias_scale = GetQuantizeScale(mshadow::kInt8, min_bias, max_bias);
-  float bias_int32_rescale = data_scale_ * weight_scale / bias_scale;
-
-  // // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the
-  // // maximum value of bias to INT_MAX / 2.
-  float bias_max_rescale = mshadow::red::limits::MaxValue<int32_t>() / 2 /
-                           MaxAbs(min_bias, max_bias) / bias_scale;
-  if (bias_int32_rescale > bias_max_rescale) {
-    LOG(INFO) << "RESCALING WEIGHTS in shifted quantization because bias scale "
-                 "is too big in layer " << fc->attrs.name;
-    // avoid overflow on bias
-    bias_int32_rescale = bias_max_rescale;
-    float weight_rescale =
-        bias_int32_rescale * bias_scale / data_scale_ / weight_scale;
-
-    size_t weight_size = weight_tensor->shape().Size();
-    int8_t *weight_ptr = weight_tensor->data().dptr<int8_t>();
-    for (int32_t i = 0; i < static_cast<int32_t>(weight_size); ++i) {
-      weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
-    }
-    *min_weight *= weight_rescale;
-    *max_weight *= weight_rescale;
-  }
-  return bias_int32_rescale;
-}
-
-static inline void ShiftBias(int32_t* bias_ptr_int32, size_t bias_size,
-                             NDArray* weight_tensor, int32_t shift_value) {
-  CHECK_EQ(static_cast<size_t>(weight_tensor->shape()[0]), bias_size);
-  int8_t* weight_ptr = weight_tensor->data().dptr<int8_t>();
-  for (dim_t i = 0; i < weight_tensor->shape()[0]; ++i) {
-    for (dim_t j = 0; j < weight_tensor->shape()[1]; j++) {
-      bias_ptr_int32[i] -= shift_value * (*weight_ptr++);
-    }
-  }
-}
-
-Graph OneDNNShiftedQuantization(Graph&& g) {
-  bool disable_shifted_quant =
-      dmlc::GetEnv("MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS", true);
-  LOG(INFO) << "Running OneDNN shifted quantization: " << !disable_shifted_quant;
-  // No change to aux params
-  g.attrs["new_aux_names"] = std::make_shared<nnvm::any>(std::vector<std::string>());
-  g.attrs["new_aux"] = std::make_shared<nnvm::any>(std::vector<NDArray *>());
-
-  // New args to replace the old
-  std::vector<std::string> new_arg_names;
-  std::vector<NDArray *> new_arg_vector;
-
-#if MXNET_USE_MKLDNN == 1
-  if (!disable_shifted_quant) {
-    DFSVisit(g.outputs, [&](const ObjectPtr &fc) {
-      // Find Quantize->FC pattern and rescale bias from int8 to int32 and shift
-      if (IsOneDNNFullyConnected(fc)) {
-        ObjectPtr &quantize = fc->inputs[0].node;
-        if (IsQuantize(quantize)) {
-          ObjectPtr& bias_node = fc->inputs[2].node;
-          std::string bias_name_old = bias_node->attrs.name;
-          NDArray* bias_in_arg_ptr = FindInArgByName(g, bias_name_old);
-          if (bias_in_arg_ptr->dtype() != mshadow::kInt8) return;
-          std::string bias_name_s32 = bias_node->attrs.name + "_s32";
-          bias_node = CreateNode("nullptr", bias_name_s32);
-          new_arg_names.push_back(bias_name_s32);
-
-          quantize->attrs.dict["shifted"] = "True";
-          if (quantize->op()->attr_parser) quantize->op()->attr_parser(&(quantize->attrs));
-
-          NDArray *weight_tensor = FindInArgByName(g, fc->inputs[1].node->attrs.name);
-
-          float bias_int32_rescale = RescaleWeights(g, fc, weight_tensor);
-
-          new_arg_vector.push_back(
-              new NDArray(kDefaultStorage, bias_in_arg_ptr->shape(),
-                          Context::CPU(), false, mshadow::kInt32));
-          int32_t *bias_ptr_int32 = new_arg_vector.back()->data().dptr<int32_t>();
-          size_t bias_size = bias_in_arg_ptr->shape().Size();
-          int8_t *bias_ptr_old = bias_in_arg_ptr->data().dptr<int8_t>();
-
-          for (size_t i = 0; i < bias_size; ++i) {
-            bias_ptr_int32[i] = static_cast<int32_t>(
-                std::round(bias_ptr_old[i] * bias_int32_rescale));
-          }
-          float min_data = std::stof(quantize->attrs.dict.at("min_calib_range"));
-          float max_data = std::stof(quantize->attrs.dict.at("max_calib_range"));
-          float data_scale = kUint8Range / (max_data - min_data);
-          int32_t shift_value = static_cast<int32_t>(std::round(data_scale * -min_data));
-          ShiftBias(bias_ptr_int32, bias_size, weight_tensor, shift_value);
-          LOG(INFO) << "applied shifted quantization on QUANTIZE->FC";
-        }
-      }
-    });
-  }
-#endif
-  g.attrs["new_arg_names"] = std::make_shared<nnvm::any>(new_arg_names);
-  g.attrs["new_args"] = std::make_shared<nnvm::any>(new_arg_vector);
-  return g;
-}
-
 NNVM_REGISTER_PASS(QuantizeGraph)
 .describe("")
 .set_body(QuantizeGraph)
@@ -740,10 +570,5 @@ NNVM_REGISTER_PASS(SetCalibTableToQuantizedGraph)
 .set_body(SetCalibTableToQuantizedGraph)
 .set_change_graph(true);
 
-NNVM_REGISTER_PASS(OneDNNShiftedQuantization)
-.describe("Enables shifted quantization.")
-.set_body(OneDNNShiftedQuantization)
-.set_change_graph(true);
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/quantization/quantize_graph_pass.h b/src/operator/quantization/quantize_graph_pass.h
new file mode 100644
index 0000000..cd24854
--- /dev/null
+++ b/src/operator/quantization/quantize_graph_pass.h
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2021 by Contributors
+ * \file quantize_graph_pass.h
+ * \brief
+ */
+#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_GRAPH_PASS_H_
+#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_GRAPH_PASS_H_
+
+#include <mxnet/op_attr_types.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <string>
+#include "quantize_v2-inl.h"
+#include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::ObjectPtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+inline ObjectPtr CreateNode(std::string op_name, std::string node_name) {
+  ObjectPtr node = Node::Create();
+  node->attrs.name = node_name;
+  if (op_name == "nullptr") {
+    node->attrs.op = nullptr;
+    // ugly workaround because VariableParam is not exposed
+    node->attrs.parsed =
+      nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
+  } else {
+    node->attrs.op = Op::Get(op_name);
+  }
+  return node;
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_GRAPH_PASS_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h
index 4a39bf0..806ecfe 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h
@@ -65,6 +65,9 @@ static inline bool IsOutputUint8(const MKLDNNFCFullParam& full_param) {
        alg == mkldnn::algorithm::eltwise_sqrt || alg == mkldnn::algorithm::eltwise_exp ||
        alg == mkldnn::algorithm::eltwise_abs)) {
     return true;
+  } else if (full_param.mkldnn_param.shifted_output.has_value() &&
+             full_param.mkldnn_param.shifted_output.value()) {
+               return true;
   }
 
   return false;
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index e4baa0c..5578106 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -46,80 +46,6 @@
 namespace mxnet {
 namespace op {
 
-static inline size_t GetInSumIndex(const MKLDNNFCFullParam& param) {
-  assert(param.mkldnn_param.with_sum);
-  return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1);
-}
-
-class FCInputIndex {
- public:
-  explicit FCInputIndex(const MKLDNNFCFullParam full_param) {
-    auto& mkldnn_param   = full_param.mkldnn_param;
-    const bool has_bias  = !full_param.default_param.no_bias;
-    const bool quantized = mkldnn_param.quantized;
-    const bool sum_input_quantized =
-        quantized && mkldnn_param.with_sum && !mkldnn_param.enable_float_output;
-    const bool channel_wise = quantized && mkldnn_param.channel_wise_quantize.has_value() &&
-                              mkldnn_param.channel_wise_quantize.value();
-
-    // Calculate position of particular input in the input vector:
-    int index     = 0;
-    data          = index++;
-    weight        = index++;
-    bias          = has_bias ? index++ : 0;
-    num_quantized = index + (sum_input_quantized ? 1 : 0);
-    sum           = mkldnn_param.with_sum ? index++ : 0;
-    num_base      = index;
-
-    data_min   = quantized ? index++ : 0;
-    data_max   = quantized ? index++ : 0;
-    weight_min = (quantized && !channel_wise) ? index++ : 0;
-    weight_max = (quantized && !channel_wise) ? index++ : 0;
-    bias_min   = (quantized && !channel_wise && has_bias) ? index++ : 0;
-    bias_max   = (quantized && !channel_wise && has_bias) ? index++ : 0;
-    sum_min    = sum_input_quantized ? index++ : 0;
-    sum_max    = sum_input_quantized ? index++ : 0;
-    num_total  = index;
-  }
-
-  // true if sum input is used and it is float number
-  bool IsSumInputFloat() const {
-    return (sum && !sum_min);
-  }
-  int GetTotal() const {
-    return num_total;
-  }
-  int GetBase() const {
-    return num_base;
-  }
-
-  // return number of standard inputs which are quantized (represented as
-  // integer)
-  int GetQuantized() const {
-    return num_quantized;
-  }
-
-  // Represent index of particular input in the input vector:
-  int data;
-  int weight;
-  int bias;
-  int sum;
-  int data_min;
-  int data_max;
-  int weight_min;
-  int weight_max;
-  int bias_min;
-  int bias_max;
-  int sum_min;
-  int sum_max;
-
- private:
-  int num_base;       // Number of standard inputs
-  int num_total;      // Number of total inputs: standard + additional needed for
-                      // quantization
-  int num_quantized;  // Number of standard inputs which are quantized
-};
-
 class SgMKLDNNFCOp {
  public:
   explicit SgMKLDNNFCOp(const nnvm::NodeAttrs& attrs)
@@ -547,10 +473,16 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
   MKLDNNStream::Get()->Submit();
 
   if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) {
-    float* min_output_ptr = out_data[out_min_index].data().dptr<float>();
-    float* max_output_ptr = out_data[out_max_index].data().dptr<float>();
-    *min_output_ptr       = cached_min_output_;
-    *max_output_ptr       = cached_max_output_;
+    float *min_output_ptr = out_data[out_min_index].data().dptr<float>();
+    float *max_output_ptr = out_data[out_max_index].data().dptr<float>();
+
+    if (mkldnn_param.shifted_output.has_value() && mkldnn_param.shifted_output.value()) {
+      *min_output_ptr = 0;
+      *max_output_ptr = cached_max_output_ - cached_min_output_;
+    } else {
+      *min_output_ptr = cached_min_output_;
+      *max_output_ptr = cached_max_output_;
+    }
   }
 }
 
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 6c1878a..637f1fe 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -1255,7 +1255,7 @@ def test_get_optimal_thresholds():
 
 
 @with_seed()
-def test_onednn_shifted_quantization():
+def test_onednn_shifted_quantize_fc():
     batch_size = 1
     if not is_test_for_mkldnn():
         print("Test only for mkldnn")
@@ -1303,7 +1303,7 @@ def test_onednn_shifted_quantization():
         return fc_layer, quantize_attrs
 
     def get_fc_layer():
-        fc_layer = mx.gluon.nn.Dense(5, use_bias=True, flatten=True,
+        fc_layer = mx.gluon.nn.Dense(20, use_bias=True, flatten=True,
                                      weight_initializer=mx.initializer.Normal(),
                                      bias_initializer=mx.initializer.Normal())
         fc_layer.initialize()
@@ -1330,7 +1330,7 @@ def test_onednn_shifted_quantization():
         min_range = mx.nd.min(out).asscalar()
         max_range = mx.nd.max(out).asscalar()
         atol = 0.1 * max(abs(min_range), abs(max_range))
-        assert_almost_equal_with_err(out.asnumpy(), out_q.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+        assert_almost_equal_with_err(out_q.asnumpy(), out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
 
         if qdtype == 'auto':
             assert quantize_attrs['shifted'] == 'True'
@@ -1348,6 +1348,84 @@ def test_onednn_shifted_quantization():
             check(i, qdtype)
 
 
+@with_seed()
+def test_onednn_shifted_quantize_fc_fc():
+    batch_size = 2
+    if not is_test_for_mkldnn():
+        print("Test only for mkldnn")
+        return
+
+    def get_fc_fc_layers(with_eltwise):
+        class Net(mx.gluon.nn.HybridBlock):
+            def __init__(self):
+                super(Net, self).__init__()
+                self.fc1 = mx.gluon.nn.Dense(20, use_bias=True, flatten=True,
+                                             weight_initializer=mx.initializer.Normal(),
+                                             bias_initializer=mx.initializer.Normal())
+                self.relu = mx.gluon.nn.Activation('relu') if with_eltwise else None
+                self.fc2 = mx.gluon.nn.Dense(20, use_bias=True, flatten=True,
+                                             weight_initializer=mx.initializer.Normal(),
+                                             bias_initializer=mx.initializer.Normal())
+
+            def hybrid_forward(self, F, x):
+                out = self.fc1(x)
+                if self.relu is not None:
+                    out = self.relu(out)
+                out = self.fc2(out)
+                return out
+
+        net = Net()
+        net.initialize()
+        return net
+
+    def quantize_net(with_eltwise, qdtype, net, random_data):
+        calib_data = NDArrayIter(data=random_data, batch_size=batch_size)
+        calib_data = DummyIter(calib_data)
+        net = mx.contrib.quant.quantize_net(net, quantize_mode='smart',
+                                            quantized_dtype=qdtype,
+                                            exclude_layers=None,
+                                            exclude_layers_match=[],
+                                            calib_data=calib_data,
+                                            calib_mode='naive',
+                                            num_calib_examples=1,
+                                            ctx=mx.current_context())
+        net.hybridize(static_alloc=True, static_shape=True)
+        out = net(random_data)
+        out.wait_to_read()
+
+        _, sym = net._cached_graph
+        fc0_name = "quantized_sg_mkldnn_fully_connected%s_0" %("_eltwise" if with_eltwise else "")
+        fc0_attrs = sym.attr_dict()[fc0_name]
+
+        if qdtype == 'auto':
+            assert fc0_attrs['shifted_output'] == 'True'
+        else:
+            assert 'shifted_output' not in fc0_attrs
+
+        return out
+
+    def check(with_eltwise, qdtype, random_data):
+        net_ref = get_fc_fc_layers(with_eltwise)
+        out_ref = net_ref(random_data)
+        out_ref.wait_to_read()
+
+        out_q = quantize_net(with_eltwise, qdtype, net_ref, random_data)
+
+        min_range = mx.nd.min(out_ref).asscalar()
+        max_range = mx.nd.max(out_ref).asscalar()
+        atol = 0.1 * max(abs(min_range), abs(max_range))
+        assert_almost_equal_with_err(out_q.asnumpy(), out_ref.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
+
+    with environment({'MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS': '0',
+                      'MXNET_DISABLE_SHIFTED_QUANTIZE_FC_OPTIMIZATION': '0'}):
+        for with_eltwise in [False, True]:
+            for qdtype in ['int8', 'uint8', 'auto']:
+                print("with_eltwise:", with_eltwise)
+                print("qdtype:", qdtype)
+                data = mx.nd.random_uniform(low=0 if qdtype == 'uint8' else -1, high=1, shape=(batch_size, 10))
+                check(with_eltwise, qdtype, data)
+
+
 if __name__ == "__main__":
     import nose
     nose.runmodule()