You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/02/10 11:43:49 UTC

[GitHub] [incubator-mxnet] PawelGlomski-Intel opened a new pull request #20889: INC quantization tutorial

PawelGlomski-Intel opened a new pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889


   ## Description ##
   This change extends MKLDNN quantization tutorial with quantization and auto-tuning using [Intel® Neural Compressor](https://github.com/intel/neural-compressor).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] bgawrych commented on pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
bgawrych commented on pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889#issuecomment-1044012636


   @mxnet-bot run ci [unix-cpu, website]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] bgawrych merged pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
bgawrych merged pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889#issuecomment-1042935703


   Jenkins CI successfully triggered : [unix-gpu, website, unix-cpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889#issuecomment-1044012796


   Jenkins CI successfully triggered : [unix-cpu, website]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] bgawrych commented on pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
bgawrych commented on pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889#issuecomment-1042935602


   @mxnet-bot run ci [unix-cpu, unix-gpu, website]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889#issuecomment-1034823141


   Hey @PawelGlomski-Intel , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [unix-gpu, unix-cpu, windows-cpu, sanity, centos-gpu, website, miscellaneous, clang, centos-cpu, edge, windows-gpu]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] bartekkuncer commented on a change in pull request #20889: INC quantization tutorial

Posted by GitBox <gi...@apache.org>.
bartekkuncer commented on a change in pull request #20889:
URL: https://github.com/apache/incubator-mxnet/pull/20889#discussion_r804654986



