You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/06/14 04:58:39 UTC

[incubator-mxnet] branch master updated: [MXNET-290] MKLDNN support for model quantization (#10433)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d79e1ad  [MXNET-290] MKLDNN support for model quantization (#10433)
d79e1ad is described below

commit d79e1ad3294837cac653478045023fd312ceed78
Author: wentingj <we...@intel.com>
AuthorDate: Thu Jun 14 12:58:33 2018 +0800

    [MXNET-290] MKLDNN support for model quantization (#10433)
    
    * mkldnn support for quantization
    
    * fix output number in graph
    
    * update licsence
    
    * modify Jenkinsfile
    
    * modify Jenkinsfile
    
    * mkldnn has no int8 fc api, excluded_sym_names includes fc for cpu
    
    * add mkldnn uint8 pass for quantization graph
    
    * update ut
    
    * retrig ic
    
    * remove no mkldnn quantization test temp
    
    * seperate mkldnn quantization ut from gpu quantization ut
    
    * rm dev_id check for cpu
    
    * add mkl tests dictionary
    
    * resolve review comments
    
    * simplify DequantizeStorageType() logic
    
    * simplify quantize/quantized_conv storage type logic
    
    * Add mkldnn_OIhw4i16o4i type case (needed by int8)
    
    * INT8 conv/pooling: share with FP32 convolution/pooling class/function
    
    * minor indent changes
    
    * Remove unnecessary mkldnn_quantized_pooling-inl.h
    
    * Fix minor issue
    
    * Fix lint
    
    * delete duplicated data type
    
    * fix bugs and convert requantize data to NDArray
    
    * fix lint
    
    * fix requantize storgetype
    
    * fix requantize storge type
    
    * Fix coding style comments
    
    * Fix compile issue
    
    * Change to use quantized_dtype option to support uint8/int8 scenarios
    
    * fix gpu test quantization failure
    
    * Fix indent
    
    * fix quantized pooling param parser
    
    * Fix imagenet_gen_qsym.py option style
    
    * retrigger jenkins
    
    * retrigger again
    
    * trigger jenkins
    
    * Resolve further comments
    
    * share test code
    
    * remove unnecessary test code
    
    * add test_quantize_model for cpu
    
    * add comments in quantize_graph_pass.cc
    
    * jenkins
    
    * jenkins
    
    * improve coding style
    
    * improve coding style
    
    * Add naive CPU quantization test back and share quantization code between naive-CPU/MKLDNN/GPU
    
    * rename test_quantization_cpu.py to test_quantization_mkldnn.py
    
    * code style
    
    * trigger
    
    * Adjust variable naming for test quantization
    
    * add qdtype for quantized op test case to test/bypass all cases explicitly
    
    * change expressions to be consistent
    
    * revert unnecessary change
---
 ci/docker/runtime_functions.sh                     |   3 +-
 example/quantization/imagenet_gen_qsym.py          |  44 +-
 example/quantization/imagenet_inference.py         |  10 +-
 include/mxnet/c_api.h                              |   4 +-
 python/mxnet/contrib/quantization.py               |  21 +-
 src/c_api/c_api_symbolic.cc                        |   5 +-
 src/operator/nn/convolution-inl.h                  |   2 +
 src/operator/nn/convolution.cc                     |   2 +-
 src/operator/nn/mkldnn/mkldnn_convolution-inl.h    |  77 ++++
 src/operator/nn/mkldnn/mkldnn_convolution.cc       | 109 ++---
 src/operator/nn/mkldnn/mkldnn_pooling-inl.h        |   4 +
 src/operator/nn/pooling-inl.h                      |   2 +
 src/operator/nn/pooling.cc                         |   2 +-
 src/operator/quantization/dequantize.cc            |  24 +
 .../quantization/mkldnn/mkldnn_dequantize-inl.h    | 105 +++++
 .../quantization/mkldnn/mkldnn_quantize-inl.h      | 112 +++++
 .../quantization/mkldnn/mkldnn_quantized_conv.cc   |  89 ++++
 .../mkldnn/mkldnn_quantized_pooling.cc             |  54 +++
 .../quantization/mkldnn/mkldnn_requantize-inl.h    | 158 +++++++
 src/operator/quantization/quantize.cc              |  24 +
 src/operator/quantization/quantize_graph_pass.cc   |  20 +-
 src/operator/quantization/quantized_conv.cc        |  27 +-
 src/operator/quantization/quantized_flatten-inl.h  |  23 +-
 src/operator/quantization/quantized_pooling.cc     |  31 +-
 src/operator/quantization/requantize.cc            |  25 +
 tests/python/mkl/test_quantization_mkldnn.py       |  28 ++
 tests/python/quantization/test_quantization.py     | 506 ++++++++++++---------
 27 files changed, 1185 insertions(+), 326 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 36e2387..293ac64 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -466,13 +466,12 @@ unittest_ubuntu_python3_cpu() {
 
 unittest_ubuntu_python3_cpu_mkldnn() {
     set -ex
-    export PYTHONPATH=./python/ 
+    export PYTHONPATH=./python/
     # MXNET_MKLDNN_DEBUG is buggy and produces false positives
     # https://github.com/apache/incubator-mxnet/issues/10026
     #export MXNET_MKLDNN_DEBUG=1  # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
     nosetests-3.4 --verbose tests/python/unittest
-    nosetests-3.4 --verbose tests/python/quantization
     nosetests-3.4 --verbose tests/python/mkl
 }
 
diff --git a/example/quantization/imagenet_gen_qsym.py b/example/quantization/imagenet_gen_qsym.py
index 045ce62..85474b6 100644
--- a/example/quantization/imagenet_gen_qsym.py
+++ b/example/quantization/imagenet_gen_qsym.py
@@ -53,6 +53,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model')
+    parser.add_argument('--ctx', type=str, default='gpu')
     parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
                         help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
     parser.add_argument('--batch-size', type=int, default=32)
@@ -91,8 +92,18 @@ if __name__ == '__main__':
                              ' thresholds. This mode is expected to produce the best inference accuracy of all three'
                              ' kinds of quantized models if the calibration dataset is representative enough of the'
                              ' inference dataset.')
+    parser.add_argument('--quantized-dtype', type=str, default='int8', 
+                        choices=['int8', 'uint8'],
+                        help='quantization destination data type for input data')
     args = parser.parse_args()
 
+    if args.ctx == 'gpu':
+        ctx = mx.gpu(0)
+    elif args.ctx == 'cpu':
+        ctx = mx.cpu(0)
+    else:
+        raise ValueError('ctx %s is not supported in this script' % args.ctx)
+
     logging.basicConfig()
     logger = logging.getLogger('logger')
     logger.setLevel(logging.INFO)
@@ -129,17 +140,26 @@ if __name__ == '__main__':
     excluded_sym_names = []
     if args.model == 'imagenet1k-resnet-152':
         rgb_mean = '0,0,0'
-        calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
-                                                                 or name.find('sc') != -1
-                                                                 or name.find('fc') != -1)
+        if args.ctx == 'gpu':
+            calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
+                                                                     or name.find('sc') != -1
+                                                                     or name.find('fc') != -1)
+        else:
+            calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
+                                                                     or name.find('sc') != -1)
+            excluded_sym_names += ['flatten0', 'fc1']
         if exclude_first_conv:
-            excluded_sym_names = ['conv0']
+            excluded_sym_names += ['conv0']
     elif args.model == 'imagenet1k-inception-bn':
         rgb_mean = '123.68,116.779,103.939'
-        calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
-                                                                 or name.find('fc') != -1)
+        if args.ctx == 'gpu':
+            calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1
+                                                                     or name.find('fc') != -1)
+        else:
+            calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1)
+            excluded_sym_names += ['flatten', 'fc1']
         if exclude_first_conv:
-            excluded_sym_names = ['conv_1']
+            excluded_sym_names += ['conv_1']
     else:
         raise ValueError('model %s is not supported in this script' % args.model)
 
@@ -156,8 +176,9 @@ if __name__ == '__main__':
     if calib_mode == 'none':
         logger.info('Quantizing FP32 model %s' % args.model)
         qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
-                                                       excluded_sym_names=excluded_sym_names,
-                                                       calib_mode=calib_mode, logger=logger)
+                                                       ctx=ctx, excluded_sym_names=excluded_sym_names,
+                                                       calib_mode=calib_mode, quantized_dtype=args.quantized_dtype,
+                                                       logger=logger)
         sym_name = '%s-symbol.json' % (prefix + '-quantized')
         save_symbol(sym_name, qsym, logger)
     else:
@@ -176,10 +197,11 @@ if __name__ == '__main__':
                                      **mean_args)
 
         cqsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
-                                                        ctx=mx.gpu(0), excluded_sym_names=excluded_sym_names,
+                                                        ctx=ctx, excluded_sym_names=excluded_sym_names,
                                                         calib_mode=calib_mode, calib_data=data,
                                                         num_calib_examples=num_calib_batches * batch_size,
-                                                        calib_layer=calib_layer, logger=logger)
+                                                        calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
+                                                        logger=logger)
         if calib_mode == 'entropy':
             suffix = '-quantized-%dbatches-entropy' % num_calib_batches
         elif calib_mode == 'naive':
diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py
index fe3f266..8564953 100644
--- a/example/quantization/imagenet_inference.py
+++ b/example/quantization/imagenet_inference.py
@@ -99,6 +99,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Score a model on a dataset')
+    parser.add_argument('--ctx', type=str, default='gpu')
     parser.add_argument('--symbol-file', type=str, required=True, help='symbol file path')
     parser.add_argument('--param-file', type=str, required=True, help='param file path')
     parser.add_argument('--batch-size', type=int, default=32)
@@ -122,6 +123,13 @@ if __name__ == '__main__':
 
     args = parser.parse_args()
 
+    if args.ctx == 'gpu':
+        ctx = mx.gpu(0)
+    elif args.ctx == 'cpu':
+        ctx = mx.cpu(0)
+    else:
+        raise ValueError('ctx %s is not supported in this script' % args.ctx)
+    
     logging.basicConfig()
     logger = logging.getLogger('logger')
     logger.setLevel(logging.INFO)
@@ -172,5 +180,5 @@ if __name__ == '__main__':
 
     num_inference_images = args.num_inference_batches * batch_size
     logger.info('Running model %s for inference' % symbol_file)
-    score(sym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+    score(sym, arg_params, aux_params, data, [ctx], label_name,
           max_num_examples=num_inference_images, logger=logger)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 6b7cf44..4dd858a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1431,13 +1431,15 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
  * \param excluded_symbols array of symbols to be excluded from being quantized
  * \param num_offline number of parameters that are quantized offline
  * \param offline_params array of c strings representing the names of params quantized offline
+ * \param quantized_dtype the quantized destination type for input data.
  */
 MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle,
                                SymbolHandle *ret_sym_handle,
                                const mx_uint num_excluded_symbols,
                                const SymbolHandle *excluded_symbols,
                                const mx_uint num_offline,
-                               const char **offline_params);
+                               const char **offline_params,
+                               const char *quantized_dtype);
 
 /*!
  * \brief Set calibration table to node attributes in the sym
diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index c9c58a9..1314b97 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -72,7 +72,8 @@ def _quantize_params(qsym, params):
     return quantized_params
 
 
-def _quantize_symbol(sym, excluded_symbols=None, offline_params=None):
+def _quantize_symbol(sym, excluded_symbols=None, offline_params=None,
+                     quantized_dtype='int8'):
     """Given a symbol object representing a neural network of data type FP32,
     quantize it into a INT8 network.
 
@@ -86,6 +87,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None):
         Names of the parameters that users want to quantize offline. It's always recommended to
         quantize parameters offline so that quantizing parameters during the inference can be
         avoided.
+    quantized_dtype: str
+        The quantized destination type for input data.
     """
     num_excluded_symbols = 0
     excluded_handles = []
@@ -108,7 +111,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None):
                                      mx_uint(num_excluded_symbols),
                                      c_array(SymbolHandle, excluded_handles),
                                      mx_uint(num_offline),
-                                     c_array(ctypes.c_char_p, offline)))
+                                     c_array(ctypes.c_char_p, offline),
+                                     c_str(quantized_dtype)))
     return Symbol(out)
 
 
@@ -401,7 +405,8 @@ def _load_params(params, logger=logging):
 def quantize_model(sym, arg_params, aux_params,
                    data_names=('data',), label_names=('softmax_label',),
                    ctx=cpu(), excluded_sym_names=None, calib_mode='entropy',
-                   calib_data=None, num_calib_examples=None, calib_layer=None, logger=logging):
+                   calib_data=None, num_calib_examples=None, calib_layer=None,
+                   quantized_dtype='int8', logger=logging):
     """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration.
     The backend quantized operators are only enabled for Linux systems. Please do not run
     inference using the quantized models on Windows for now.
@@ -451,6 +456,9 @@ def quantize_model(sym, arg_params, aux_params,
         calibrate this layer. If yes, the statistics of the layer's output will be collected;
         otherwise, no information of the layer's output will be collected. If not provided,
         all the layers' outputs that need requantization will be collected.
+    quantized_dtype : str
+        The quantized destination type for input data. Currently support 'int8'
+        and 'uint8', default value is 'int8'.
     logger : Object
         A logging object for printing information during the process of quantization.
 
@@ -473,8 +481,13 @@ def quantize_model(sym, arg_params, aux_params,
             idx = nodes.list_outputs().index(sym_name + '_output')
             excluded_syms.append(nodes[idx])
     logger.info('Quantizing symbol')
+
+    if quantized_dtype != 'int8' and quantized_dtype != 'uint8':
+        raise ValueError('unknown quantized_dtype %s received,'
+                         ' expected `int8` or `uint8`' % quantized_dtype)
     qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms,
-                            offline_params=list(arg_params.keys()))
+                            offline_params=list(arg_params.keys()),
+                            quantized_dtype=quantized_dtype)
 
     logger.info('Quantizing parameters')
     qarg_params = _quantize_params(qsym, arg_params)
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 4666b6a..e5e9b52 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -577,7 +577,8 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
                      const mx_uint num_excluded_symbols,
                      const SymbolHandle *excluded_symbols,
                      const mx_uint num_offline,
-                     const char **offline_params) {
+                     const char **offline_params,
+                     const char *quantized_dtype) {
   nnvm::Symbol *s = new nnvm::Symbol();
   API_BEGIN();
   nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
@@ -594,7 +595,9 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
   for (size_t i = 0; i < num_offline; ++i) {
     offline.emplace(offline_params[i]);
   }
+  std::string quantized_type(quantized_dtype);
   g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
+  g.attrs["quantized_dtype"] = std::make_shared<nnvm::any>(std::move(quantized_type));
   g = ApplyPass(std::move(g), "QuantizeGraph");
   s->outputs = g.outputs;
   *ret_sym_handle = s;
diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h
index 5632d73..d40abaf 100644
--- a/src/operator/nn/convolution-inl.h
+++ b/src/operator/nn/convolution-inl.h
@@ -125,6 +125,8 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
   }
 };
 
+void ConvolutionParamParser(nnvm::NodeAttrs* attrs);
+
 typedef ParamOpSign<ConvolutionParam> ConvSignature;
 
 }  // namespace op
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 0e8a929..ef70ccd 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -331,7 +331,7 @@ inline static bool BackwardConvStorageType(const nnvm::NodeAttrs& attrs,
                              dispatch_mode, wanted_mode);
 }
 
-static void ConvolutionParamParser(nnvm::NodeAttrs* attrs) {
+void ConvolutionParamParser(nnvm::NodeAttrs* attrs) {
   using namespace mshadow;
   ConvolutionParam param_;
   try {
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
new file mode 100644
index 0000000..23f2fe6
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_convolution-inl.h
+ * \brief
+*/
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+
+#include <utility>
+#include "../convolution-inl.h"
+#include "./mkldnn_ops-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
+    const ConvolutionParam& param, const bool is_train, const NDArray &data,
+    const NDArray &weights, const NDArray *bias, const NDArray &output);
+
+class MKLDNNConvForward {
+ public:
+  mkldnn::convolution_forward::primitive_desc fwd_pd;
+
+  MKLDNNConvForward(const ConvolutionParam& param, const bool is_train,
+                    const NDArray &data, const NDArray &weights,
+                    const NDArray *bias, const NDArray &output): fwd_pd(
+                        GetConvFwdImpl(param, is_train, data, weights, bias, output)) {
+  }
+
+  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
+                 const mkldnn::memory *bias, const mkldnn::memory &output);
+
+  const mkldnn::convolution_forward &GetFwd() const {
+    return *fwd_;
+  }
+
+ private:
+  std::shared_ptr<mkldnn::convolution_forward> fwd_;
+  std::shared_ptr<mkldnn::memory> data_;
+  std::shared_ptr<mkldnn::memory> weight_;
+  std::shared_ptr<mkldnn::memory> bias_;
+  std::shared_ptr<mkldnn::memory> out_;
+};
+
+typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
+
+MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs,
+    const bool is_train, const NDArray &data, const NDArray &weights,
+    const NDArray *bias, const NDArray &output);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index f851a6d..cf04ea8 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -23,11 +23,14 @@
  * \author Da Zheng
 */
 
+
+#if MXNET_USE_MKLDNN == 1
+
 #include "../convolution-inl.h"
 #include "./mkldnn_ops-inl.h"
 #include "./mkldnn_base-inl.h"
+#include "./mkldnn_convolution-inl.h"
 
-#if MXNET_USE_MKLDNN == 1
 namespace mxnet {
 namespace op {
 
@@ -37,8 +40,8 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
   return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4;
 }
 
-static mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
-    const ConvolutionParam& param, bool is_train, const NDArray &data,
+mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
+    const ConvolutionParam& param, const bool is_train, const NDArray &data,
     const NDArray &weights, const NDArray *bias, const NDArray &output) {
   auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
   auto data_md = GetMemDesc(data);
@@ -162,73 +165,51 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
   }
 }
 
