You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/10/19 01:53:18 UTC

[incubator-mxnet] branch master updated: MKL-DNN Quantization Examples and README (#12808)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 441fdb7  MKL-DNN Quantization Examples and README (#12808)
441fdb7 is described below

commit 441fdb773019e45b289d9049a8aca88f8593ef10
Author: Xinyu Chen <xi...@intel.com>
AuthorDate: Fri Oct 19 09:53:07 2018 +0800

    MKL-DNN Quantization Examples and README (#12808)
    
    * add gluoncv support
    
    * add ssd readme
    
    * improve ssd readme
    
    * add custom readme
    
    * add ssd model link
    
    * add squeezenet
    
    * add ssd quantization script
    
    * fix topo of args
    
    * improve custom readme
    
    * fix topo bug
    
    * fix squeezenet
    
    * add squeezenet accuracy
    
    * Add initializer for min max to support quantization
    
    * add dummy data inference
    
    * add test case for init_param
    
    * add subgraph docs
    
    * improve docs
    
    * add two models and fix default rgb_std to 1
    
    * fix doc link
    
    * improve MKLDNN_README
    
    * add quantization for mobilenetv1
    
    * fix ssd benchmark_score label shapes
    
    * add resnet101_v1 and inceptionv3 support
    
    * Refine some descriptions in the MKLDNN_README
    
    * improve docs
    
    * improve link in perf.md
---
 MKLDNN_README.md                                   |  55 +++-
 docs/faq/perf.md                                   |  13 +-
 example/quantization/README.md                     | 301 +++++++++++++++++++++
 example/quantization/imagenet_gen_qsym_mkldnn.py   | 131 ++++++++-
 example/quantization/imagenet_inference.py         | 103 ++++---
 example/ssd/benchmark_score.py                     |  25 +-
 example/ssd/config/config.py                       |   2 +-
 example/ssd/evaluate.py                            |   4 +-
 example/ssd/evaluate/evaluate_net.py               |  28 +-
 .../quantization.py}                               | 137 ++++------
 python/mxnet/initializer.py                        |  10 +
 tests/python/mkl/test_subgraph.py                  |  12 +
 12 files changed, 667 insertions(+), 154 deletions(-)

diff --git a/MKLDNN_README.md b/MKLDNN_README.md
index 9da2eee..2618d23 100644
--- a/MKLDNN_README.md
+++ b/MKLDNN_README.md
@@ -1,6 +1,10 @@
 # Build/Install MXNet with MKL-DNN
 
-Building MXNet with [Intel MKL-DNN](https://github.com/intel/mkl-dnn) will gain better performance when using Intel Xeon CPUs for training and inference. The improvement of performance can be seen in this [page](https://mxnet.incubator.apache.org/faq/perf.html#intel-cpu). Below are instructions for linux, MacOS and Windows platform.
+A better training and inference perforamce are expected to achieved on Intel-Architecture CPUs with MXNET built with [Intel MKL-DNN](https://github.com/intel/mkl-dnn) on multiple operating system, including Linux, Windows and MacOS.
+In the following sections, you will find building instructions for MXNET with Intel MKL-DNN on Linux, MacOS and Windows.
+
+The detailed performance data collected on Intel Xeon CPU with MXNET built with Intel MKL-DNN can be found at [here](https://mxnet.incubator.apache.org/faq/perf.html#intel-cpu).
+
 
 <h2 id="0">Contents</h2>
 
@@ -9,7 +13,9 @@ Building MXNet with [Intel MKL-DNN](https://github.com/intel/mkl-dnn) will gain
 * [3. Windows](#3)
 * [4. Verify MXNet with python](#4)
 * [5. Enable MKL BLAS](#5)
-* [6. Support](#6)
+* [6. Enable graph optimization](#6)
+* [7. Quantization](#7)
+* [8. Support](#8)
 
 <h2 id="1">Linux</h2>
 
@@ -36,7 +42,7 @@ cd incubator-mxnet
 make -j $(nproc) USE_OPENCV=1 USE_MKLDNN=1 USE_BLAS=mkl USE_INTEL_PATH=/opt/intel
 ```
 
-If you don't have full [MKL](https://software.intel.com/en-us/intel-mkl) library installed, you can use OpenBLAS by setting `USE_BLAS=openblas`.
+If you don't have the full [MKL](https://software.intel.com/en-us/intel-mkl) library installation, you might use OpenBLAS as the blas library, by setting USE_BLAS=openblas.
 
 <h2 id="2">MacOS</h2>
 
@@ -77,7 +83,8 @@ LIBRARY_PATH=$(brew --prefix llvm)/lib/ make -j $(sysctl -n hw.ncpu) CC=$(brew -
 
 <h2 id="3">Windows</h2>
 
-We recommend to build and install MXNet yourself using [Microsoft Visual Studio 2015](https://www.visualstudio.com/vs/older-downloads/), or you can also try experimentally the latest [Microsoft Visual Studio 2017](https://www.visualstudio.com/downloads/).
+On Windows, you can use [Micrsoft Visual Studio 2015](https://www.visualstudio.com/vs/older-downloads/) and [Microsoft Visual Studio 2017](https://www.visualstudio.com/downloads/) to compile MXNET with Intel MKL-DNN.
+[Micrsoft Visual Studio 2015](https://www.visualstudio.com/vs/older-downloads/) is recommended.
 
 **Visual Studio 2015**
 
@@ -211,11 +218,11 @@ o = exe.outputs[0]
 t = o.asnumpy()
 ```
 
-You can open the `MKLDNN_VERBOSE` flag by setting environment variable:
+More detailed debugging and profiling information can be logged by setting the environment variable 'MKLDNN_VERBOSE':
 ```
 export MKLDNN_VERBOSE=1
 ```
-Then by running above code snippet, you probably will get the following output message which means `convolution` and `reorder` primitive from MKL-DNN are called. Layout information and primitive execution performance are also demonstrated in the log message.
+For example, by running above code snippet, the following debugging logs providing more insights on MKL-DNN primitives `convolution` and `reorder`. That includes: Memory layout, infer shape and the time cost of primitive execution.
 ```
 mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nChw16c,num:1,32x32x256x256,6.47681
 mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_oihw out:f32_OIhw16i16o,num:1,32x32x3x3,0.0429688
@@ -226,9 +233,9 @@ mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nChw16c out:f32_nchw,num:1,32x3
 
 <h2 id="5">Enable MKL BLAS</h2>
 
-To make it convenient for customers, Intel introduced a new license called [Intel® Simplified license](https://software.intel.com/en-us/license/intel-simplified-software-license) that allows to redistribute not only dynamic libraries but also headers, examples and static libraries.
-
-Installing and enabling the full MKL installation enables MKL support for all operators under the linalg namespace.
+With MKL BLAS, the performace is expected to furtherly improved with variable range depending on the computation load of the models.
+You can redistribute not only dynamic libraries but also headers, examples and static libraries on accepting the license [Intel® Simplified license](https://software.intel.com/en-us/license/intel-simplified-software-license).
+Installing the full MKL installation enables MKL support for all operators under the linalg namespace.
 
   1. Download and install the latest full MKL version following instructions on the [intel website.](https://software.intel.com/en-us/mkl)
 
@@ -275,10 +282,32 @@ MKL_VERBOSE Intel(R) MKL 2018.0 Update 1 Product build 20171007 for Intel(R) 64
 MKL_VERBOSE SGEMM(T,N,12,10,8,0x7f7f927b1378,0x1bc2140,8,0x1ba8040,8,0x7f7f927b1380,0x7f7f7400a280,12) 8.93ms CNR:OFF Dyn:1 FastMM:1 TID:0  NThr:40 WDiv:HOST:+0.000
 ```
 
-<h2 id="6">Next Steps and Support</h2>
+<h2 id="6">Enable graph optimization</h2>
+
+Graph optimization by subgraph feature are available in master branch. You can build from source and then use below command to enable this *experimental* feature for better performance:
+
+```
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+```
+
+This limitations of this experimental feature are:
+
+- Use this feature only for inference. When training, be sure to turn the feature off by unsetting the `MXNET_SUBGRAPH_BACKEND` environment variable.
+
+- This feature will only run on the CPU, even if you're using a GPU-enabled build of MXNet. 
+
+- [MXNet Graph Optimization and Quantization Technical Information and Performance Details](https://cwiki.apache.org/confluence/display/MXNET/MXNet+Graph+Optimization+and+Quantization+based+on+subgraph+and+MKL-DNN).
+
+<h2 id="7">Quantization and Inference with INT8</h2>
+
+Benefiting from Intel® MKL-DNN, MXNet built with Intel® MKL-DNN brings outstanding performance improvement on quantization and inference with INT8 Intel® CPU Platform on Intel® Xeon® Scalable Platform.
+
+- [CNN Quantization Examples](https://github.com/apache/incubator-mxnet/tree/master/example/quantization).
+
+<h2 id="8">Next Steps and Support</h2>
 
-- For questions or support specific to MKL, visit the [Intel MKL](https://software.intel.com/en-us/mkl)
+- For questions or support specific to MKL, visit the [Intel MKL](https://software.intel.com/en-us/mkl) website.
 
-- For questions or support specific to MKL, visit the [Intel MKLDNN](https://github.com/intel/mkl-dnn)
+- For questions or support specific to MKL, visit the [Intel MKLDNN](https://github.com/intel/mkl-dnn) website.
 
-- If you find bugs, please open an issue on GitHub for [MXNet with MKL](https://github.com/apache/incubator-mxnet/labels/MKL) or [MXNet with MKLDNN](https://github.com/apache/incubator-mxnet/labels/MKLDNN)
+- If you find bugs, please open an issue on GitHub for [MXNet with MKL](https://github.com/apache/incubator-mxnet/labels/MKL) or [MXNet with MKLDNN](https://github.com/apache/incubator-mxnet/labels/MKLDNN).
diff --git a/docs/faq/perf.md b/docs/faq/perf.md
index ad81b5d..f116ede 100644
--- a/docs/faq/perf.md
+++ b/docs/faq/perf.md
@@ -18,12 +18,15 @@ Performance is mainly affected by the following 4 factors:
 ## Intel CPU
 
 For using Intel Xeon CPUs for training and inference, we suggest enabling
-`USE_MKLDNN = 1` in`config.mk`. 
+`USE_MKLDNN = 1` in `config.mk`. 
 
-We also find that setting the following two environment variables can help:
-- `export KMP_AFFINITY=granularity=fine,compact,1,0` if there are two physical CPUs
-- `export OMP_NUM_THREADS=vCPUs / 2` in which `vCPUs` is the number of virtual CPUs.
-  Whe using Linux, we can access this information by running `cat /proc/cpuinfo  | grep processor | wc -l`
+We also find that setting the following environment variables can help:
+
+| Variable  | Description |
+| :-------- | :---------- |
+| `OMP_NUM_THREADS`            | Suggested value: `vCPUs / 2` in which `vCPUs` is the number of virtual CPUs. For more information, please see the guide for [setting the number of threads using an OpenMP environment variable](https://software.intel.com/en-us/mkl-windows-developer-guide-setting-the-number-of-threads-using-an-openmp-environment-variable) |
+| `KMP_AFFINITY`               | Suggested value: `granularity=fine,compact,1,0`.  For more information, please see the guide for [Thread Affinity Interface (Linux* and Windows*)](https://software.intel.com/en-us/node/522691). |
+| `MXNET_SUBGRAPH_BACKEND` | Set to MKLDNN to enable the [subgraph feature](https://cwiki.apache.org/confluence/display/MXNET/MXNet+Graph+Optimization+and+Quantization+based+on+subgraph+and+MKL-DNN) for better performance. For more information please see [Build/Install MXNet with MKL-DNN](https://github.com/apache/incubator-mxnet/blob/master/MKLDNN_README.md)|
 
 Note that _MXNet_ treats all CPUs on a single machine as a single device.
 So whether you specify `cpu(0)` or `cpu()`, _MXNet_ will use all CPU cores on the machine.
diff --git a/example/quantization/README.md b/example/quantization/README.md
index 63b6557..7f68dd7 100644
--- a/example/quantization/README.md
+++ b/example/quantization/README.md
@@ -1,4 +1,305 @@
 # Model Quantization with Calibration Examples
+
+This folder contains examples of quantizing a FP32 model with Intel® MKL-DNN or CUDNN.
+
+<h2 id="0">Contents</h2>
+
+* [1. Model Quantization with Intel® MKL-DNN](#1)
+* [2. Model Quantization with CUDNN](#2)
+
+<h2 id="1">Model Quantization with Intel® MKL-DNN</h2>
+
+Intel® MKL-DNN supports quantization with subgraph features on Intel® CPU Platform and can bring performance improvements on the [Intel® Xeon® Scalable Platform](https://www.intel.com/content/www/us/en/processors/xeon/scalable/xeon-scalable-platform.html). A new quantization script `imagenet_gen_qsym_mkldnn.py` has been designed to launch quantization for CNN models with Intel® MKL-DNN. This script integrates with [Gluon-CV modelzoo](https://gluon-cv.mxnet.io/model_zoo/classification.htm [...]
+
+Calibration is used for generating a calibration table for the quantized symbol. The quantization script supports three methods:
+
+- **none:** No calibration will be used. The thresholds for quantization will be calculated on the fly. This will result in inference speed slowdown and loss of accuracy in general.
+- **naive:** Simply take min and max values of layer outputs as thresholds for quantization. In general, the inference accuracy worsens with more examples used in calibration. It is recommended to use `entropy` mode as it produces more accurate inference results.
+- **entropy:** Calculate KL divergence of the fp32 output and quantized output for optimal thresholds. This mode is expected to produce the best inference accuracy of all three kinds of quantized models if the calibration dataset is representative enough of the inference dataset.
+
+Use the following command to install [Gluon-CV](https://gluon-cv.mxnet.io/):
+
+```
+pip install gluoncv
+```
+
+The following models have been tested on Linux systems.
+
+| Model | Source | Dataset | FP32 Accuracy (top-1/top-5)| INT8 Accuracy (top-1/top-5)|
+|:---|:---|---|:---:|:---:|
+| [ResNet50-V1](#3)  | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)  | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)  | 75.87%/92.72%  |  75.71%/92.65% |
+| [ResNet101-V1](#4)  | [Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)  | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)  | 77.3%/93.58%  | 77.09%/93.41%  |
+|[Squeezenet 1.0](#5)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|57.01%/79.71%|56.62%/79.55%|
+|[MobileNet 1.0](#6)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|69.76%/89.32%|69.61%/89.09%|
+|[Inception V3](#7)|[Gluon-CV](https://gluon-cv.mxnet.io/model_zoo/classification.html)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.49%/93.10% |76.38%/93% |
+|[ResNet152-V2](#8)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.76%/93.03%|76.48%/92.96%|
+|[Inception-BN](#9)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/inception-bn/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.09%/90.60%|72.00%/90.53%|
+| [SSD-VGG](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd)  | VOC2007/2012  | 0.83 mAP  | 0.82 mAP  |
+
+<h3 id='3'>ResNet50-V1</h3>
+
+The following command is to download the pre-trained model from Gluon-CV and transfer it into the symbolic model which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=resnet50_v1 --num-calib-batches=5 --calib-mode=naive
+```
+
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference 
+python imagenet_inference.py --symbol-file=./model/resnet50_v1-symbol.json --param-file=./model/resnet50_v1-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/resnet50_v1-quantized-5batches-naive-symbol.json --param-file=./model/resnet50_v1-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/resnet50_v1-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+python imagenet_inference.py --symbol-file=./model/resnet50_v1-quantized-5batches-naive-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='4'>ResNet101-V1</h3>
+
+The following command is to download the pre-trained model from Gluon-CV and transfer it into the symbolic model which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=resnet101_v1 --num-calib-batches=5 --calib-mode=naive
+```
+
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference 
+python imagenet_inference.py --symbol-file=./model/resnet101_v1-symbol.json --param-file=./model/resnet101_v1-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/resnet101_v1-quantized-5batches-naive-symbol.json --param-file=./model/resnet101_v1-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/resnet101_v1-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+python imagenet_inference.py --symbol-file=./model/resnet101_v1-quantized-5batches-naive-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='5'>SqueezeNet 1.0</h3>
+
+The following command is to download the pre-trained model from Gluon-CV and transfer it into the symbolic model which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=squeezenet1.0 --num-calib-batches=5 --calib-mode=naive
+```
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference
+python imagenet_inference.py --symbol-file=./model/squeezenet1.0-symbol.json --param-file=./model/squeezenet1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/squeezenet1.0-quantized-5batches-naive-symbol.json --param-file=./model/squeezenet1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/squeezenet1.0-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
+python imagenet_inference.py --symbol-file=./model/squeezenet1.0-quantized-5batches-naive-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='6'>MobileNet 1.0</h3>
+
+The following command is to download the pre-trained model from Gluon-CV and transfer it into the symbolic model which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=mobilenet1.0 --num-calib-batches=5 --calib-mode=naive
+```
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference
+python imagenet_inference.py --symbol-file=./model/mobilenet1.0-symbol.json --param-file=./model/mobilenet1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/mobilenet1.0-quantized-5batches-naive-symbol.json --param-file=./model/mobilenet1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/mobilenet1.0-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
+python imagenet_inference.py --symbol-file=./model/mobilenet1.0-quantized-5batches-naive-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='7'>Inception-V3</h3>
+
+The following command is to download the pre-trained model from Gluon-CV and transfer it into the symbolic model which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=inceptionv3 --image-shape=3,299,299 --num-calib-batches=5 --calib-mode=naive
+```
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference
+python imagenet_inference.py --symbol-file=./model/inceptionv3-symbol.json --param-file=./model/inceptionv3-0000.params --image-shape=3,299,299 --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/inceptionv3-quantized-5batches-naive-symbol.json --param-file=./model/inceptionv3-quantized-0000.params --image-shape=3,299,299 --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/inceptionv3-symbol.json --image-shape=3,299,299 --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
+python imagenet_inference.py --symbol-file=./model/inceptionv3-quantized-5batches-naive-symbol.json --image-shape=3,299,299 --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='8'>ResNet152-V2</h3>
+
+The following command is to download the pre-trained model from the [MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/) which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=imagenet1k-resnet-152 --num-calib-batches=5 --calib-mode=naive
+```
+
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference 
+python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-symbol.json --param-file=./model/imagenet1k-resnet-152-0000.params --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-5batches-naive-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='9'>Inception-BN</h3>
+
+The following command is to download the pre-trained model from the [MXNet ModelZoo](http://data.mxnet.io/models/imagenet/inception-bn/) which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=imagenet1k-inception-bn --num-calib-batches=5 --calib-mode=naive
+```
+
+The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. The following command is to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference 
+python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-symbol.json --param-file=./model/imagenet1k-inception-bn-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+
+# Launch INT8 Inference
+python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-5batches-naive-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h3 id='10'>SSD-VGG</h3>
+
+Follow the [SSD example's instructions](https://github.com/apache/incubator-mxnet/tree/master/example/ssd#train-the-model) in [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) to train a FP32 `SSD-VGG16_reduced_300x300` model based on Pascal VOC dataset. You can also download our [SSD-VGG16 pre-trained model](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/ssd_vgg16_reduced_300-dd479559.zip) and [packed binary data](http://apache-mxne [...]
+
+```
+data/
+|---val.rec
+|---val.lxt
+|---val.idx
+model/
+|---ssd_vgg16_reduced_300.params
+|---ssd_vgg16_reduced_300-symbol.json
+```
+
+Then, use the following command for quantization. By default, this script uses 5 batches (32 samples per batch) for naive calibration:
+
+```
+python quantization.py
+```
+
+After quantization, INT8 models will be saved in `model/` dictionary.  Use the following command to launch inference.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference 
+python evaluate.py --cpu --num-batch 10 --batch-size 224 --deploy --prefix=./model/ssd_
+
+# Launch INT8 Inference
+python evaluate.py --cpu --num-batch 10 --batch-size 224 --deploy --prefix=./model/cqssd_
+
+# Launch dummy data Inference
+python benchmark_score.py --deploy --prefix=./model/ssd_
+python benchmark_score.py --deploy --prefix=./model/cqssd_
+```
+
+<h3 id='11'>Custom Model</h3>
+
+This script also supports custom symbolic models. You can easily add some quantization layer configs in `imagenet_gen_qsym_mkldnn.py` like below:
+
+```
+elif args.model == 'custom':
+    # add rgb mean/std of your model.
+    rgb_mean = '0,0,0'
+    rgb_std = '0,0,0'
+    calib_layer = lambda name: name.endswith('_output')
+    # add layer names you donnot want to quantize.
+    # add conv/pool layer names that has negative inputs
+    # since Intel® MKL-DNN only support uint8 quantization temporary.
+    # add all fc layer names since Intel® MKL-DNN does not support temporary.
+    excluded_sym_names += ['layers']
+    # add your first conv layer names since Intel® MKL-DNN only support uint8 quantization temporary.
+    if exclude_first_conv:
+        excluded_sym_names += ['layers']
+```
+
+Some tips on quantization configs:
+
+1. First, you should prepare your data, symbol file (custom-symbol.json) and parameter file (custom-0000.params) of your fp32 symbolic model.
+2. Then, you should run the following command and verify that your fp32 symbolic model runs inference as expected.
+
+```
+# USE MKLDNN AS SUBGRAPH BACKEND
+export MXNET_SUBGRAPH_BACKEND=MKLDNN
+
+# Launch FP32 Inference 
+python imagenet_inference.py --symbol-file=./model/custom-symbol.json --param-file=./model/custom-0000.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* --ctx=cpu --data-nthreads=1
+```
+
+3. Then, you should add `rgb_mean`, `rgb_std` and `excluded_sym_names` in this script. Notice that you should exclude conv/pool layers that have negative data since Intel® MKL-DNN only supports `uint8` quantization temporarily. You should also exclude all fc layers in your model.
+
+4. Then, you can run the following command for quantization:
+
+```
+python imagenet_gen_qsym_mkldnn.py --model=custom --num-calib-batches=5 --calib-mode=naive
+```
+
+5. After quantization, the quantized symbol and parameter files will be saved in the `model/` directory.
+
+6. Finally, you can run INT8 inference:
+
+```
+# Launch INT8 Inference 
+python imagenet_inference.py --symbol-file=./model/*.json --param-file=./model/*.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* --ctx=cpu --data-nthreads=1
+
+# Launch dummy data Inference
+python imagenet_inference.py --symbol-file=./model/*.json --batch-size=* --num-inference-batches=500 --ctx=cpu --benchmark=True
+```
+
+<h2 id="2">Model Quantization with CUDNN</h2>
+
 This folder contains examples of quantizing a FP32 model with or without calibration and using the calibrated
 quantized for inference. Two pre-trained imagenet models are taken as examples for quantization. One is
 [Resnet-152](http://data.mxnet.io/models/imagenet/resnet/152-layers/), and the other one is
diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py
index e062761..9056f79 100644
--- a/example/quantization/imagenet_gen_qsym_mkldnn.py
+++ b/example/quantization/imagenet_gen_qsym_mkldnn.py
@@ -20,6 +20,10 @@ import os
 import logging
 from common import modelzoo
 import mxnet as mx
+import gluoncv
+from mxnet import gluon, nd, image
+from gluoncv import utils
+from gluoncv.model_zoo import get_model
 from mxnet.contrib.quantization import *
 from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
 import ctypes
@@ -38,6 +42,39 @@ def download_model(model_name, logger=None):
         logger.info('Downloading model %s... into path %s' % (model_name, model_path))
     return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
 
+def convert_from_gluon(model_name, image_shape, classes=1000, logger=None):
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    model_path = os.path.join(dir_path, 'model')
+    if logger is not None:
+        logger.info('Converting model from Gluon-CV ModelZoo %s... into path %s' % (model_name, model_path))
+    net = get_model(name=model_name, classes=classes, pretrained=True)
+    net.hybridize()
+    x = mx.sym.var('data')
+    y = net(x)
+    y = mx.sym.SoftmaxOutput(data=y, name='softmax')
+    symnet = mx.symbol.load_json(y.tojson())
+    params = net.collect_params()
+    args = {}
+    auxs = {}    
+    for param in params.values():
+        v = param._reduce()
+        k = param.name
+        if 'running' in k:
+            auxs[k] = v
+        else:
+            args[k] = v            
+    mod = mx.mod.Module(symbol=symnet, context=mx.cpu(),
+                        label_names = ['softmax_label'])
+    mod.bind(for_training=False, 
+             data_shapes=[('data', (1,) + 
+                          tuple([int(i) for i in image_shape.split(',')]))])
+    mod.set_params(arg_params=args, aux_params=auxs)
+    dst_dir = os.path.join(dir_path, 'model')
+    prefix = os.path.join(dir_path, 'model', model_name)
+    if not os.path.isdir(dst_dir):
+        os.mkdir(dst_dir)       
+    mod.save_checkpoint(prefix, 0)
+    return prefix
 
 def save_symbol(fname, sym, logger=None):
     if logger is not None:
@@ -54,9 +91,20 @@ def save_params(fname, arg_params, aux_params, logger=None):
 
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with MKL-DNN support')
-    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
-                        help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
+    parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with Intel MKL-DNN support')
+    parser.add_argument('--model', type=str, choices=['resnet50_v1',
+                                                      'resnet101_v1',
+                                                      'inceptionv3',
+                                                      'squeezenet1.0',
+                                                      'mobilenet1.0',
+                                                      'imagenet1k-resnet-152',
+                                                      'imagenet1k-inception-bn',
+                                                      'custom'],
+                        help='currently only supports imagenet1k-resnet-50_v1, imagenet1k-resnet-152 or imagenet1k-inception-bn.'
+                             'you can set to custom to load your pre-trained model.')
+    parser.add_argument('--use-gluon-model', type=bool, default=False,
+                        help='If enabled, will download pretrained model from Gluon-CV '
+                             'and convert to symbolic model ')    
     parser.add_argument('--batch-size', type=int, default=32)
     parser.add_argument('--label-name', type=str, default='softmax_label')
     parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec',
@@ -115,8 +163,21 @@ if __name__ == '__main__':
         download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset)
 
     # download model
-    prefix, epoch = download_model(model_name=args.model, logger=logger)
-    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    if args.model in ['resnet50_v1', 'resnet101_v1', 'squeezenet1.0', 'mobilenet1.0', 'inceptionv3']:
+        logger.info('model %s is converted from GluonCV' % args.model)
+        args.use_gluon_model = True
+    if args.use_gluon_model == True:
+        prefix = convert_from_gluon(model_name=args.model, image_shape=args.image_shape, classes=1000, logger=logger)
+        epoch = 0
+        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    elif args.model == 'custom':
+        dir_path = os.path.dirname(os.path.realpath(__file__))
+        prefix = os.path.join(dir_path, 'model', args.model)
+        epoch = 0
+        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    else:
+        prefix, epoch = download_model(model_name=args.model, logger=logger)
+        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
 
     sym = sym.get_backend_symbol('MKLDNN')
 
@@ -141,16 +202,66 @@ if __name__ == '__main__':
     excluded_sym_names = []
     if args.model == 'imagenet1k-resnet-152':
         rgb_mean = '0,0,0'
+        rgb_std = '1,1,1'
         calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['flatten0', 'fc1']
+        excluded_sym_names += ['flatten0', 'fc1', 'pooling0']
         if exclude_first_conv:
-            excluded_sym_names += ['conv0', 'pooling0']
+            excluded_sym_names += ['conv0']
     elif args.model == 'imagenet1k-inception-bn':
         rgb_mean = '123.68,116.779,103.939'
+        rgb_std = '1,1,1'
         calib_layer = lambda name: name.endswith('_output')
         excluded_sym_names += ['flatten', 'fc1']
         if exclude_first_conv:
             excluded_sym_names += ['conv_1']
+    elif args.model in ['resnet50_v1', 'resnet101_v1']:
+        rgb_mean = '123.68,116.779,103.939'
+        rgb_std = '58.393, 57.12, 57.375'
+        calib_layer = lambda name: name.endswith('_output')
+        excluded_sym_names += ['resnetv10_dense0_fwd', 'resnetv10_pool0_fwd']
+        if exclude_first_conv:
+            excluded_sym_names += ['resnetv10_conv0_fwd']
+    elif args.model == 'squeezenet1.0':
+        rgb_mean = '123.68,116.779,103.939'
+        rgb_std = '58.393, 57.12, 57.375'
+        calib_layer = lambda name: name.endswith('_output')
+        excluded_sym_names += ['squeezenet0_flatten0_flatten0',
+                               'squeezenet0_pool0_fwd',
+                               'squeezenet0_pool1_fwd',
+                               'squeezenet0_pool2_fwd',
+                               'squeezenet0_pool3_fwd']
+        if exclude_first_conv:
+            excluded_sym_names += ['squeezenet0_conv0_fwd']
+    elif args.model == 'mobilenet1.0':
+        rgb_mean = '123.68,116.779,103.939'
+        rgb_std = '58.393, 57.12, 57.375'
+        calib_layer = lambda name: name.endswith('_output')
+        excluded_sym_names += ['mobilenet0_flatten0_flatten0',
+                               'mobilenet0_dense0_fwd',
+                               'mobilenet0_pool0_fwd']
+        if exclude_first_conv:
+            excluded_sym_names += ['mobilenet0_conv0_fwd']
+    elif args.model == 'inceptionv3':
+        rgb_mean = '123.68,116.779,103.939'
+        rgb_std = '58.393, 57.12, 57.375'
+        calib_layer = lambda name: name.endswith('_output')
+        excluded_sym_names += ['inception30_dense0_fwd',
+                               'inception30_pool0_fwd']
+        if exclude_first_conv:
+            excluded_sym_names += ['inception30_conv0_fwd']
+    elif args.model == 'custom':
+        # add rgb mean/std of your model.
+        rgb_mean = '0,0,0'
+        rgb_std = '0,0,0'
+        calib_layer = lambda name: name.endswith('_output')
+        # add layer names you donnot want to quantize.
+        # add conv/pool layer names that has negative inputs
+        # since Intel MKL-DNN only support uint8 quantization temporary.
+        # add all fc layer names since Intel MKL-DNN does not support temporary.
+        excluded_sym_names += ['layers']
+        # add your first conv layer names since Intel MKL-DNN only support uint8 quantization temporary.
+        if exclude_first_conv:
+            excluded_sym_names += ['layers']
     else:
         raise ValueError('model %s is not supported in this script' % args.model)
 
@@ -163,6 +274,9 @@ if __name__ == '__main__':
     logger.info('rgb_mean = %s' % rgb_mean)
     rgb_mean = [float(i) for i in rgb_mean.split(',')]
     mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
+    logger.info('rgb_std = %s' % rgb_std)
+    rgb_std = [float(i) for i in rgb_std.split(',')]
+    std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}    
 
     if calib_mode == 'none':
         logger.info('Quantizing FP32 model %s' % args.model)
@@ -184,7 +298,8 @@ if __name__ == '__main__':
                                      shuffle=args.shuffle_dataset,
                                      shuffle_chunk_seed=args.shuffle_chunk_seed,
                                      seed=args.shuffle_seed,
-                                     **mean_args)
+                                     **mean_args,
+                                     **std_args)
 
         qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
                                                         ctx=ctx, excluded_sym_names=excluded_sym_names,
diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py
index 286e49e..7d380d3 100644
--- a/example/quantization/imagenet_inference.py
+++ b/example/quantization/imagenet_inference.py
@@ -97,15 +97,46 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
             logger.info(m.get())
 
 
+def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
+    # get mod
+    cur_path = os.path.dirname(os.path.realpath(__file__))
+    symbol_file_path = os.path.join(cur_path, symbol_file)
+    if logger is not None:
+        logger.info('Loading symbol from file %s' % symbol_file_path)
+    sym = mx.sym.load(symbol_file_path)
+    mod = mx.mod.Module(symbol=sym, context=ctx)
+    mod.bind(for_training     = False,
+             inputs_need_grad = False,
+             data_shapes      = [('data', (batch_size,)+data_shape)])
+    mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
+
+    # get data
+    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
+    batch = mx.io.DataBatch(data, []) # empty label
+
+    # run
+    dry_run = 5                 # use 5 iterations to warm up
+    for i in range(dry_run+num_batches):
+        if i == dry_run:
+            tic = time.time()
+        mod.forward(batch, is_train=False)
+        for output in mod.get_outputs():
+            output.wait_to_read()
+
+    # return num images per second
+    return num_batches*batch_size/(time.time() - tic)
+
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Score a model on a dataset')
     parser.add_argument('--ctx', type=str, default='gpu')
+    parser.add_argument('--benchmark', type=bool, default=False, help='dummy data benchmark')
     parser.add_argument('--symbol-file', type=str, required=True, help='symbol file path')
-    parser.add_argument('--param-file', type=str, required=True, help='param file path')
+    parser.add_argument('--param-file', type=str, required=False, help='param file path')
     parser.add_argument('--batch-size', type=int, default=32)
     parser.add_argument('--label-name', type=str, default='softmax_label')
-    parser.add_argument('--dataset', type=str, required=True, help='dataset path')
+    parser.add_argument('--dataset', type=str, required=False, help='dataset path')
     parser.add_argument('--rgb-mean', type=str, default='0,0,0')
+    parser.add_argument('--rgb-std', type=str, default='1,1,1')
     parser.add_argument('--image-shape', type=str, default='3,224,224')
     parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding')
     parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference')
@@ -145,6 +176,10 @@ if __name__ == '__main__':
     logger.info('rgb_mean = %s' % rgb_mean)
     rgb_mean = [float(i) for i in rgb_mean.split(',')]
     mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
+    rgb_std = args.rgb_std
+    logger.info('rgb_std = %s' % rgb_std)
+    rgb_std = [float(i) for i in rgb_std.split(',')]
+    std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}
 
     label_name = args.label_name
     logger.info('label_name = %s' % label_name)
@@ -153,32 +188,38 @@ if __name__ == '__main__':
     data_shape = tuple([int(i) for i in image_shape.split(',')])
     logger.info('Input data shape = %s' % str(data_shape))
 
-    dataset = args.dataset
-    download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
-    logger.info('Dataset for inference: %s' % dataset)
-
-    # creating data iterator
-    data = mx.io.ImageRecordIter(path_imgrec=dataset,
-                                 label_width=1,
-                                 preprocess_threads=data_nthreads,
-                                 batch_size=batch_size,
-                                 data_shape=data_shape,
-                                 label_name=label_name,
-                                 rand_crop=False,
-                                 rand_mirror=False,
-                                 shuffle=True,
-                                 shuffle_chunk_seed=3982304,
-                                 seed=48564309,
-                                 **mean_args)
-
-    # loading model
-    sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
-
-    # make sure that fp32 inference works on the same images as calibrated quantized model
-    logger.info('Skipping the first %d batches' % args.num_skipped_batches)
-    data = advance_data_iter(data, args.num_skipped_batches)
-
-    num_inference_images = args.num_inference_batches * batch_size
-    logger.info('Running model %s for inference' % symbol_file)
-    score(sym, arg_params, aux_params, data, [ctx], label_name,
-          max_num_examples=num_inference_images, logger=logger)
+    if args.benchmark == False:
+        dataset = args.dataset
+        download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
+        logger.info('Dataset for inference: %s' % dataset)
+
+        # creating data iterator
+        data = mx.io.ImageRecordIter(path_imgrec=dataset,
+                                    label_width=1,
+                                    preprocess_threads=data_nthreads,
+                                    batch_size=batch_size,
+                                    data_shape=data_shape,
+                                    label_name=label_name,
+                                    rand_crop=False,
+                                    rand_mirror=False,
+                                    shuffle=True,
+                                    shuffle_chunk_seed=3982304,
+                                    seed=48564309,
+                                    **mean_args,
+                                    **std_args)
+
+        # loading model
+        sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
+
+        # make sure that fp32 inference works on the same images as calibrated quantized model
+        logger.info('Skipping the first %d batches' % args.num_skipped_batches)
+        data = advance_data_iter(data, args.num_skipped_batches)
+
+        num_inference_images = args.num_inference_batches * batch_size
+        logger.info('Running model %s for inference' % symbol_file)
+        score(sym, arg_params, aux_params, data, [ctx], label_name,
+            max_num_examples=num_inference_images, logger=logger)
+    else:
+        logger.info('Running model %s for inference' % symbol_file)
+        speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, logger)
+        logger.info('batch size %2d, image/sec: %f', batch_size, speed)
diff --git a/example/ssd/benchmark_score.py b/example/ssd/benchmark_score.py
index caeb208..01a0eb9 100644
--- a/example/ssd/benchmark_score.py
+++ b/example/ssd/benchmark_score.py
@@ -34,12 +34,19 @@ parser.add_argument('--network', '-n', type=str, default='vgg16_reduced')
 parser.add_argument('--batch_size', '-b', type=int, default=0)
 parser.add_argument('--shape', '-w', type=int, default=300)
 parser.add_argument('--class_num', '-class', type=int, default=20)
+parser.add_argument('--prefix', dest='prefix', help='load model prefix',
+                    default=os.path.join(os.getcwd(), 'model', 'ssd_'), type=str)
+parser.add_argument('--deploy', dest='deploy', help='Load network from model',
+                    action='store_true', default=False)
 
 
 def get_data_shapes(batch_size):
     image_shape = (3, 300, 300)
     return [('data', (batch_size,)+image_shape)]
 
+def get_label_shapes(batch_size):
+    return [('label', (batch_size,) + (42, 6))]
+
 def get_data(batch_size):
     data_shapes = get_data_shapes(batch_size)
     data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.cpu()) for _, shape in data_shapes]
@@ -53,6 +60,7 @@ if __name__ == '__main__':
     image_shape = args.shape
     num_classes = args.class_num
     b = args.batch_size
+    prefix = args.prefix
     supported_image_shapes = [300, 512]
     supported_networks = ['vgg16_reduced', 'inceptionv3', 'resnet50']
 
@@ -68,18 +76,27 @@ if __name__ == '__main__':
         batch_sizes = [b]
 
     data_shape = (3, image_shape, image_shape)
-    net = get_symbol(network, data_shape[1], num_classes=num_classes,
-                     nms_thresh=0.4, force_suppress=True)
+
+    if args.deploy == True:
+        prefix += network + '_' + str(data_shape[1]) + '-symbol.json'
+        net = mx.sym.load(prefix)
+    else:
+        net = get_symbol(network, data_shape[1], num_classes=num_classes,
+                         nms_thresh=0.4, force_suppress=True)
+    if not 'label' in net.list_arguments():
+        label = mx.sym.Variable(name='label')
+        net = mx.sym.Group([net, label])
     
     num_batches = 100
     dry_run = 5   # use 5 iterations to warm up
     
     for bs in batch_sizes:
         batch = get_data(bs)
-        mod = mx.mod.Module(net, label_names=None, context=mx.cpu())
+        mod = mx.mod.Module(net, label_names=('label',), context=mx.cpu())
         mod.bind(for_training = False,
                  inputs_need_grad = False,
-                 data_shapes = get_data_shapes(bs))
+                 data_shapes = get_data_shapes(bs),
+                 label_shapes = get_label_shapes(bs))
         mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
 
         # get data
diff --git a/example/ssd/config/config.py b/example/ssd/config/config.py
index b084888..8d44a0d 100644
--- a/example/ssd/config/config.py
+++ b/example/ssd/config/config.py
@@ -16,7 +16,7 @@
 # under the License.
 
 import os
-from config.utils import DotDict, namedtuple_with_defaults, zip_namedtuple, config_as_dict
+from .utils import DotDict, namedtuple_with_defaults, zip_namedtuple, config_as_dict
 
 RandCropper = namedtuple_with_defaults('RandCropper',
     'min_crop_scales, max_crop_scales, \
diff --git a/example/ssd/evaluate.py b/example/ssd/evaluate.py
index d1a83cc..bbe9fea 100644
--- a/example/ssd/evaluate.py
+++ b/example/ssd/evaluate.py
@@ -30,6 +30,8 @@ def parse_args():
                         default="", type=str)
     parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
                         help='which network to use')
+    parser.add_argument('--num-batch', dest='num_batch', type=int, default=5,
+                        help='evaluation number batches')
     parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
                         help='evaluation batch size')
     parser.add_argument('--num-class', dest='num_class', type=int, default=20,
@@ -97,7 +99,7 @@ if __name__ == '__main__':
         prefix = args.prefix + args.network
     else:
         prefix = args.prefix
-    evaluate_net(network, args.rec_path, num_class,
+    evaluate_net(network, args.rec_path, num_class, args.num_batch,
                  (args.mean_r, args.mean_g, args.mean_b), args.data_shape,
                  prefix, args.epoch, ctx, batch_size=args.batch_size,
                  path_imglist=args.list_path, nms_thresh=args.nms_thresh,
diff --git a/example/ssd/evaluate/evaluate_net.py b/example/ssd/evaluate/evaluate_net.py
index fabe54f..35e253d 100644
--- a/example/ssd/evaluate/evaluate_net.py
+++ b/example/ssd/evaluate/evaluate_net.py
@@ -24,10 +24,15 @@ from dataset.iterator import DetRecordIter
 from config.config import cfg
 from evaluate.eval_metric import MApMetric, VOC07MApMetric
 import logging
+import time
 from symbol.symbol_factory import get_symbol
+from symbol import symbol_builder
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+import ctypes
+from mxnet.contrib.quantization import *
 
-def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape,
-                 model_prefix, epoch, ctx=mx.cpu(), batch_size=1,
+def evaluate_net(net, path_imgrec, num_classes, num_batch, mean_pixels, data_shape,
+                 model_prefix, epoch, ctx=mx.cpu(), batch_size=32,
                  path_imglist="", nms_thresh=0.45, force_nms=False,
                  ovp_thresh=0.5, use_difficult=False, class_names=None,
                  voc07_metric=False):
@@ -106,6 +111,23 @@ def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape,
         metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
     else:
         metric = MApMetric(ovp_thresh, use_difficult, class_names)
-    results = mod.score(eval_iter, metric, num_batch=None)
+
+    num = num_batch * batch_size
+    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
+    batch = mx.io.DataBatch(data, [])  # empty label
+
+    dry_run = 5                 # use 5 iterations to warm up
+    for i in range(dry_run):
+        mod.forward(batch, is_train=False)
+        for output in mod.get_outputs():
+            output.wait_to_read()
+
+    tic = time.time()
+    results = mod.score(eval_iter, metric, num_batch=num_batch)
+    speed = num / (time.time() - tic)
+    if logger is not None:
+        logger.info('Finished inference with %d images' % num)
+        logger.info('Finished with %f images per second', speed)
+
     for k, v in results:
         print("{}: {}".format(k, v))
diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/ssd/quantization.py
similarity index 60%
copy from example/quantization/imagenet_gen_qsym_mkldnn.py
copy to example/ssd/quantization.py
index e062761..5cb74ba 100644
--- a/example/quantization/imagenet_gen_qsym_mkldnn.py
+++ b/example/ssd/quantization.py
@@ -15,29 +15,22 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import argparse
+from __future__ import print_function
 import os
-import logging
-from common import modelzoo
+import sys
+import importlib
 import mxnet as mx
-from mxnet.contrib.quantization import *
+from dataset.iterator import DetRecordIter
+from config.config import cfg
+from evaluate.eval_metric import MApMetric, VOC07MApMetric
+import argparse
+import logging
+import time
+from symbol.symbol_factory import get_symbol
+from symbol import symbol_builder
 from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
 import ctypes
-
-
-def download_calib_dataset(dataset_url, calib_dataset, logger=None):
-    if logger is not None:
-        logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset))
-    mx.test_utils.download(dataset_url, calib_dataset)
-
-
-def download_model(model_name, logger=None):
-    dir_path = os.path.dirname(os.path.realpath(__file__))
-    model_path = os.path.join(dir_path, 'model')
-    if logger is not None:
-        logger.info('Downloading model %s... into path %s' % (model_name, model_path))
-    return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
-
+from mxnet.contrib.quantization import *
 
 def save_symbol(fname, sym, logger=None):
     if logger is not None:
@@ -54,21 +47,14 @@ def save_params(fname, arg_params, aux_params, logger=None):
 
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with MKL-DNN support')
-    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
-                        help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
+    parser = argparse.ArgumentParser(description='Generate a calibrated quantized SSD model from a FP32 model')
     parser.add_argument('--batch-size', type=int, default=32)
-    parser.add_argument('--label-name', type=str, default='softmax_label')
-    parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec',
-                        help='path of the calibration dataset')
-    parser.add_argument('--image-shape', type=str, default='3,224,224')
-    parser.add_argument('--data-nthreads', type=int, default=60,
-                        help='number of threads for data decoding')
-    parser.add_argument('--num-calib-batches', type=int, default=10,
+    parser.add_argument('--num-calib-batches', type=int, default=5,
                         help='number of batches for calibration')
     parser.add_argument('--exclude-first-conv', action='store_true', default=True,
                         help='excluding quantizing the first conv layer since the'
-                             ' input data may have negative value which doesn\'t support at moment' )
+                             ' number of channels is usually not a multiple of 4 in that layer'
+                             ' which does not satisfy the requirement of cuDNN')
     parser.add_argument('--shuffle-dataset', action='store_true', default=True,
                         help='shuffle the calibration dataset')
     parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
@@ -79,7 +65,7 @@ if __name__ == '__main__':
                         help='shuffling seed, see'
                              ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
                              ' for more details')
-    parser.add_argument('--calib-mode', type=str, default='entropy',
+    parser.add_argument('--calib-mode', type=str, default='naive',
                         help='calibration mode used for generating calibration table for the quantized symbol; supports'
                              ' 1. none: no calibration will be used. The thresholds for quantization will be calculated'
                              ' on the fly. This will result in inference speed slowdown and loss of accuracy'
@@ -95,10 +81,7 @@ if __name__ == '__main__':
     parser.add_argument('--quantized-dtype', type=str, default='uint8',
                         choices=['int8', 'uint8'],
                         help='quantization destination data type for input data')
-    parser.add_argument('--enable-calib-quantize', type=bool, default=True,
-                        help='If enabled, the quantize op will '
-                             'be calibrated offline if calibration mode is '
-                             'enabled')
+
     args = parser.parse_args()
     ctx = mx.cpu(0)
     logging.basicConfig()
@@ -110,13 +93,13 @@ if __name__ == '__main__':
     calib_mode = args.calib_mode
     logger.info('calibration mode set to %s' % calib_mode)
 
-    # download calibration dataset
-    if calib_mode != 'none':
-        download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset)
+    # load FP32 models
+    prefix, epoch = "./model/ssd_vgg16_reduced_300", 0
+    sym, arg_params, aux_params = mx.model.load_checkpoint("./model/ssd_vgg16_reduced_300", 0)
 
-    # download model
-    prefix, epoch = download_model(model_name=args.model, logger=logger)
-    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    if not 'label' in sym.list_arguments():
+        label = mx.sym.Variable(name='label')
+        sym = mx.sym.Group([sym, label])
 
     sym = sym.get_backend_symbol('MKLDNN')
 
@@ -126,35 +109,26 @@ if __name__ == '__main__':
 
     # get number of batches for calibration
     num_calib_batches = args.num_calib_batches
-    if calib_mode == 'none':
-        logger.info('skip calibration step as calib_mode is none')
-    else:
+    if calib_mode != 'none':
         logger.info('number of batches = %d for calibration' % num_calib_batches)
 
-    # get number of threads for decoding the dataset
-    data_nthreads = args.data_nthreads
-
     # get image shape
-    image_shape = args.image_shape
+    image_shape = '3,300,300'
 
+    # Quantization layer configs
     exclude_first_conv = args.exclude_first_conv
     excluded_sym_names = []
-    if args.model == 'imagenet1k-resnet-152':
-        rgb_mean = '0,0,0'
-        calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['flatten0', 'fc1']
-        if exclude_first_conv:
-            excluded_sym_names += ['conv0', 'pooling0']
-    elif args.model == 'imagenet1k-inception-bn':
-        rgb_mean = '123.68,116.779,103.939'
-        calib_layer = lambda name: name.endswith('_output')
-        excluded_sym_names += ['flatten', 'fc1']
-        if exclude_first_conv:
-            excluded_sym_names += ['conv_1']
-    else:
-        raise ValueError('model %s is not supported in this script' % args.model)
-
-    label_name = args.label_name
+    rgb_mean = '123,117,104'
+    calib_layer = lambda name: name.endswith('_output')
+    for i in range(1,19):
+        excluded_sym_names += ['flatten'+str(i)]
+    excluded_sym_names += ['relu4_3_cls_pred_conv',
+                            'relu7_cls_pred_conv',
+                            'relu4_3_loc_pred_conv']
+    if exclude_first_conv:
+        excluded_sym_names += ['conv1_1']
+
+    label_name = 'label'
     logger.info('label_name = %s' % label_name)
 
     data_shape = tuple([int(i) for i in image_shape.split(',')])
@@ -170,38 +144,25 @@ if __name__ == '__main__':
                                                        ctx=ctx, excluded_sym_names=excluded_sym_names,
                                                        calib_mode=calib_mode, quantized_dtype=args.quantized_dtype,
                                                        logger=logger)
-        sym_name = '%s-symbol.json' % (prefix + '-quantized')
+        sym_name = '%s-symbol.json' % ('./model/qssd_vgg16_reduced_300')
+        param_name = '%s-%04d.params' % ('./model/qssd_vgg16_reduced_300', epoch)
+        save_symbol(sym_name, qsym, logger)
     else:
         logger.info('Creating ImageRecordIter for reading calibration dataset')
-        data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset,
-                                     label_width=1,
-                                     preprocess_threads=data_nthreads,
-                                     batch_size=batch_size,
-                                     data_shape=data_shape,
-                                     label_name=label_name,
-                                     rand_crop=False,
-                                     rand_mirror=False,
-                                     shuffle=args.shuffle_dataset,
-                                     shuffle_chunk_seed=args.shuffle_chunk_seed,
-                                     seed=args.shuffle_seed,
-                                     **mean_args)
+        eval_iter = DetRecordIter(os.path.join(os.getcwd(), 'data', 'val.rec'),
+                                  batch_size, data_shape, mean_pixels=(123, 117, 104),
+                                  path_imglist="", **cfg.valid)
 
         qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
                                                         ctx=ctx, excluded_sym_names=excluded_sym_names,
-                                                        calib_mode=calib_mode, calib_data=data,
+                                                        calib_mode=calib_mode, calib_data=eval_iter,
                                                         num_calib_examples=num_calib_batches * batch_size,
                                                         calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
-                                                        label_names=(label_name,), calib_quantize_op = True,
+                                                        label_names=(label_name,),
+                                                        calib_quantize_op = True,
                                                         logger=logger)
-        if calib_mode == 'entropy':
-            suffix = '-quantized-%dbatches-entropy' % num_calib_batches
-        elif calib_mode == 'naive':
-            suffix = '-quantized-%dbatches-naive' % num_calib_batches
-        else:
-            raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`'
-                             % calib_mode)
-        sym_name = '%s-symbol.json' % (prefix + suffix)
+        sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300')
+        param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch)
     qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
     save_symbol(sym_name, qsym, logger)
-    param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
     save_params(param_name, qarg_params, aux_params, logger)
diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py
index 8ae729f..357e75b 100755
--- a/python/mxnet/initializer.py
+++ b/python/mxnet/initializer.py
@@ -152,6 +152,12 @@ class Initializer(object):
             elif desc.endswith('beta'):
                 self._init_beta(desc, arr)
                 self._verbose_print(desc, 'beta', arr)
+            elif desc.endswith('min'):
+                self._init_zero(desc, arr)
+                self._verbose_print(desc, 'min', arr)
+            elif desc.endswith('max'):
+                self._init_one(desc, arr)
+                self._verbose_print(desc, 'max', arr)
             else:
                 self._init_default(desc, arr)
 
@@ -196,6 +202,10 @@ class Initializer(object):
             self._init_zero(name, arr)
         elif name.endswith("moving_avg"):
             self._init_zero(name, arr)
+        elif name.endswith('min'):
+            self._init_zero(name, arr)
+        elif name.endswith('max'):
+            self._init_one(name, arr)
         else:
             self._init_default(name, arr)
 
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 5b70821..71784dc 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -55,6 +55,17 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_
     output.wait_to_read()
   return mod.get_outputs()
 
+def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
+  mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+  mod.bind(for_training=False,
+           data_shapes=[('data', data_shape)],
+           label_shapes=[('softmax_label', label_shape)])
+  mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
+  mod.forward(batch, is_train=False)
+  for output in mod.get_outputs():
+    output.wait_to_read()
+  return mod.get_outputs()
+
 def check_quantize(sym, data_shape):
   fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc')
   sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
@@ -99,6 +110,7 @@ def check_quantize(sym, data_shape):
   quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape)
   for i in range(len(ref_out)):
     assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)
+  check_qsym_dummy_forward(qsym, batch, data_shape, label_shape)
 
 
 @with_seed()