##########
File path: example/quantization_inc/BERT_MRPC/details.py
##########
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import time
+import logging
+import random
+import itertools
+import collections
+import numpy as np
+import numpy.ma as ma
+import gluonnlp as nlp
+import mxnet as mx
+
+from mxnet.contrib.quantization import quantize_net_v2
+from gluonnlp.model import BERTClassifier as BERTModel
+from gluonnlp.data import BERTTokenizer
+from gluonnlp.data import GlueMRPC
+from functools import partial
+
+nlp.utils.check_version('0.9', warning_only=True)
+logging.basicConfig()
+logging = logging.getLogger()
+
+CTX = mx.cpu()
+
+TASK_NAME = 'MRPC'
+MODEL_NAME = 'bert_12_768_12'
+DATASET_NAME = 'book_corpus_wiki_en_uncased'
+BACKBONE, VOCAB = nlp.model.get_model(name=MODEL_NAME,
+                                      dataset_name=DATASET_NAME,
+                                      pretrained=True,
+                                      ctx=CTX,
+                                      use_decoder=False,
+                                      use_classifier=False)
+TOKENIZER = BERTTokenizer(VOCAB, lower=('uncased' in DATASET_NAME))
+MAX_LEN = int(512)
+
+LABEL_DTYPE = 'int32'
+CLASS_LABELS = ['0', '1']
+NUM_CLASSES = len(CLASS_LABELS)
+LABEL_MAP = {l: i for (i, l) in enumerate(CLASS_LABELS)}
+
+BATCH_SIZE = int(32)
+LR = 3e-5
+EPSILON = 1e-6
+LOSS_FUNCTION = mx.gluon.loss.SoftmaxCELoss()
+EPOCH_NUMBER = int(4)
+TRAINING_STEPS = None  # if specified, epochs will be ignored
+ACCUMULATE = int(1)  # >= 1
+WARMUP_RATIO = 0.1
+EARLY_STOP = None
+TRAINING_LOG_INTERVAL = 10*ACCUMULATE
+
+METRIC = mx.metric.Accuracy
+
+
+class FixedDataset:
+    def __init__(self, dataset):
+        self.dataset = dataset
+
+    def __getitem__(self, idx):
+        input_ids, segment_ids, valid_length, label = self.dataset[idx]
+        return input_ids, segment_ids, np.float32(valid_length), label
+
+    def __len__(self):
+        return len(self.dataset)
+
+
+def truncate_seqs_equal(seqs, max_len):
+    assert isinstance(seqs, list)
+    lens = list(map(len, seqs))
+    if sum(lens) <= max_len:
+        return seqs
+
+    lens = ma.masked_array(lens, mask=[0] * len(lens))
+    while True:
+        argmin = lens.argmin()
+        minval = lens[argmin]
+        quotient, remainder = divmod(max_len, len(lens) - sum(lens.mask))
+        if minval <= quotient:  # Ignore values that don't need truncation
+            lens.mask[argmin] = 1
+            max_len -= minval
+        else:  # Truncate all
+            lens.data[~lens.mask] = [
+                quotient + 1 if i < remainder else quotient for i in range(lens.count())
+            ]
+            break
+    seqs = [seq[:length] for (seq, length) in zip(seqs, lens.data.tolist())]
+    return seqs
+
+
+def concat_sequences(seqs, separators, seq_mask=0, separator_mask=1):
+    assert isinstance(seqs, collections.abc.Iterable) and len(seqs) > 0
+    assert isinstance(seq_mask, (list, int))
+    assert isinstance(separator_mask, (list, int))
+    concat = sum((seq + sep for sep, seq in itertools.zip_longest(separators, seqs, fillvalue=[])),
+                 [])
+    segment_ids = sum(
+        ([i] * (len(seq) + len(sep))
+         for i, (sep, seq) in enumerate(itertools.zip_longest(separators, seqs, fillvalue=[]))),
+        [])
+    if isinstance(seq_mask, int):
+        seq_mask = [[seq_mask] * len(seq) for seq in seqs]
+    if isinstance(separator_mask, int):
+        separator_mask = [[separator_mask] * len(sep) for sep in separators]
+
+    p_mask = sum((s_mask + mask for sep, seq, s_mask, mask in itertools.zip_longest(
+        separators, seqs, seq_mask, separator_mask, fillvalue=[])), [])
+    return concat, segment_ids, p_mask
+
+
+def convert_examples_to_features(example, is_test):
+    truncate_length = MAX_LEN if is_test else MAX_LEN - 3
+    if not is_test:
+        example, label = example[:-1], example[-1]
+        label = np.array([LABEL_MAP[label]], dtype=LABEL_DTYPE)
+
+    tokens_raw = [TOKENIZER(l) for l in example]
+    tokens_trun = truncate_seqs_equal(tokens_raw, truncate_length)
+    tokens_trun[0] = [VOCAB.cls_token] + tokens_trun[0]
+    tokens, segment_ids, _ = concat_sequences(tokens_trun, [[VOCAB.sep_token]] * len(tokens_trun))
+    input_ids = VOCAB[tokens]
+    valid_length = len(input_ids)
+    if not is_test:
+        return input_ids, segment_ids, valid_length, label
+    else:
+        return input_ids, segment_ids, valid_length
+
+
+def preprocess_data():
+    def preprocess_dataset(segment):
+        is_calib = segment == 'calib'
+        is_test = segment == 'test'
+        segment = 'train' if is_calib else segment
+        trans = partial(convert_examples_to_features, is_test=is_test)
+        batchify = [nlp.data.batchify.Pad(axis=0, pad_val=VOCAB[VOCAB.padding_token]),  # 0. input
+                    nlp.data.batchify.Pad(axis=0, pad_val=0),                           # 1. segment
+                    nlp.data.batchify.Stack()]                                          # 2. length
+        batchify += [] if is_test else [nlp.data.batchify.Stack(LABEL_DTYPE)]           # 3. label
+        batchify_fn = nlp.data.batchify.Tuple(*batchify)
+
+        dataset = list(map(trans, GlueMRPC(segment)))
+        random.shuffle(dataset)
+        dataset = mx.gluon.data.SimpleDataset(dataset)
+
+        batch_arg = {}
+        if segment == 'train' and not is_calib:
+            seq_len = dataset.transform(lambda *args: args[2], lazy=False)
+            sampler = nlp.data.sampler.FixedBucketSampler(seq_len, BATCH_SIZE, num_buckets=10,
+                                                          ratio=0, shuffle=True)
+            batch_arg['batch_sampler'] = sampler
+        else:
+            batch_arg['batch_size'] = BATCH_SIZE
+
+        dataset = FixedDataset(dataset)
+        return mx.gluon.data.DataLoader(dataset, num_workers=0, shuffle=False,
+                                        batchify_fn=batchify_fn, **batch_arg)
+
+    return (preprocess_dataset(seg) for seg in ['train', 'dev', 'calib'])
+
+
+def log_train(batch_id, batch_num, metric, step_loss, epoch_id, learning_rate):
+    """Generate and print out the log message for training. """
+    metric_nm, metric_val = metric.get()
+    if not isinstance(metric_nm, list):
+        metric_nm, metric_val = [metric_nm], [metric_val]
+
+    train_str = '[Epoch %d Batch %d/%d] loss=%.4f, lr=%.7f, metrics:' + \
+                ','.join([i + ':%.4f' for i in metric_nm])
+    logging.info(train_str, epoch_id, batch_id, batch_num, step_loss / TRAINING_LOG_INTERVAL,
+                 learning_rate, *metric_val)
+
+
+def finetune(model, train_dataloader, dev_dataloader, output_dir_path):
+    model.classifier.initialize(init=mx.init.Normal(0.02), ctx=CTX)
+
+    all_model_params = model.collect_params()
+    optimizer_params = {'learning_rate': LR, 'epsilon': EPSILON, 'wd': 0.01}
+    trainer = mx.gluon.Trainer(all_model_params, 'bertadam', optimizer_params,
+                               update_on_kvstore=False)
+    epochs = 9999 if TRAINING_STEPS else EPOCH_NUMBER
+    batches_in_epoch = TRAINING_STEPS if TRAINING_STEPS else int(len(train_dataloader) / ACCUMULATE)
+    num_train_steps = batches_in_epoch * epochs
+
+    logging.info('training steps=%d', num_train_steps)
+    num_warmup_steps = int(num_train_steps * WARMUP_RATIO)
+
+    # Do not apply weight decay on LayerNorm and bias terms
+    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
+        v.wd_mult = 0.0
+
+    # Collect differentiable parameters
+    params = [p for p in all_model_params.values() if p.grad_req != 'null']
+
+    # Set grad_req if gradient accumulation is required
+    if ACCUMULATE > 1:
+        for p in params:
+            p.grad_req = 'add'
+
+    # track best eval score
+    metric = METRIC()
+    metric_history = []
+    best_metric = None
+    patience = EARLY_STOP
+
+    step_num = 0
+    epoch_id = 0
+    finish_flag = False
+    while epoch_id < epochs and not finish_flag and (not EARLY_STOP or patience > 0):
+        epoch_id += 1
+        metric.reset()
+        step_loss = 0
+        tic = time.time()
+        all_model_params.zero_grad()
+
+        for batch_id, batch in enumerate(train_dataloader):
+            batch_id += 1
+            # learning rate schedule
+            if step_num < num_warmup_steps:
+                new_lr = LR * step_num / num_warmup_steps
+            else:
+                non_warmup_steps = step_num - num_warmup_steps
+                offset = non_warmup_steps / (num_train_steps - num_warmup_steps)
+                new_lr = LR - offset * LR
+            trainer.set_learning_rate(new_lr)
+
+            # forward and backward
+            with mx.autograd.record():
+                input_ids, segment_ids, valid_length, label = batch
+                input_ids = input_ids.as_in_context(CTX)
+                valid_length = valid_length.as_in_context(CTX).astype('float32')
+                label = label.as_in_context(CTX)
+                out = model(input_ids, segment_ids.as_in_context(CTX), valid_length)
+                ls = LOSS_FUNCTION(out, label).mean()
+                ls.backward()
+
+            # update
+            if ACCUMULATE <= 1 or batch_id % ACCUMULATE == 0:
+                trainer.allreduce_grads()
+                nlp.utils.clip_grad_global_norm(params, 1)
+                trainer.update(ACCUMULATE)
+                step_num += 1
+                if ACCUMULATE > 1:
+                    # set grad to zero for gradient accumulation
+                    all_model_params.zero_grad()
+
+            step_loss += ls.asscalar()
+            label = label.reshape((-1))
+            metric.update([label], [out])
+            if batch_id % TRAINING_LOG_INTERVAL == 0:
+                log_train(batch_id, batches_in_epoch, metric, step_loss, epoch_id,
+                          trainer.learning_rate)
+                step_loss = 0
+            if step_num >= num_train_steps:
+                logging.info('Finish training step: %d', step_num)
+                finish_flag = True
+                break
+        mx.nd.waitall()
+
+        # inference on dev data
+        metric_val = evaluate(model, dev_dataloader)
+        if best_metric is None or metric_val >= best_metric:
+            best_metric = metric_val
+            patience = EARLY_STOP
+        else:
+            if EARLY_STOP is not None:
+                patience -= 1
+        metric_history.append((epoch_id, METRIC().name, metric_val))
+        print('Results of evaluation on dev dataset: {}:{}'.format(METRIC().name, metric_val))
+
+        # save params
+        ckpt_name = 'model_bert_{}_{}.params'.format(TASK_NAME, epoch_id)
+        params_path = (output_dir_path / ckpt_name)
+
+        model.save_parameters(str(params_path))
+        logging.info('params saved in: %s', str(params_path))
+        toc = time.time()
+        logging.info('Time cost=%.2fs', toc - tic)
+
+    # we choose the best model assuming higher score stands for better model quality

