You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/12 04:51:25 UTC

[incubator-mxnet] branch master updated: Fix quantized graphpass bug (#11937)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d076d10  Fix quantized graphpass bug (#11937)
d076d10 is described below

commit d076d10b1372308336d1008858401d1c59ed5826
Author: Xinyu Chen <xi...@intel.com>
AuthorDate: Sun Aug 12 12:51:13 2018 +0800

    Fix quantized graphpass bug (#11937)
    
    * fix quantized graphpass bug
    
    * add residual quantization testcase
    
    * handle dtype and backend issues
---
 src/operator/quantization/quantize_graph_pass.cc |   3 +
 tests/python/quantization/test_quantization.py   | 106 +++++++++++++++++++++++
 2 files changed, 109 insertions(+)

diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index 5376a0e..1083486 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -221,6 +221,9 @@ Graph QuantizeGraph(Graph &&src) {
 
           new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0});
           mirror_map[e.node.get()] = std::move(dequantize_node);
+        } else if (mirror_node->op() != nullptr
+                   && mirror_node->op()->name == "_contrib_quantize") {
+          new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version});
         } else {
           new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, e.version});
         }
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index b73a2a4..369a923 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -396,6 +396,17 @@ def get_fp32_sym():
                                out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
     return sym
 
+def get_fp32_residual():
+    data = mx.sym.Variable('data')
+    conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
+                              no_bias=True, name='conv')
+    bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn')
+    act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu')
+    pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
+    fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc')
+    sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False,
+                               out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
+    return sym 
 
 @with_seed()
 def test_quantize_model():
@@ -463,6 +474,101 @@ def test_quantize_model():
     for qdtype in ['int8', 'uint8']:
         check_quantize_model(qdtype)
 
+@with_seed()
+def test_quantize_residual_unit():
+    def check_quantize_model(qdtype):
+        if is_test_for_native_cpu():
+            print('skipped testing quantized_residual_unit for native cpu since it is not supported yet')
+            return
+        elif qdtype == 'int8' and is_test_for_mkldnn():
+            print('skipped testing quantized_residual_unit for mkldnn cpu int8 since it is not supported yet')
+            return
+        elif qdtype == 'uint8' and is_test_for_gpu():
+            print('skipped testing quantized_residual_unit for gpu uint8 since it is not supported yet')
+            return
+
+        def check_params(params, qparams, qsym=None):
+            if qsym is None:
+                assert len(params) == len(qparams)
+                for k, v in params.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+            else:
+                qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params)
+                assert len(qparams) == len(qparams_ground_truth)
+                for k, v in qparams_ground_truth.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+
+        def check_qsym_calibrated(qsym):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('requantize_') != -1:
+                    assert 'min_calib_range' in v
+                    assert 'max_calib_range' in v
+
+        def check_qsym_qdtype(qsym, qdtype):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('_quantize') != -1:
+                    assert 'out_type' in v
+                    assert v['out_type'] == qdtype
+
+        def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape):
+            mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+            mod.bind(for_training=False,
+                     data_shapes=[('data', data_shape)],
+                     label_shapes=[('softmax_label', label_shape)])
+            mod.set_params(qarg_params, qaux_params)
+            data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
+            batch = mx.io.DataBatch(data, [])
+            mod.forward(batch, is_train=False)
+            for output in mod.get_outputs():
+                output.wait_to_read()
+             
+
+        sym = get_fp32_residual()
+        mod = Module(symbol=sym)
+        batch_size = 4
+        data_shape = (batch_size, 4, 10, 10)
+        label_shape = (batch_size, 10)
+        mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)])
+        mod.init_params()
+        arg_params, aux_params = mod.get_params()
+        excluded_sym_names = []
+        if mx.current_context() == mx.cpu():
+           excluded_sym_names += ['fc']
+        qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+                                                                         arg_params=arg_params,
+                                                                         aux_params=aux_params,
+                                                                         excluded_sym_names=excluded_sym_names,
+                                                                         ctx=mx.current_context(),
+                                                                         quantized_dtype=qdtype,
+                                                                         calib_mode='none')
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape)
+
+        calib_data = mx.nd.random.uniform(shape=data_shape)
+        calib_data = NDArrayIter(data=calib_data)
+        calib_data = DummyIter(calib_data)
+        qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+                                                                         arg_params=arg_params,
+                                                                         aux_params=aux_params,
+                                                                         excluded_sym_names=excluded_sym_names,
+                                                                         ctx=mx.current_context(),
+                                                                         quantized_dtype=qdtype,
+                                                                         calib_mode='naive',
+                                                                         calib_data=calib_data,
+                                                                         num_calib_examples=20)
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_calibrated(qsym)
+        check_qsym_qdtype(qsym, qdtype)
+        check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape)
+
+    for qdtype in ['int8', 'uint8']:
+        check_quantize_model(qdtype)
 
 @with_seed()
 def test_quantize_sym_with_calib():