You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/12/10 03:11:53 UTC

[GitHub] [incubator-mxnet] TaoLv commented on a change in pull request #19587: [FEATURE] Restore Quantization API to MXNet

TaoLv commented on a change in pull request #19587:
URL: https://github.com/apache/incubator-mxnet/pull/19587#discussion_r539806803



##########
File path: example/quantization/README.md
##########
@@ -0,0 +1,194 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+<!--- -->
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+<!--- -->
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Model Quantization with Calibration Examples
+
+This folder contains examples of quantizing a FP32 model with Intel® oneAPI Deep Neural Network Library (oneDNN) to (U)INT8 model.
+
+<h2 id="0">Contents</h2>
+
+* [1. Model Quantization with Intel® oneDNN](#1)
+<h2 id="1">Model Quantization with Intel® oneDNN</h2>
+
+Intel® oneDNN supports quantization with subgraph features on Intel® CPU Platform and can bring performance improvements on the [Intel® Xeon® Scalable Platform](https://www.intel.com/content/www/us/en/processors/xeon/scalable/xeon-scalable-platform.html). To apply quantization flow to your project directly, please refer [Optimize custom models with oneDNN backend](#TODO(agrygielski)).

Review comment:
       Remove the TODO?

##########
File path: example/quantization/README.md
##########
@@ -0,0 +1,194 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+<!--- -->
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+<!--- -->
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Model Quantization with Calibration Examples
+
+This folder contains examples of quantizing a FP32 model with Intel® oneAPI Deep Neural Network Library (oneDNN) to (U)INT8 model.
+
+<h2 id="0">Contents</h2>
+
+* [1. Model Quantization with Intel® oneDNN](#1)
+<h2 id="1">Model Quantization with Intel® oneDNN</h2>
+
+Intel® oneDNN supports quantization with subgraph features on Intel® CPU Platform and can bring performance improvements on the [Intel® Xeon® Scalable Platform](https://www.intel.com/content/www/us/en/processors/xeon/scalable/xeon-scalable-platform.html). To apply quantization flow to your project directly, please refer [Optimize custom models with oneDNN backend](#TODO(agrygielski)).
+
+```
+usage: python imagenet_gen_qsym_onednn.py [-h] [--model MODEL] [--epoch EPOCH]
+                                          [--no-pretrained] [--batch-size BATCH_SIZE]
+                                          [--calib-dataset CALIB_DATASET]
+                                          [--image-shape IMAGE_SHAPE]
+                                          [--data-nthreads DATA_NTHREADS]
+                                          [--num-calib-batches NUM_CALIB_BATCHES]
+                                          [--exclude-first-conv] [--shuffle-dataset]
+                                          [--calib-mode CALIB_MODE]
+                                          [--quantized-dtype {auto,int8,uint8}]
+                                          [--quiet]
+
+Generate a calibrated quantized model from a FP32 model with Intel oneDNN support
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --model MODEL         model to be quantized. If no-pretrained is set then
+                        model must be provided to `model` directory in the same path
+                        as this python script, default is `resnet50_v1`
+  --epoch EPOCH         number of epochs, default is `0`
+  --no-pretrained       If enabled, will not download pretrained model from
+                        MXNet or Gluon-CV modelzoo, default is `False`
+  --batch-size BATCH_SIZE
+                        batch size to be used when calibrating model, default is `32`
+  --calib-dataset CALIB_DATASET
+                        path of the calibration dataset, default is `data/val_256_q90.rec`
+  --image-shape IMAGE_SHAPE
+                        number of channels, height and width of input image separated by comma,
+                        default is `3,224,224`
+  --data-nthreads DATA_NTHREADS
+                        number of threads for data loading, default is `0`
+  --num-calib-batches NUM_CALIB_BATCHES
+                        number of batches for calibration, default is `10`
+  --exclude-first-conv  excluding quantizing the first conv layer since the
+                        input data may have negative value which doesn't
+                        support at moment
+  --shuffle-dataset     shuffle the calibration dataset
+  --calib-mode CALIB_MODE
+                        calibration mode used for generating calibration table
+                        for the quantized symbol; supports 1. none: no
+                        calibration will be used. The thresholds for
+                        quantization will be calculated on the fly. This will
+                        result in inference speed slowdown and loss of
+                        accuracy in general. 2. naive: simply take min and max
+                        values of layer outputs as thresholds for
+                        quantization. In general, the inference accuracy
+                        worsens with more examples used in calibration. It is
+                        recommended to use `entropy` mode as it produces more
+                        accurate inference results. 3. entropy: calculate KL
+                        divergence of the fp32 output and quantized output for
+                        optimal thresholds. This mode is expected to produce
+                        the best inference accuracy of all three kinds of
+                        quantized models if the calibration dataset is
+                        representative enough of the inference dataset.
+                        default is `entropy`
+  --quantized-dtype {auto,int8,uint8}
+                        quantization destination data type for input data,
+                        default is `auto`
+  --quiet               suppress most of log
+```
+
+A new benchmark script `launch_inference_onednn.sh` has been designed to launch performance benchmark for float32 or int8 image-classification models with Intel® oneDNN.

Review comment:
       > `flaot32 or int8`
   
   I see `FP32` in the first paragraph. Need unify the term through the tutorial.

##########
File path: example/quantization/imagenet_gen_qsym_onednn.py
##########
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import re
+import sys
+from inspect import currentframe, getframeinfo
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet.contrib.quantization import quantize_net
+from mxnet.gluon.data import DataLoader
+from mxnet.gluon.data.vision import transforms
+from mxnet.gluon.model_zoo.vision import get_model
+
+sys.path.append('../..')
+from tools.rec2idx import IndexCreator
+
+
+def download_calib_dataset(dataset_url, calib_dataset, logger=None):
+    if logger is not None:
+        logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset))
+    mx.test_utils.download(dataset_url, calib_dataset)
+
+def get_from_gluon(model_name, classes=1000, logger=None):
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    model_path = os.path.join(dir_path, 'model')
+    if logger is not None:
+        logger.info('Converting model from Gluon-CV ModelZoo %s... into path %s' % (model_name, model_path))
+    net = get_model(name=model_name, classes=classes, pretrained=True)
+    prefix = os.path.join(model_path, model_name)
+    return net, prefix
+
+def regex_find_excluded_symbols(patterns_dict, model_name):
+    for key, value in patterns_dict.items():
+        if re.search(key, model_name) is not None:
+            return value
+    return None
+
+def get_exclude_symbols(model_name, exclude_first_conv):
+    # Grouped supported models at the time of commit:
+    # alexnet
+    # densenet121, densenet161
+    # densenet169, densenet201
+    # inceptionv3
+    # mobilenet0.25, mobilenet0.5, mobilenet0.75, mobilenet1.0,
+    # mobilenetv2_0.25, mobilenetv2_0.5, mobilenetv2_0.75, mobilenetv2_1.0
+    # resnet101_v1, resnet152_v1, resnet18_v1, resnet34_v1, resnet50_v1
+    # resnet101_v2, resnet152_v2, resnet18_v2, resnet34_v2, resnet50_v2
+    # squeezenet1.0, squeezenet1.1
+    # vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
+    exclude_symbol_regex = {
+        'mobilenet[^v]': ['mobilenet_hybridsequential0_flatten0_flatten0', 'mobilenet_hybridsequential0_globalavgpool2d0_fwd'],
+        'mobilenetv2': ['mobilenetv2_hybridsequential1_flatten0_flatten0'],
+        # resnetv2_hybridsequential0_hybridsequential0_bottleneckv20_batchnorm0_fwd is excluded for the sake of accuracy
+        'resnet.*v2': ['resnetv2_hybridsequential0_flatten0_flatten0', 'resnetv2_hybridsequential0_hybridsequential0_bottleneckv20_batchnorm0_fwd'],
+        'squeezenet1': ['squeezenet_hybridsequential1_flatten0_flatten0'],
+    }
+    excluded_sym_names = regex_find_excluded_symbols(exclude_symbol_regex, model_name)
+    if excluded_sym_names is None:
+        excluded_sym_names = []
+    if exclude_first_conv:
+        first_conv_regex = {
+            'alexnet': ['alexnet_hybridsequential0_conv2d0_fwd'],
+            'densenet': ['densenet_hybridsequential0_conv2d0_fwd'],
+            'inceptionv3': ['inception3_hybridsequential0_hybridsequential0_conv2d0_fwd'],
+            'mobilenet[^v]': ['mobilenet_hybridsequential0_conv2d0_fwd'],
+            'mobilenetv2': ['mobilenetv2_hybridsequential0_conv2d0_fwd'],
+            'resnet.*v1': ['resnetv1_hybridsequential0_conv2d0_fwd'],
+            'resnet.*v2': ['resnetv2_hybridsequential0_conv2d0_fwd'],
+            'squeezenet1': ['squeezenet_hybridsequential0_conv2d0_fwd'],
+            'vgg': ['vgg_hybridsequential0_conv2d0_fwd'],
+        }
+        excluded_first_conv_sym_names = regex_find_excluded_symbols(first_conv_regex, model_name)
+        if excluded_first_conv_sym_names is None:
+            raise ValueError('Currently, model %s is not supported in this script' % model_name)
+        excluded_sym_names += excluded_first_conv_sym_names
+    return excluded_sym_names
+
+
+

Review comment:
       It looks like too many blank lines.

##########
File path: example/quantization/imagenet_gen_qsym_onednn.py
##########
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import re
+import sys
+from inspect import currentframe, getframeinfo
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet.contrib.quantization import quantize_net
+from mxnet.gluon.data import DataLoader
+from mxnet.gluon.data.vision import transforms
+from mxnet.gluon.model_zoo.vision import get_model
+
+sys.path.append('../..')
+from tools.rec2idx import IndexCreator
+
+
+def download_calib_dataset(dataset_url, calib_dataset, logger=None):
+    if logger is not None:
+        logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset))
+    mx.test_utils.download(dataset_url, calib_dataset)
+
+def get_from_gluon(model_name, classes=1000, logger=None):
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    model_path = os.path.join(dir_path, 'model')
+    if logger is not None:
+        logger.info('Converting model from Gluon-CV ModelZoo %s... into path %s' % (model_name, model_path))
+    net = get_model(name=model_name, classes=classes, pretrained=True)
+    prefix = os.path.join(model_path, model_name)
+    return net, prefix
+
+def regex_find_excluded_symbols(patterns_dict, model_name):
+    for key, value in patterns_dict.items():
+        if re.search(key, model_name) is not None:
+            return value
+    return None
+
+def get_exclude_symbols(model_name, exclude_first_conv):
+    # Grouped supported models at the time of commit:

Review comment:
       Better to use doc string?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org