Review comment:
       Some comments start from capital letter some from a small one in both *.py files.

##########
File path: docs/python_docs/python/tutorials/performance/backend/mkldnn/mkldnn_quantization.md
##########
@@ -255,4 +258,295 @@ BTW, You can also modify the `min_calib_range` and `max_calib_range` in the JSON
 
 MXNet also supports deploy quantized models with C++. Refer [MXNet C++ Package](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/README.md) for more details.
 
+# Improving accuracy with Intel® Neural Compressor
+
+The accuracy of a model can decrease as a result of quantization. When the accuracy drop is significant, we can try to manually find a better quantization configuration (exclude some layers, try different calibration methods, etc.), but for bigger models this might prove to be a difficult and time consuming task. [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) tries to automate this process using several tuning heuristics, which aim to find the quantization configuration that satisfies the specified accuracy requirement. 
+
+**NOTE:**
+
+Most tuning strategies will try different configurations on an evaluation dataset in order to find out how each layer affects the accuracy of the model. This means that for larger models, it may take a long time to find a solution (as the tuning space is usually larger and the evaluation itself takes longer).
+
+## Installation and Prerequisites
+
+- Install MXNet with MKLDNN enabled as described in the [previous section](#Installation-and-Prerequisites).
+
+- Install Intel® Neural Compressor:
+  
+  Use one of the commands below to install INC (supported python versions are: 3.6, 3.7, 3.8, 3.9):
+
+  ```bash
+  # install stable version from pip
+  pip install neural-compressor
+
+  # install nightly version from pip
+  pip install -i https://test.pypi.org/simple/ neural-compressor
+
+  # install stable version from conda
+  conda install neural-compressor -c conda-forge -c intel
+  ```
+
+## Configuration file
+
+Quantization tuning process can be customized in the yaml configuration file. Below is a simple example:
+
+```yaml
+# cnn.yaml
+
+version: 1.0
+
+model:
+  name: cnn
+  framework: mxnet
+
+quantization:
+  calibration:
+    sampling_size: 160 # number of samples for calibration
+
+tuning:
+  strategy:
+    name: basic
+  accuracy_criterion:
+    relative: 0.01
+  exit_policy:
+    timeout: 0
+  random_seed: 9527
+```
+
+We are using the `basic` strategy, but you could also try out different ones. [Here](https://github.com/intel/neural-compressor/blob/master/docs/tuning_strategies.md) you can find a list of strategies available in INC and details of how they work. You can also add your own strategy if the existing ones do not suit your needs.
+
+Since the value of `timeout` is 0, INC will run until it finds a configuration that satisfies the accuracy criterion and then exit. Depending on the strategy this may not be ideal, as sometimes it would be better to further explore the tuning space to find a superior configuration both in terms of accuracy and speed. To achieve this, we can set a specific `timeout` value, which will tell INC how long (in seconds) it should run.
+
+For more information about the configuration file, see the [template](https://github.com/intel/neural-compressor/blob/master/neural_compressor/template/ptq.yaml) from the official INC repo. Keep in mind that only the `post training quantization` is currently supported for MXNet.
+
+## Model quantization and tuning
+
+In general, Intel® Neural Compressor requires 4 elements in order to run: 
+1. Config file - like the example above
+2. Model to be quantized
+3. Calibration dataloader
+4. Evaluation function - a function that takes a model as an argument and returns the accuracy it achieves on a certain evaluation dataset.
+
+### Quantizing ResNet
+
+The previous sections described how to quantize ResNet using the native MXNet quantization. This example shows how we can achieve the same (with the auto-tuning) using INC.
+
+1. Get the model
+
+```python
+import logging
+import mxnet as mx
+from mxnet.gluon.model_zoo import vision
+
+logging.basicConfig()
+logger = logging.getLogger('logger')
+logger.setLevel(logging.INFO)
+
+batch_shape = (1, 3, 224, 224)
+resnet18 = vision.resnet18_v1(pretrained=True)
+```
+
+2. Prepare the dataset:
+
+```python
+mx.test_utils.download('http://data.mxnet.io/data/val_256_q90.rec', 'data/val_256_q90.rec')
+
+batch_size = 16
+mean_std = {'mean_r': 123.68, 'mean_g': 116.779, 'mean_b': 103.939,
+            'std_r': 58.393, 'std_g': 57.12, 'std_b': 57.375}
+
+data = mx.io.ImageRecordIter(path_imgrec='data/val_256_q90.rec', 
+                             batch_size=batch_size,
+                             data_shape=batch_shape[1:], 
+                             rand_crop=False, 
+                             rand_mirror=False,
+                             shuffle=False, 
+                             **mean_std)
+data.batch_size = batch_size
+```
+
+3. Prepare the evaluation function:
+
+```python
+eval_samples = batch_size*10
+
+def eval_func(model):
+    data.reset()
+    metric = mx.metric.Accuracy()
+    for i, batch in enumerate(data):
+        if i * batch_size >= eval_samples:
+            break
+        x = batch.data[0].as_in_context(mx.cpu())
+        label = batch.label[0].as_in_context(mx.cpu())
+        outputs = model.forward(x)
+        metric.update(label, outputs)
+    return metric.get()[1]
+```
+
+4. Run Intel® Neural Compressor:
+
+```python
+from neural_compressor.experimental import Quantization
+quantizer = Quantization("./cnn.yaml")
+quantizer.model = resnet18
+quantizer.calib_dataloader = data
+quantizer.eval_func = eval_func
+qnet = quantizer.fit().model
+```
+
+Since this model already achieves good accuracy using native quantization (less than 1% accuracy drop), for the given configuration file, INC will end on the first configuration, quantizing all layers using `naive` calibration mode for each. To see the true potential of INC, we need a model which suffers from a larger accuracy drop after quantization.
+
+### Quantizing BERT
+
+This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode).  To simplify the code, model and task specific boilerplate has been moved to the `details.py` file.

Review comment:
       ```suggestion
   This example shows how to use INC to quantize BERT-base for MRPC. In this case, the native MXNet quantization usually introduce a significant accuracy drop (2% - 5% using `naive` calibration mode). To simplify the code, model and task specific boilerplate has been moved to the `details.py` file.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org