-class MKLDNNConvForward {
-  std::shared_ptr<mkldnn::convolution_forward> fwd;
-  std::shared_ptr<mkldnn::memory> data;
-  std::shared_ptr<mkldnn::memory> weight;
-  std::shared_ptr<mkldnn::memory> bias;
-  std::shared_ptr<mkldnn::memory> out;
+void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data,
+                                  const mkldnn::memory &weight,
+                                  const mkldnn::memory *bias,
+                                  const mkldnn::memory &output) {
+  if (this->data_ == nullptr)
+    this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.src_primitive_desc(), data.get_data_handle()));
+  else
+    this->data_->set_data_handle(data.get_data_handle());
 
- public:
-  mkldnn::convolution_forward::primitive_desc fwd_pd;
+  if (this->weight_ == nullptr)
+    this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+  else
+    this->weight_->set_data_handle(weight.get_data_handle());
 
-  MKLDNNConvForward(const ConvolutionParam& param, bool is_train,
-                    const NDArray &data, const NDArray &weights,
-                    const NDArray *bias, const NDArray &output): fwd_pd(
-                        GetConvFwdImpl(param, is_train, data, weights, bias, output)) {
-  }
-
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output) {
-    if (this->data == nullptr)
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              fwd_pd.src_primitive_desc(), data.get_data_handle()));
-    else
-      this->data->set_data_handle(data.get_data_handle());
+  if (this->out_ == nullptr)
+    this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+  else
+    this->out_->set_data_handle(output.get_data_handle());
 
-    if (this->weight == nullptr)
-      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+  if (bias != nullptr) {
+    if (this->bias_ == nullptr)
+      this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
     else
-      this->weight->set_data_handle(weight.get_data_handle());
-
-    if (this->out == nullptr)
-      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              fwd_pd.dst_primitive_desc(), output.get_data_handle()));
-    else
-      this->out->set_data_handle(output.get_data_handle());
-
-    if (bias != nullptr) {
-      if (this->bias == nullptr)
-        this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-                fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
-      else
-        this->bias->set_data_handle(bias->get_data_handle());
-      if (this->fwd == nullptr)
-        this->fwd = std::shared_ptr<mkldnn::convolution_forward>(
-            new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data),
-                                            mkldnn::primitive::at(*this->weight),
-                                            mkldnn::primitive::at(*this->bias),
-                                            *this->out));
-    } else if (this->fwd == nullptr) {
-      this->fwd = std::shared_ptr<mkldnn::convolution_forward>(
-          new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data),
-                                          mkldnn::primitive::at(*this->weight),
-                                          *this->out));
-    }
+      this->bias_->set_data_handle(bias->get_data_handle());
+    if (this->fwd_ == nullptr)
+      this->fwd_ = std::shared_ptr<mkldnn::convolution_forward>(
+          new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
+                                          mkldnn::primitive::at(*this->weight_),
+                                          mkldnn::primitive::at(*this->bias_),
+                                          *this->out_));
+  } else if (this->fwd_ == nullptr) {
+    this->fwd_ = std::shared_ptr<mkldnn::convolution_forward>(
+        new mkldnn::convolution_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
+                                        mkldnn::primitive::at(*this->weight_),
+                                        *this->out_));
   }
+}
 
-  const mkldnn::convolution_forward &GetFwd() const {
-    return *fwd;
-  }
-};
-
-typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
-
-static inline MKLDNNConvForward &GetConvFwd(
-    const nnvm::NodeAttrs& attrs, bool is_train,
-    const NDArray &data, const NDArray &weights,
-    const NDArray *bias, const NDArray &output) {
+MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train,
+                              const NDArray &data, const NDArray &weights,
+                              const NDArray *bias, const NDArray &output) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvForward, OpHash> fwds;
 #else
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
index 4b6235e..691e1d3 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
@@ -119,6 +119,10 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
                               const NDArray &out_grad, const NDArray &in_data,
                               const NDArray *workspace, const OpReqType req,
                               const NDArray &in_grad);
+MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
+                                const bool is_train,
+                                const NDArray &data,
+                                const NDArray &output);
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h
index a4770b4..9c7b1af 100644
--- a/src/operator/nn/pooling-inl.h
+++ b/src/operator/nn/pooling-inl.h
@@ -41,6 +41,8 @@
 namespace mxnet {
 namespace op {
 
+void PoolingParamParser(nnvm::NodeAttrs *attrs);
+
 struct PoolingParam : public dmlc::Parameter<PoolingParam> {
   TShape kernel;
   TShape stride;
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 3ff94da..3200a51 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -35,7 +35,7 @@
 namespace mxnet {
 namespace op {
 
-static void PoolingParamParser(nnvm::NodeAttrs *attrs) {
+void PoolingParamParser(nnvm::NodeAttrs *attrs) {
   using namespace mshadow;
   PoolingParam param;
   param.Init(attrs->dict);
diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc
index 92b808d..bbd7941 100644
--- a/src/operator/quantization/dequantize.cc
+++ b/src/operator/quantization/dequantize.cc
@@ -23,11 +23,31 @@
  * \brief
  */
 #include "./dequantize-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_dequantize-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
 DMLC_REGISTER_PARAMETER(DequantizeParam);
 
+bool DequantizeStorageType(const nnvm::NodeAttrs& attrs,
+                           const int dev_mask,
+                           DispatchMode* dispatch_mode,
+                           std::vector<int> *in_attrs,
+                           std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  (*out_attrs)[0] = kDefaultStorage;
+  (*out_attrs)[1] = kDefaultStorage;
+  (*out_attrs)[2] = kDefaultStorage;
+  return true;
+}
+
 NNVM_REGISTER_OP(_contrib_dequantize)
 .describe(R"code(Dequantize the input tensor into a float tensor.
 min_range and max_range are scalar floats that specify the range for
@@ -50,6 +70,10 @@ by keep zero centered for the quantized value:
 .set_num_outputs(1)
 .set_attr<nnvm::FInferShape>("FInferShape", DequantizeShape)
 .set_attr<nnvm::FInferType>("FInferType", DequantizeType)
+.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
+#endif
 .set_attr<FCompute>("FCompute<cpu>", DequantizeCompute<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `uint8`")
 .add_argument("min_range", "NDArray-or-Symbol", "The minimum scalar value "
diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
new file mode 100644
index 0000000..89c3c19
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_dequantize-inl.h
+ * \author Wenting Jiang, Xinyu Chen
+ * \brief
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
+#if MXNET_USE_MKLDNN == 1
+#include <string>
+#include <algorithm>
+#include <vector>
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename SrcType, typename DstType>
+static void MKLDNNDequantizeComputeKer(const std::vector<NDArray> &inputs,
+                                       const std::vector<NDArray> &outputs,
+                                       const std::vector<OpReqType> &req) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using red::limits::MaxValue;
+  using red::limits::MinValue;
+  float real_range = 0.0;
+  float quantized_range = 0.0;
+  if (inputs[0].dtype() == mshadow::kUint8) {
+    quantized_range = MaxAbs(MaxValue<SrcType>(), MinValue<SrcType>());
+    real_range = MaxAbs(*inputs[1].data().dptr<DstType>(), *inputs[2].data().dptr<DstType>());
+  } else if (inputs[0].dtype() == mshadow::kInt8) {
+    quantized_range = MinAbs(MaxValue<SrcType>(), MinValue<SrcType>());
+    real_range = MaxAbs(*inputs[1].data().dptr<DstType>(), *inputs[2].data().dptr<DstType>());
+  } else {
+    LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type";
+  }
+  float scale = real_range / quantized_range;
+  primitive_attr attr;
+  const int mask = 0;
+  std::vector<float> scales = {scale};
+  attr.set_output_scales(mask, scales);
+  attr.set_int_output_round_mode(round_nearest);
+  mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
+
+  NDArray in_buffer = inputs[0];
+  if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
+    in_buffer = inputs[0].Reorder2Default();
+
+  auto i_mem = in_buffer.GetMKLDNNData();
+  auto i_mpd = i_mem->get_primitive_desc();
+  auto i_desc = i_mpd.desc();
+  size_t i_ndim = in_buffer.shape().ndim();
+  mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
+  for (size_t i = 0; i < i_ndim; i++) {
+    i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
+  }
+  mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
+  auto o_desc = mkldnn::memory::desc(i_dims,
+                                    (mkldnn::memory::data_type)data_type_enum<DstType>::type,
+                                    i_fmt);
+  auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
+  auto reorder_pd  = reorder::primitive_desc(i_mpd, o_mpd, attr);
+  auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
+  MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
+  CommitOutput(outputs[0], o_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+static void MKLDNNDequantizeCompute(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                                    const std::vector<NDArray> &inputs,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &outputs) {
+  if (inputs[0].dtype() == mshadow::kUint8) {
+    MKLDNNDequantizeComputeKer<uint8_t, float>(inputs, outputs, req);
+  } else if (inputs[0].dtype() == mshadow::kInt8) {
+    MKLDNNDequantizeComputeKer<int8_t, float>(inputs, outputs, req);
+  } else {
+    LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as input type";
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
new file mode 100644
index 0000000..f770931
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_quantize-inl.h
+ * \brief
+ * \author Wenting Jiang, Xinyu Chen
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_
+#if MXNET_USE_MKLDNN == 1
+#include <string>
+#include <algorithm>
+#include <vector>
+#include "../quantize-inl.h"
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename SrcType, typename DstType>
+static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
+                                     const std::vector<NDArray>& outputs,
+                                     const QuantizeParam& param,
+                                     const std::vector<OpReqType> &req) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using red::limits::MaxValue;
+  using red::limits::MinValue;
+  float real_range = 0.0;
+  float quantized_range = 0.0;
+  if (param.out_type == mshadow::kUint8) {
+    real_range = MaxAbs(*inputs[1].data().dptr<float>(), *inputs[2].data().dptr<float>());
+    quantized_range = MaxAbs(MaxValue<DstType>(), MinValue<DstType>());
+    *outputs[1].data().dptr<float>() = *inputs[1].data().dptr<float>();
+    *outputs[2].data().dptr<float>() = *inputs[2].data().dptr<float>();
+  } else if (param.out_type == mshadow::kInt8) {
+    real_range = MaxAbs(*inputs[1].data().dptr<float>(), *inputs[2].data().dptr<float>());
+    quantized_range = MinAbs(MaxValue<DstType>(), MinValue<DstType>());
+    *outputs[1].data().dptr<float>() = -real_range;
+    *outputs[2].data().dptr<float>() = real_range;
+  } else {
+    LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
+  }
+  float scale = quantized_range / real_range;
+  primitive_attr attr;
+  const int mask = 0;
+  std::vector<float> scales = {scale};
+  attr.set_output_scales(mask, scales);
+  attr.set_int_output_round_mode(round_nearest);
+  mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
+
+  NDArray in_buffer = inputs[0];
+  if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
+    in_buffer = inputs[0].Reorder2Default();
+
+  auto i_mem = in_buffer.GetMKLDNNData();
+  auto i_mpd = i_mem->get_primitive_desc();
+  auto i_desc = i_mpd.desc();
+  mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
+  size_t i_ndim = in_buffer.shape().ndim();
+  mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
+  for (size_t i = 0; i < i_ndim; i++) {
+    i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
+  }
+  auto o_desc = mkldnn::memory::desc(i_dims,
+                                    (mkldnn::memory::data_type)data_type_enum<DstType>::type,
+                                    i_fmt);
+  auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
+  auto reorder_pd  = reorder::primitive_desc(i_mpd, o_mpd, attr);
+  auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
+  MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
+  CommitOutput(outputs[0], o_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+static void MKLDNNQuantizeCompute(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                                  const std::vector<NDArray> &inputs,
+                                  const std::vector<OpReqType> &req,
+                                  const std::vector<NDArray> &outputs) {
+  const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
+  if (param.out_type == mshadow::kUint8) {
+    MKLDNNQuantizeComputeKer<float, uint8_t>(inputs, outputs, param, req);
+  } else if (param.out_type == mshadow::kInt8) {
+    MKLDNNQuantizeComputeKer<float, int8_t>(inputs, outputs, param, req);
+  } else {
+    LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
new file mode 100644
index 0000000..fa6a32a
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_quantized_conv.cc
+ * \brief
+ * \author Wenting Jiang, Xinyu Chen
+*/
+
+#if MXNET_USE_MKLDNN == 1
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+#include "../../nn/mkldnn/mkldnn_convolution-inl.h"
+#include "../../nn/convolution-inl.h"
+#include "../quantization_utils.h"
+#include "../../tensor/matrix_op-inl.h"
+#include "../../elemwise_op_common.h"
+namespace mxnet {
+namespace op {
+
+static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
+                                       const OpContext &ctx,
+                                       const std::vector<NDArray> &in_data,
+                                       const std::vector<OpReqType> &req,
+                                       const std::vector<NDArray> &out_data) {
+  CHECK_EQ(in_data[0].dtype(), mshadow::kUint8)
+    << "mkldnn_quantized_conv op only supports uint8 as input type";
+  TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
+  const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
+  NDArray weight = in_data[conv::kWeight];
+  MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train,
+      in_data[conv::kData], weight,
+      param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]);
+
+  auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
+  const mkldnn::memory *weight_mem;
+  // For inference, we want to reorder the weight array so we don't need to
+  // reorder data every time.
+  if (weight.IsDefaultData()) {
+    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
+    // We also need to modify the layout on the original weight array. The
+    // data conversion happens after the weight array is used.
+    weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
+  } else {
+    weight_mem = weight.GetMKLDNNData();
+    CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
+  }
+  auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(),
+                                 req[conv::kOut]);
+  const mkldnn::memory *bias_mem = nullptr;
+  if (!param.no_bias)
+    bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());
+  fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+
+  CommitOutput(out_data[conv::kOut], out_mem);
+  MKLDNNStream::Get()->Submit();
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  const size_t num_inputs = param.no_bias ? 2 : 3;
+  mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
+           out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
+           in_data[num_inputs].data().dptr<float>(),
+           in_data[num_inputs+1].data().dptr<float>(),
+           in_data[num_inputs+2].data().dptr<float>(),
+           in_data[num_inputs+3].data().dptr<float>());
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_conv)
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedConvForward);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
new file mode 100644
index 0000000..83177ad
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_quantized_pooling.cc
+ * \brief
+ * \author Tao Lv, Xinyu Chen
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "../../nn/mkldnn/mkldnn_pooling-inl.h"
+
+namespace mxnet {
+namespace op {
+
+static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
+                                          const std::vector<NDArray> &in_data,
+                                          const std::vector<OpReqType> &req,
+                                          const std::vector<NDArray> &out_data) {
+  CHECK(in_data[0].dtype() == mshadow::kUint8
+    || in_data[0].dtype() == mshadow::kInt8)
+    << "mkldnn_quantized_pooling op only supports uint8 and int8 as input type";
+  const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
+  auto fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0]);
+  fwd.SetDataHandle(in_data[0], out_data[0]);
+  fwd.Execute();
+  out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
+  out_data[2].data().dptr<float>()[0] = in_data[2].data().dptr<float>()[0];
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_pooling)
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedPoolingForward);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
new file mode 100644
index 0000000..409c53d
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/* \file mkldnn_requantize-inl.h
+ * \brief
+ * \author Jin Huang, Xinyu Chen
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_REQUANTIZE_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_REQUANTIZE_INL_H_
+#if MXNET_USE_MKLDNN == 1
+#include <string>
+#include <algorithm>
+#include <vector>
+#include "../requantize-inl.h"
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
+                                       const OpContext& ctx,
+                                       const std::vector<NDArray>& inputs,
+                                       const std::vector<OpReqType>& req,
+                                       const std::vector<NDArray>& outputs,
+                                       const float real_range) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using red::limits::MaxValue;
+  using red::limits::MinValue;
+  typedef int32_t SrcDType;
+  typedef int8_t  DstDType;
+  // check shapes
+  size_t i_dim = inputs[0].shape().ndim();
+  size_t o_dim = outputs[0].shape().ndim();
+  CHECK_EQ(i_dim, o_dim);
+  float first_quantized_range = MinAbs(MinValue<SrcDType>(),
+                                       MaxValue<SrcDType>());
+  float first_real_range = MaxAbs(*inputs[1].data().dptr<float>(),
+                                  *inputs[2].data().dptr<float>());
+  float first_scale = first_real_range / first_quantized_range;
+  float second_real_range = real_range;
+  float second_quantized_range = MinAbs(MaxValue<DstDType>(),
+                                        MinValue<DstDType>());
+  float second_scale = second_quantized_range / second_real_range;
+  float scale = first_scale * second_scale;
+  *outputs[1].data().dptr<float>() = -second_real_range;
+  *outputs[2].data().dptr<float>() = second_real_range;
+  primitive_attr attr;
+  const int mask = 0;
+  std::vector<float> scales = {scale};
+  attr.set_output_scales(mask, scales);
+  attr.set_int_output_round_mode(round_nearest);
+  mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
+
+  NDArray in_buffer = inputs[0];
+  if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
+    in_buffer = inputs[0].Reorder2Default();
+
+  auto i_mem = in_buffer.GetMKLDNNData();
+  auto i_mpd = i_mem->get_primitive_desc();
+  auto i_desc = i_mpd.desc();
+  mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
+  mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_dim);
+  for (size_t i = 0; i < i_dim; i++) {
+    i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
+  }
+  auto o_desc = mkldnn::memory::desc(i_dims,
+                                    (mkldnn::memory::data_type)data_type_enum<DstDType>::type,
+                                    i_fmt);
+  auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
+  auto reorder_pd  = reorder::primitive_desc(i_mpd, o_mpd, attr);
+  auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
+  MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
+  CommitOutput(outputs[0], o_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+static void MKLDNNRequantizeForward(const nnvm::NodeAttrs& attrs,
+                                    const OpContext& ctx,
+                                    const std::vector<NDArray>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<NDArray>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  typedef int32_t SrcDType;
+  typedef int8_t  DstDType;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  const RequantizeParam& param = nnvm::get<RequantizeParam>(attrs.parsed);
+  float real_range;
+  // Model is calibrated
+  if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+    real_range =
+          MaxAbs(param.min_calib_range.value(), param.max_calib_range.value());
+    MKLDNNRequantizeForwardKer(attrs, ctx, inputs, req, outputs, real_range);
+  // Model is not calibrated
+  } else {
+    TShape src_shape, dst_shape;
+    const size_t actual_float_size = sizeof(float);
+    const size_t actual_quantized_size = sizeof(SrcDType);
+    const size_t temp_reduce_size = ConfigReduce<cpu, SrcDType>(s,
+                         inputs[0].shape(), TShape({1}), &src_shape, &dst_shape);
+    Tensor<cpu, 1, char> temp_space =
+      ctx.requested[0].get_space_typed<cpu, 1, char>(
+      Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s);
+    Tensor<cpu, 1, float> actual_min_float(
+                 reinterpret_cast<float*>(temp_space.dptr_), Shape1(1), s);
+    Tensor<cpu, 1, float> actual_max_float(
+                 reinterpret_cast<float*>(temp_space.dptr_) + 1, Shape1(1), s);
+    const int dev_id = ctx.run_ctx.ctx.dev_id;
+    TBlob actual_min_quantized(reinterpret_cast<SrcDType*>(
+                       temp_space.dptr_ + 8), Shape1(1), cpu::kDevMask, dev_id);
+    TBlob actual_max_quantized(reinterpret_cast<SrcDType*>(
+                   temp_space.dptr_ + 8) + 1, Shape1(1), cpu::kDevMask, dev_id);
+    Tensor<cpu, 1, char> workspace(
+            temp_space.dptr_+2*actual_float_size+2*actual_quantized_size,
+            Shape1(temp_reduce_size), s);
+    broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
+        s, actual_min_quantized.reshape(dst_shape), kWriteTo,
+        workspace, inputs[0].Reorder2Default().data().reshape(src_shape));
+    Kernel<QuantizedToFloatStruct, cpu>::Launch(s, 1,
+        actual_min_float.dptr_, actual_min_quantized.dptr<SrcDType>(),
+        inputs[1].Reorder2Default().data().dptr<float>(),
+        inputs[2].Reorder2Default().data().dptr<float>());
+    broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
+        s, actual_max_quantized.reshape(dst_shape), kWriteTo,
+        workspace, inputs[0].Reorder2Default().data().reshape(src_shape));
+    Kernel<QuantizedToFloatStruct, cpu>::Launch(s, 1,
+        actual_max_float.dptr_, actual_max_quantized.dptr<SrcDType>(),
+        inputs[1].Reorder2Default().data().dptr<float>(),
+        inputs[2].Reorder2Default().data().dptr<float>());
+
+    real_range = MaxAbs(*actual_min_float.dptr_, *actual_max_float.dptr_);
+    MKLDNNRequantizeForwardKer(attrs, ctx, inputs, req, outputs, real_range);
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_REQUANTIZE_INL_H_
diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc
index 32eb952..25fb19d 100644
--- a/src/operator/quantization/quantize.cc
+++ b/src/operator/quantization/quantize.cc
@@ -23,11 +23,31 @@
  * \brief
  */
 #include "./quantize-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_quantize-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
 DMLC_REGISTER_PARAMETER(QuantizeParam);
 
+bool QuantizeStorageType(const nnvm::NodeAttrs& attrs,
+                         const int dev_mask,
+                         DispatchMode* dispatch_mode,
+                         std::vector<int> *in_attrs,
+                         std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  (*out_attrs)[0] = kDefaultStorage;
+  (*out_attrs)[1] = kDefaultStorage;
+  (*out_attrs)[2] = kDefaultStorage;
+  return true;
+}
+
 NNVM_REGISTER_OP(_contrib_quantize)
 .describe(R"code(Quantize a input tensor from float to `out_type`,
 with user-specified `min_range` and `max_range`.
@@ -61,6 +81,10 @@ where
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuantizeShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizeType)
+.set_attr<FInferStorageType>("FInferStorageType", QuantizeStorageType)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeCompute)
+#endif
 .set_attr<FCompute>("FCompute<cpu>", QuantizeCompute<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`")
 .add_argument("min_range", "NDArray-or-Symbol", "The minimum scalar value "
diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index 5ec745c..5376a0e 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -99,6 +99,7 @@ Graph QuantizeGraph(Graph &&src) {
   static auto& need_requantize_map = Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
   auto offline_params = src.GetAttr<std::unordered_set<std::string>>("offline_params");
   auto excluded_nodes = src.GetAttr<std::unordered_set<NodePtr>>("excluded_nodes");
+  auto quantized_dtype = src.GetAttr<std::string>("quantized_dtype");
 
   // mirror_map stores the mapping from the currently visited graph to the newly created quantized
   // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key
@@ -129,7 +130,7 @@ Graph QuantizeGraph(Graph &&src) {
              mirror_node->op()->name != "_contrib_quantize")) {
           NodePtr quantize_node = InsertNode("_contrib_quantize",
             e.node->attrs.name + "_quantize", new_node, mirror_entry);
-          quantize_node->attrs.dict["out_type"] = "int8";
+          quantize_node->attrs.dict["out_type"] = quantized_dtype;
           quantize_node->op()->attr_parser(&(quantize_node->attrs));
 
           NodePtr min_node = InsertNode("min",
@@ -159,7 +160,11 @@ Graph QuantizeGraph(Graph &&src) {
         uint32_t min_index = 1;
         uint32_t max_index = 2;
         if (quantized_op_map.count(e.node->op())) {
-          size_t  num_outputs = e.node->num_outputs();
+          // here we calculate the output number (exclude min/max, in order to
+          // calculate min/max index from mirror node) based on assumption that
+          // there is only 1min and 1max output from mirror node (which is
+          // currently true)
+          size_t  num_outputs = mirror_node->num_outputs() - 2;
           min_index = num_outputs + 2 * e.index;
           max_index = num_outputs + 2 * e.index + 1;
         } else {
@@ -198,12 +203,15 @@ Graph QuantizeGraph(Graph &&src) {
         NodePtr mirror_node = mirror_map.at(e.node.get());
         NodeEntry mirror_entry = NodeEntry{
           mirror_node, e.index, e.version};
-        size_t num_outputs = e.node->num_outputs();
-        uint32_t min_index = num_outputs + 2 * e.index;
-        uint32_t max_index = num_outputs + 2 * e.index + 1;
-
         // if input node is quantized operator, add dequantize node
         if (NeedQuantize(e.node, excluded_nodes)) {
+          // here we calculate the output number (exclude min/max, in order to
+          // calculate min/max index from mirror node) based on assumption that
+          // there is only 1min and 1max output from mirror node (which is
+          // currently true)
+          size_t num_outputs = mirror_node->num_outputs() - 2;
+          uint32_t min_index = num_outputs + 2 * e.index;
+          uint32_t max_index = num_outputs + 2 * e.index + 1;
           NodePtr dequantize_node = CreateNode("_contrib_dequantize",
             e.node->attrs.name + "_dequantize");
           dequantize_node->inputs.emplace_back(mirror_entry);
diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc
index d7dc9fe..ed62228 100644
--- a/src/operator/quantization/quantized_conv.cc
+++ b/src/operator/quantization/quantized_conv.cc
@@ -24,6 +24,9 @@
  * \author Ziheng Jiang, Jun Wu
 */
 #include "../nn/convolution-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "../nn/mkldnn/mkldnn_ops-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
@@ -86,12 +89,13 @@ bool QuantizedConvType(const nnvm::NodeAttrs& attrs,
   const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
   CHECK_EQ(in_type->size(), param.no_bias? 6U : 9U);
   CHECK_EQ(out_type->size(), 3U);
+#ifndef MXNET_USE_MKLDNN
   TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
+#endif
   TYPE_ASSIGN_CHECK(*in_type, 1, mshadow::kInt8);
   if (!param.no_bias) {
     TYPE_ASSIGN_CHECK(*in_type, 2, mshadow::kInt8);
   }
-
   const size_t start = param.no_bias? 2 : 3;
   const size_t end = param.no_bias? 6 : 9;
   for (size_t i = start; i < end; ++i) {
@@ -104,6 +108,24 @@ bool QuantizedConvType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+bool QuantizedConvStorageType(const nnvm::NodeAttrs& attrs,
+                              const int dev_mask,
+                              DispatchMode* dispatch_mode,
+                              std::vector<int> *in_attrs,
+                              std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+
+  (*out_attrs)[0] = kDefaultStorage;
+  (*out_attrs)[1] = kDefaultStorage;
+  (*out_attrs)[2] = kDefaultStorage;
+  return true;
+}
+
 NNVM_REGISTER_OP(_contrib_quantized_conv)
 .describe(R"code(Convolution operator for input, weight and bias data type of int8,
 and accumulates in type int32 for the output. For each argument, two more arguments of type
@@ -119,7 +141,7 @@ and max thresholds representing the threholds for quantizing the float32 output
     return param.no_bias? 6 : 9;
   })
 .set_num_outputs(3)
-.set_attr_parser(ParamParser<ConvolutionParam>)
+.set_attr_parser(ConvolutionParamParser)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
@@ -137,6 +159,7 @@ and max thresholds representing the threholds for quantizing the float32 output
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuantizedConvShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedConvType)
+.set_attr<FInferStorageType>("FInferStorageType", QuantizedConvStorageType)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>(1, ResourceRequest::kTempSpace);
diff --git a/src/operator/quantization/quantized_flatten-inl.h b/src/operator/quantization/quantized_flatten-inl.h
index 95f3661..b7209fd 100644
--- a/src/operator/quantization/quantized_flatten-inl.h
+++ b/src/operator/quantization/quantized_flatten-inl.h
@@ -62,11 +62,21 @@ void QuantizedFlattenCompute(const nnvm::NodeAttrs& attrs,
   using namespace mxnet_op;
   Stream<xpu> *s = ctx.get_stream<xpu>();
 
-  typedef int8_t DstDType;
-  typedef int8_t  SrcDType;
-  Kernel<quantized_flatten, xpu>::Launch(s, outputs[0].Size(),
-    outputs[0].dptr<DstDType>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
-    inputs[0].dptr<SrcDType>(), inputs[1].dptr<float>(), inputs[2].dptr<float>());
+  if (inputs[0].type_flag_ == mshadow::kUint8) {
+    typedef uint8_t SrcDType;
+    typedef uint8_t DstDType;
+    Kernel<quantized_flatten, xpu>::Launch(s, outputs[0].Size(),
+      outputs[0].dptr<DstDType>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
+      inputs[0].dptr<SrcDType>(), inputs[1].dptr<float>(), inputs[2].dptr<float>());
+  } else if (inputs[0].type_flag_ == mshadow::kInt8) {
+    typedef int8_t SrcDType;
+    typedef int8_t DstDType;
+    Kernel<quantized_flatten, xpu>::Launch(s, outputs[0].Size(),
+      outputs[0].dptr<DstDType>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
+      inputs[0].dptr<SrcDType>(), inputs[1].dptr<float>(), inputs[2].dptr<float>());
+  } else {
+    LOG(FATAL) << "quantized_flatten op only supports int8 and uint8 as input and output type";
+  }
 }
 
 inline bool QuantizedFlattenShape(const nnvm::NodeAttrs& attrs,
@@ -96,10 +106,9 @@ inline bool QuantizedFlattenType(const nnvm::NodeAttrs& attrs,
                                  std::vector<int> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 3U);
   CHECK_EQ(out_attrs->size(), 3U);
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kInt8);
   TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kFloat32);
   TYPE_ASSIGN_CHECK(*in_attrs, 2, mshadow::kFloat32);
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
   TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32);
   TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32);
   return (*in_attrs)[0] != -1;
diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc
index a3105eb..779e244 100644
--- a/src/operator/quantization/quantized_pooling.cc
+++ b/src/operator/quantization/quantized_pooling.cc
@@ -23,6 +23,9 @@
 */
 #include <mxnet/op_attr_types.h>
 #include "../nn/pooling-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "../nn/mkldnn/mkldnn_pooling-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
@@ -79,8 +82,12 @@ bool QuantizedPoolingType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_type->size(), 3U);
   CHECK_EQ(out_type->size(), 3U);
   if (param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) {
+#if MXNET_USE_MKLDNN  == 1
+    TYPE_ASSIGN_CHECK(*out_type, 0, (*in_type)[0]);
+#else
     TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
     TYPE_ASSIGN_CHECK(*out_type, 0, mshadow::kInt8);
+#endif
   } else {
     LOG(FATAL) << "QuantizedPoolingOp only supports pool_type=max/avg for now";
   }
@@ -91,6 +98,27 @@ bool QuantizedPoolingType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+inline static bool QuantizedPoolingStorageType(const nnvm::NodeAttrs &attrs,
+                                               const int dev_mask,
+                                               DispatchMode *dispatch_mode,
+                                               std::vector<int> *in_attrs,
+                                               std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 3);
+
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
+  if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#else
+  CHECK_EQ(out_attrs->size(), 3);
+#endif
+  for (size_t i = 0; i < out_attrs->size(); i++)
+    (*out_attrs)[i] = kDefaultStorage;
+  return true;
+}
+
 NNVM_REGISTER_OP(_contrib_quantized_pooling)
 .describe(R"code(Pooling operator for input and output data type of int8.
 The input and output data comes with min and max thresholds for quantizing
@@ -101,7 +129,7 @@ the float32 data into int8.
     This operator only supports `pool_type` of `avg` or `max`.)code" ADD_FILELINE)
 .set_num_inputs(3)
 .set_num_outputs(3)
-.set_attr_parser(ParamParser<PoolingParam>)
+.set_attr_parser(PoolingParamParser)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"data", "min_data", "max_data"};
@@ -112,6 +140,7 @@ the float32 data into int8.
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuantizedPoolingShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedPoolingType)
+.set_attr<FInferStorageType>("FInferStorageType", QuantizedPoolingStorageType)
 .set_attr<FNeedRequantize>("FNeedRequantize",
   [](const NodeAttrs& attrs) {
     const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc
index 83ea37b..5ce0ff0 100644
--- a/src/operator/quantization/requantize.cc
+++ b/src/operator/quantization/requantize.cc
@@ -24,11 +24,31 @@
  */
 #include "./requantize-inl.h"
 #include "./quantize-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_requantize-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
 DMLC_REGISTER_PARAMETER(RequantizeParam);
 
+bool RequantizeStorageType(const nnvm::NodeAttrs& attrs,
+                         const int dev_mask,
+                         DispatchMode* dispatch_mode,
+                         std::vector<int> *in_attrs,
+                         std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  (*out_attrs)[0] = kDefaultStorage;
+  (*out_attrs)[1] = kDefaultStorage;
+  (*out_attrs)[2] = kDefaultStorage;
+  return true;
+}
+
 NNVM_REGISTER_OP(_contrib_requantize)
 .describe(R"code(Given data that is quantized in int32 and the corresponding thresholds,
 requantize the data into int8 using min and max thresholds either calculated at runtime
@@ -43,7 +63,12 @@ inference accuracy.
 .set_num_outputs(3)
 .set_attr<nnvm::FInferShape>("FInferShape", QuantizeShape)
 .set_attr<nnvm::FInferType>("FInferType", RequantizeType)
+.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNRequantizeForward)
+#else
 .set_attr<FCompute>("FCompute<cpu>", RequantizeForward<cpu>)
+#endif
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
     const RequantizeParam& param =
       nnvm::get<RequantizeParam>(attrs.parsed);
diff --git a/tests/python/mkl/test_quantization_mkldnn.py b/tests/python/mkl/test_quantization_mkldnn.py
new file mode 100644
index 0000000..290f1a1
--- /dev/null
+++ b/tests/python/mkl/test_quantization_mkldnn.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import sys
+import mxnet as mx
+
+os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] = '1'
+curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(curr_path, '../quantization'))
+from test_quantization import *
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 7b08f46..15e8582 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -18,6 +18,7 @@
 """Some of the tests using CUDNN require a special GPU instruction called dp4a.
 Ref: http://images.nvidia.com/content/pdf/tesla/184457-Tesla-P4-Datasheet-NV-Final-Letter-Web.pdf
 """
+import os
 import mxnet as mx
 import numpy as np
 from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same, DummyIter
@@ -25,6 +26,16 @@ from common import with_seed
 from mxnet.module import Module
 from mxnet.io import NDArrayIter
 
+def is_test_for_gpu():
+    return mx.current_context().device_type == 'gpu'
+
+def is_test_for_mkldnn():
+    return (mx.current_context().device_type == 'cpu'
+            and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == '1')
+
+def is_test_for_native_cpu():
+    return (mx.current_context().device_type == 'cpu'
+            and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == None)
 
 @with_seed()
 def test_quantize_float32_to_int8():
@@ -120,187 +131,220 @@ def test_requantize_int32_to_int8():
 
 @with_seed()
 def test_quantized_conv():
-    if mx.current_context().device_type != 'gpu':
-        print('skipped testing quantized_conv on cpu since it is not implemented yet')
-        return
-
-    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, no_bias):
-        with mx.Context('gpu', 0):
-            # run fp32 conv
-            data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-            conv2d = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
-                                        no_bias=no_bias, cudnn_off=False, name='conv2d')
-            arg_shapes, _, _ = conv2d.infer_shape(data=data_shape)
-            arg_names = conv2d.list_arguments()
-            conv_exe_fp32 = conv2d.simple_bind(ctx=mx.current_context(), grad_req='null')
-            conv_exe_fp32.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           shape=data_shape).astype('int32')
-            conv_exe_fp32.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                           shape=arg_shapes[1]).astype('int32')
-            if not no_bias:
-                conv_exe_fp32.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                               shape=arg_shapes[2]).astype('int32')
-            output = conv_exe_fp32.forward()[0]
-
-            # run quantized conv
-            qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
-            qweight = mx.sym.Variable(name='qweight', dtype='int8')
-            min_data = mx.sym.Variable(name='min_data')
-            max_data = mx.sym.Variable(name='max_data')
-            min_weight = mx.sym.Variable(name='min_weight')
-            max_weight = mx.sym.Variable(name='max_weight')
-            quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, weight=qweight, min_data=min_data,
-                                                             max_data=max_data, min_weight=min_weight,
-                                                             max_weight=max_weight, kernel=kernel,
-                                                             num_filter=num_filter, pad=pad, stride=stride,
-                                                             no_bias=no_bias)
-            qarg_names = quantized_conv2d.list_arguments()
-            type_dict = None
-            if not no_bias:
-                type_dict = {qarg_names[2]: 'int8'}
-            conv_exe_int8 = quantized_conv2d.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null')
-            conv_exe_int8.arg_dict[qarg_names[0]][:] = conv_exe_fp32.arg_dict[arg_names[0]].astype('int8')
-            conv_exe_int8.arg_dict[qarg_names[1]][:] = conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
-            quantized_range = 127.0
-            if no_bias:
-                conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range
-                conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range
-            else:
-                conv_exe_int8.arg_dict[qarg_names[2]][:] = conv_exe_fp32.arg_dict[arg_names[2]].astype('int8')
-                conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range
-                conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range
-                conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range
-                conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range
-            qoutput, min_range, max_range = conv_exe_int8.forward()
-
-            if no_bias:
-                assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-            else:
-                # with adding bias, accuracy loss should not be greater than one
-                diff = mx.nd.abs(output - qoutput.astype(output.dtype))
-                cond = mx.nd.lesser(2, diff).sum().asscalar()
-                assert cond == 0
-
-    check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), True)
-    check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), False)
+    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, no_bias, qdtype):
+        if is_test_for_native_cpu():
+            print('skipped testing quantized_conv for native cpu since it is not supported yet')
+            return
+        elif qdtype == 'int8' and is_test_for_mkldnn():
+            print('skipped testing quantized_conv for mkldnn cpu int8 since it is not supported yet')
+            return
+        elif qdtype == 'uint8' and is_test_for_gpu():
+            print('skipped testing quantized_conv for gpu uint8 since it is not supported yet')
+            return
+
+        # run fp32 conv
+        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
+        conv2d = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride,
+                                    no_bias=no_bias, cudnn_off=False, name='conv2d')
+        arg_shapes, _, _ = conv2d.infer_shape(data=data_shape)
+        arg_names = conv2d.list_arguments()
+        conv_exe_fp32 = conv2d.simple_bind(ctx=mx.current_context(), grad_req='null')
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+        conv_exe_fp32.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
+                                                                        shape=data_shape).astype('int32')
+        conv_exe_fp32.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
+                                                                        shape=arg_shapes[1]).astype('int32')
+        if not no_bias:
+            conv_exe_fp32.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
+                                                                            shape=arg_shapes[2]).astype('int32')
+        output = conv_exe_fp32.forward()[0]
+
+        # run quantized conv
+        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
+        qweight = mx.sym.Variable(name='qweight', dtype='int8')
+        min_data = mx.sym.Variable(name='min_data')
+        max_data = mx.sym.Variable(name='max_data')
+        min_weight = mx.sym.Variable(name='min_weight')
+        max_weight = mx.sym.Variable(name='max_weight')
+        quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, weight=qweight, min_data=min_data,
+                                                            max_data=max_data, min_weight=min_weight,
+                                                            max_weight=max_weight, kernel=kernel,
+                                                            num_filter=num_filter, pad=pad, stride=stride,
+                                                            no_bias=no_bias)
+        qarg_names = quantized_conv2d.list_arguments()
+        type_dict = None
+        if not no_bias:
+            type_dict = {qarg_names[2]: 'int8'}
+        conv_exe_int8 = quantized_conv2d.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null')
+        conv_exe_int8.arg_dict[qarg_names[0]][:] = conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype)
+        conv_exe_int8.arg_dict[qarg_names[1]][:] = conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
+        quantized_range = 127.0
+        if no_bias:
+            conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range
+            conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range
+        else:
+            conv_exe_int8.arg_dict[qarg_names[2]][:] = conv_exe_fp32.arg_dict[arg_names[2]].astype('int8')
+            conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range
+            conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range
+            conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range
+            conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range
+        qoutput, min_range, max_range = conv_exe_int8.forward()
+
+        if no_bias:
+            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
+        else:
+            # with adding bias, accuracy loss should not be greater than one
+            diff = mx.nd.abs(output - qoutput.astype(output.dtype))
+            cond = mx.nd.lesser(2, diff).sum().asscalar()
+            assert cond == 0
 
+    for qdtype in ['int8', 'uint8']:
+        check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), True, qdtype)
+        check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), False, qdtype)
 
 @with_seed()
 def test_quantized_pooling():
-    if mx.current_context().device_type != 'gpu':
-        print('skipped testing quantized_pooling on cpu since it is not implemented yet')
-        return
-
-    def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool):
-        with mx.Context('gpu', 0):
-            data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-            pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride,
-                                          pool_type=pool_type, global_pool=global_pool, cudnn_off=False)
-            arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape)
-            arg_names = pooling_fp32.list_arguments()
-            pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
-            pooling_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                              shape=data_shape).astype('int32')
-            output = pooling_fp32_exe.forward()[0]
-
-            qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
-            min_data = mx.sym.Variable(name='min_data')
-            max_data = mx.sym.Variable(name='max_data')
-            quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data,
-                                                                 max_data=max_data, kernel=kernel,
-                                                                 pad=pad, stride=stride, pool_type=pool_type,
-                                                                 global_pool=global_pool)
-            pooling_int8_exe = quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null')
-            qarg_names = quantized_pooling.list_arguments()
-            pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype('int8')
-            quantized_range = 127.0
-            pooling_int8_exe.arg_dict[qarg_names[1]][:] = -quantized_range
-            pooling_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range
-            qoutput, min_range, max_range = pooling_int8_exe.forward()
-
-            if pool_type == 'max':
-                assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-            elif pool_type == 'avg':  # for avg pooling, fp32 and int8 may be different due to rounding errors
-                diff = mx.nd.abs(output - qoutput.astype(output.dtype))
-                cond = mx.nd.lesser(2, diff).sum().asscalar()
-                assert cond == 0
-
-    check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False)
-    check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True)
-    check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False)
-    check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True)
-
+    def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype):
+        if is_test_for_native_cpu():
+            print('skipped testing quantized_pooling for native cpu since it is not supported yet')
+            return
+        elif qdtype == 'uint8' and is_test_for_gpu():
+            print('skipped testing quantized_pooling for gpu uint8 since it is not supported yet')
+            return
+
+        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
+        pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride,
+                                        pool_type=pool_type, global_pool=global_pool, cudnn_off=False)
+        arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape)
+        arg_names = pooling_fp32.list_arguments()
+        pooling_fp32_exe = pooling_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+        pooling_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
+                                                                            shape=data_shape).astype('int32')
+        output = pooling_fp32_exe.forward()[0]
+
+        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype)
+        min_data = mx.sym.Variable(name='min_data')
+        max_data = mx.sym.Variable(name='max_data')
+        quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data,
+                                                                max_data=max_data, kernel=kernel,
+                                                                pad=pad, stride=stride, pool_type=pool_type,
+                                                                global_pool=global_pool)
+        pooling_int8_exe = quantized_pooling.simple_bind(ctx=mx.current_context(), grad_req='null')
+        qarg_names = quantized_pooling.list_arguments()
+        pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
+        quantized_range = 127.0
+        pooling_int8_exe.arg_dict[qarg_names[1]][:] = -quantized_range
+        pooling_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range
+        qoutput, min_range, max_range = pooling_int8_exe.forward()
+
+        if pool_type == 'max':
+            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
+        elif pool_type == 'avg':  # for avg pooling, fp32 and int8 may be different due to rounding errors
+            diff = mx.nd.abs(output - qoutput.astype(output.dtype))
+            cond = mx.nd.lesser(2, diff).sum().asscalar()
+            assert cond == 0
+
+    for qdtype in ['int8', 'uint8']:
+        check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype)
+        check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype)
+        check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype)
+        check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype)
 
 @with_seed()
 def test_quantized_fc():
-    if mx.current_context().device_type != 'gpu':
-        print('skipped testing quantized_fc on cpu since it is not implemented yet')
-        return
-
-    def check_quantized_fc(data_shape, num_hidden, no_bias, flatten=True):
-        with mx.Context('gpu', 0):
-            data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-            fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten)
-            arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape)
-            arg_names = fc_fp32.list_arguments()
-            fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
-            fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                         shape=data_shape).astype('int32')
-            fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                         shape=arg_shapes[1]).astype('int32')
-            if not no_bias:
-                fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                             shape=arg_shapes[2]).astype('int32')
-            output = fc_fp32_exe.forward()[0]
-
-            qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
-            fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, num_hidden=num_hidden,
-                                                               no_bias=no_bias, flatten=flatten)
-            qarg_names = fc_int8.list_arguments()
-            type_dict = {qarg_names[1]: 'int8'}
-            if not no_bias:
-                type_dict.update({qarg_names[2]: 'int8'})
-            fc_int8_exe = fc_int8.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null')
-            fc_int8_exe.arg_dict[qarg_names[0]][:] = fc_fp32_exe.arg_dict[arg_names[0]].astype('int8')
-            fc_int8_exe.arg_dict[qarg_names[1]][:] = fc_fp32_exe.arg_dict[arg_names[1]].astype('int8')
-            quantized_range = 127.0
-            if no_bias:
-                fc_int8_exe.arg_dict[qarg_names[2]][:] = -quantized_range
-                fc_int8_exe.arg_dict[qarg_names[3]][:] = quantized_range
-                fc_int8_exe.arg_dict[qarg_names[4]][:] = -quantized_range
-                fc_int8_exe.arg_dict[qarg_names[5]][:] = quantized_range
-            else:
-                fc_int8_exe.arg_dict[qarg_names[2]][:] = fc_fp32_exe.arg_dict[arg_names[2]].astype('int8')
-                fc_int8_exe.arg_dict[qarg_names[3]][:] = -quantized_range
-                fc_int8_exe.arg_dict[qarg_names[4]][:] = quantized_range
-                fc_int8_exe.arg_dict[qarg_names[5]][:] = -quantized_range
-                fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range
-                fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range
-                fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range
-            qoutput, min_range, max_range = fc_int8_exe.forward()
-
-            if no_bias:
-                assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
-            else:
-                # with adding bias, accuracy loss should not be greater than one
-                diff = mx.nd.abs(output - qoutput.astype(output.dtype))
-                cond = mx.nd.lesser(2, diff).sum().asscalar()
-                assert cond == 0
-
-    check_quantized_fc((32, 512, 2, 2), 100, True)
-    check_quantized_fc((32, 111, 2, 2), 100, True)
-    check_quantized_fc((32, 512, 2, 2), 100, False)
-    check_quantized_fc((32, 111, 2, 2), 100, False)
+    def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
+        if mx.current_context().device_type != 'gpu':
+            print('skipped testing quantized_fc on cpu since it is not supported yet')
+            return
+        elif qdtype == 'uint8' and is_test_for_gpu():
+            print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
+            return
+
+        data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
+        fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten)
+        arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape)
+        arg_names = fc_fp32.list_arguments()
+        fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+        fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
+                                                                     shape=data_shape).astype('int32')
+        fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
+                                                                     shape=arg_shapes[1]).astype('int32')
+        if not no_bias:
+            fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
+                                                                         shape=arg_shapes[2]).astype('int32')
+        output = fc_fp32_exe.forward()[0]
+
+        qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype='int8')
+        fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, num_hidden=num_hidden,
+                                                           no_bias=no_bias, flatten=flatten)
+        qarg_names = fc_int8.list_arguments()
+        type_dict = {qarg_names[1]: 'int8'}
+        if not no_bias:
+            type_dict.update({qarg_names[2]: 'int8'})
+        fc_int8_exe = fc_int8.simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null')
+        fc_int8_exe.arg_dict[qarg_names[0]][:] = fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype)
+        fc_int8_exe.arg_dict[qarg_names[1]][:] = fc_fp32_exe.arg_dict[arg_names[1]].astype('int8')
+        quantized_range = 127.0
+        if no_bias:
+            fc_int8_exe.arg_dict[qarg_names[2]][:] = -quantized_range
+            fc_int8_exe.arg_dict[qarg_names[3]][:] = quantized_range
+            fc_int8_exe.arg_dict[qarg_names[4]][:] = -quantized_range
+            fc_int8_exe.arg_dict[qarg_names[5]][:] = quantized_range
+        else:
+            fc_int8_exe.arg_dict[qarg_names[2]][:] = fc_fp32_exe.arg_dict[arg_names[2]].astype('int8')
+            fc_int8_exe.arg_dict[qarg_names[3]][:] = -quantized_range
+            fc_int8_exe.arg_dict[qarg_names[4]][:] = quantized_range
+            fc_int8_exe.arg_dict[qarg_names[5]][:] = -quantized_range
+            fc_int8_exe.arg_dict[qarg_names[6]][:] = quantized_range
+            fc_int8_exe.arg_dict[qarg_names[7]][:] = -quantized_range
+            fc_int8_exe.arg_dict[qarg_names[8]][:] = quantized_range
+        qoutput, min_range, max_range = fc_int8_exe.forward()
+
+        if no_bias:
+            assert_almost_equal(output.asnumpy(), qoutput.asnumpy())
+        else:
+            # with adding bias, accuracy loss should not be greater than one
+            diff = mx.nd.abs(output - qoutput.astype(output.dtype))
+            cond = mx.nd.lesser(2, diff).sum().asscalar()
+            assert cond == 0
 
+    for qdtype in ['int8', 'uint8']:
+        check_quantized_fc((32, 512, 2, 2), 100, True, qdtype)
+        check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
+        check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
+        check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
 
 @with_seed()
 def test_quantized_flatten():
-    def check_quantized_flatten(shape):
-        qdata = mx.nd.random.uniform(low=-127, high=127, shape=shape).astype('int8')
+    def check_quantized_flatten(shape, qdtype):
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+        qdata = mx.nd.random.uniform(low=data_low, high=data_high, shape=shape).astype(qdtype)
         min_data = mx.nd.array([-1023.343], dtype='float32')
         max_data = mx.nd.array([2343.324275], dtype='float32')
         qoutput, min_output, max_output = mx.nd.contrib.quantized_flatten(qdata, min_data, max_data)
@@ -311,10 +355,11 @@ def test_quantized_flatten():
         assert same(min_data.asnumpy(), min_output.asnumpy())
         assert same(max_data.asnumpy(), max_output.asnumpy())
 
-    check_quantized_flatten((10,))
-    check_quantized_flatten((10, 15))
-    check_quantized_flatten((10, 15, 18))
-    check_quantized_flatten((3, 4, 23, 23))
+    for qdtype in ['int8', 'uint8']:
+        check_quantized_flatten((10,), qdtype)
+        check_quantized_flatten((10, 15), qdtype)
+        check_quantized_flatten((10, 15, 18), qdtype)
+        check_quantized_flatten((3, 4, 23, 23), qdtype)
 
 
 @with_seed()
@@ -353,56 +398,69 @@ def get_fp32_sym():
 
 @with_seed()
 def test_quantize_model():
-    def check_params(params, qparams, qsym=None):
-        if qsym is None:
-            assert len(params) == len(qparams)
-            for k, v in params.items():
-                assert k in qparams
-                assert same(v.asnumpy(), qparams[k].asnumpy())
-        else:
-            qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params)
-            assert len(qparams) == len(qparams_ground_truth)
-            for k, v in qparams_ground_truth.items():
-                assert k in qparams
-                assert same(v.asnumpy(), qparams[k].asnumpy())
-
-    def check_qsym_calibrated(qsym):
-        attrs = qsym.attr_dict()
-        for k, v in attrs.items():
-            if k.find('requantize_') != -1:
-                assert 'min_calib_range' in v
-                assert 'max_calib_range' in v
-
-    sym = get_fp32_sym()
-    mod = Module(symbol=sym)
-    batch_size = 4
-    data_shape = (batch_size, 4, 10, 10)
-    label_shape = (batch_size, 10)
-    mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)])
-    mod.init_params()
-    arg_params, aux_params = mod.get_params()
-    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
-                                                                     arg_params=arg_params,
-                                                                     aux_params=aux_params,
-                                                                     ctx=mx.current_context(),
-                                                                     calib_mode='none')
-    check_params(arg_params, qarg_params, qsym)
-    check_params(aux_params, qaux_params)
-
-    calib_data = mx.nd.random.uniform(shape=data_shape)
-    calib_data = NDArrayIter(data=calib_data)
-    calib_data = DummyIter(calib_data)
-    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
-                                                                     arg_params=arg_params,
-                                                                     aux_params=aux_params,
-                                                                     ctx=mx.current_context(),
-                                                                     calib_mode='naive',
-                                                                     calib_data=calib_data,
-                                                                     num_calib_examples=20)
-    check_params(arg_params, qarg_params, qsym)
-    check_params(aux_params, qaux_params)
-    check_qsym_calibrated(qsym)
-
+    def check_quantize_model(qdtype):
+        def check_params(params, qparams, qsym=None):
+            if qsym is None:
+                assert len(params) == len(qparams)
+                for k, v in params.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+            else:
+                qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params)
+                assert len(qparams) == len(qparams_ground_truth)
+                for k, v in qparams_ground_truth.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+
+        def check_qsym_calibrated(qsym):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('requantize_') != -1:
+                    assert 'min_calib_range' in v
+                    assert 'max_calib_range' in v
+
+        def check_qsym_qdtype(qsym, qdtype):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('_quantize') != -1:
+                    assert 'out_type' in v
+                    assert v['out_type'] == qdtype
+
+        sym = get_fp32_sym()
+        mod = Module(symbol=sym)
+        batch_size = 4
+        data_shape = (batch_size, 4, 10, 10)
+        label_shape = (batch_size, 10)
+        mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)])
+        mod.init_params()
+        arg_params, aux_params = mod.get_params()
+        qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+                                                                         arg_params=arg_params,
+                                                                         aux_params=aux_params,
+                                                                         ctx=mx.current_context(),
+                                                                         quantized_dtype=qdtype,
+                                                                         calib_mode='none')
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+
+        calib_data = mx.nd.random.uniform(shape=data_shape)
+        calib_data = NDArrayIter(data=calib_data)
+        calib_data = DummyIter(calib_data)
+        qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+                                                                         arg_params=arg_params,
+                                                                         aux_params=aux_params,
+                                                                         ctx=mx.current_context(),
+                                                                         quantized_dtype=qdtype,
+                                                                         calib_mode='naive',
+                                                                         calib_data=calib_data,
+                                                                         num_calib_examples=20)
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_calibrated(qsym)
+        check_qsym_qdtype(qsym, qdtype)
+
+    for qdtype in ['int8', 'uint8']:
+        check_quantize_model(qdtype)
 
 @with_seed()
 def test_quantize_sym_with_calib():

-- 
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.