You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/12 04:51:25 UTC
[incubator-mxnet] branch master updated: Fix quantized graphpass
bug (#11937)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d076d10 Fix quantized graphpass bug (#11937)
d076d10 is described below
commit d076d10b1372308336d1008858401d1c59ed5826
Author: Xinyu Chen <xi...@intel.com>
AuthorDate: Sun Aug 12 12:51:13 2018 +0800
Fix quantized graphpass bug (#11937)
* fix quantized graphpass bug
* add residual quantization testcase
* handle dtype and backend issues
---
src/operator/quantization/quantize_graph_pass.cc | 3 +
tests/python/quantization/test_quantization.py | 106 +++++++++++++++++++++++
2 files changed, 109 insertions(+)
diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index 5376a0e..1083486 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -221,6 +221,9 @@ Graph QuantizeGraph(Graph &&src) {
new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0});
mirror_map[e.node.get()] = std::move(dequantize_node);
+ } else if (mirror_node->op() != nullptr
+ && mirror_node->op()->name == "_contrib_quantize") {
+ new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version});
} else {
new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, e.version});
}
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index b73a2a4..369a923 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -396,6 +396,17 @@ def get_fp32_sym():
out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
return sym
+def get_fp32_residual():
+ data = mx.sym.Variable('data')
+ conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
+ no_bias=True, name='conv')
+ bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn')
+ act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu')
+ pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
+ fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc')
+ sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False,
+ out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
+ return sym
@with_seed()
def test_quantize_model():
@@ -463,6 +474,101 @@ def test_quantize_model():
for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)
+@with_seed()
+def test_quantize_residual_unit():
+ def check_quantize_model(qdtype):
+ if is_test_for_native_cpu():
+ print('skipped testing quantized_residual_unit for native cpu since it is not supported yet')
+ return
+ elif qdtype == 'int8' and is_test_for_mkldnn():
+ print('skipped testing quantized_residual_unit for mkldnn cpu int8 since it is not supported yet')
+ return
+ elif qdtype == 'uint8' and is_test_for_gpu():
+ print('skipped testing quantized_residual_unit for gpu uint8 since it is not supported yet')
+ return
+
+ def check_params(params, qparams, qsym=None):
+ if qsym is None:
+ assert len(params) == len(qparams)
+ for k, v in params.items():
+ assert k in qparams
+ assert same(v.asnumpy(), qparams[k].asnumpy())
+ else:
+ qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params)
+ assert len(qparams) == len(qparams_ground_truth)
+ for k, v in qparams_ground_truth.items():
+ assert k in qparams
+ assert same(v.asnumpy(), qparams[k].asnumpy())
+
+ def check_qsym_calibrated(qsym):
+ attrs = qsym.attr_dict()
+ for k, v in attrs.items():
+ if k.find('requantize_') != -1:
+ assert 'min_calib_range' in v
+ assert 'max_calib_range' in v
+
+ def check_qsym_qdtype(qsym, qdtype):
+ attrs = qsym.attr_dict()
+ for k, v in attrs.items():
+ if k.find('_quantize') != -1:
+ assert 'out_type' in v
+ assert v['out_type'] == qdtype
+
+ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape):
+ mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+ mod.bind(for_training=False,
+ data_shapes=[('data', data_shape)],
+ label_shapes=[('softmax_label', label_shape)])
+ mod.set_params(qarg_params, qaux_params)
+ data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
+ batch = mx.io.DataBatch(data, [])
+ mod.forward(batch, is_train=False)
+ for output in mod.get_outputs():
+ output.wait_to_read()
+
+
+ sym = get_fp32_residual()
+ mod = Module(symbol=sym)
+ batch_size = 4
+ data_shape = (batch_size, 4, 10, 10)
+ label_shape = (batch_size, 10)
+ mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)])
+ mod.init_params()
+ arg_params, aux_params = mod.get_params()
+ excluded_sym_names = []
+ if mx.current_context() == mx.cpu():
+ excluded_sym_names += ['fc']
+ qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+ arg_params=arg_params,
+ aux_params=aux_params,
+ excluded_sym_names=excluded_sym_names,
+ ctx=mx.current_context(),
+ quantized_dtype=qdtype,
+ calib_mode='none')
+ check_params(arg_params, qarg_params, qsym)
+ check_params(aux_params, qaux_params)
+ check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape)
+
+ calib_data = mx.nd.random.uniform(shape=data_shape)
+ calib_data = NDArrayIter(data=calib_data)
+ calib_data = DummyIter(calib_data)
+ qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
+ arg_params=arg_params,
+ aux_params=aux_params,
+ excluded_sym_names=excluded_sym_names,
+ ctx=mx.current_context(),
+ quantized_dtype=qdtype,
+ calib_mode='naive',
+ calib_data=calib_data,
+ num_calib_examples=20)
+ check_params(arg_params, qarg_params, qsym)
+ check_params(aux_params, qaux_params)
+ check_qsym_calibrated(qsym)
+ check_qsym_qdtype(qsym, qdtype)
+ check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape)
+
+ for qdtype in ['int8', 'uint8']:
+ check_quantize_model(qdtype)
@with_seed()
def test_quantize_sym_with_calib():