You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2019/03/04 21:50:47 UTC
[incubator-mxnet] branch master updated: Add int8 data loader
(#14123)
This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new df7771b Add int8 data loader (#14123)
df7771b is described below
commit df7771b4d29f6283f9d7d9d721378fc6be5cf9d8
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Tue Mar 5 05:50:27 2019 +0800
Add int8 data loader (#14123)
* Enable int8 data layer
Change-Id: I3d97ef80b7466d7555f4970e24f02e8dfba6be2b
* fix lint
* Add parameter description
* Fix imagenet_inference.py
* Allow quantize_v2 to accept int8
* make float32 default
---
docs/api/perl/io.md | 1 +
docs/api/python/io/io.md | 1 +
example/quantization/imagenet_inference.py | 84 +++++++++++++----
include/mxnet/c_api.h | 4 +-
perl-package/AI-MXNet/lib/AI/MXNet/IO.pm | 1 +
src/io/iter_image_recordio_2.cc | 55 +++++++++--
.../quantization/mkldnn/mkldnn_quantize_v2-inl.h | 31 ++++--
src/operator/quantization/quantize_v2-inl.h | 105 ++++++++++++---------
src/operator/quantization/quantize_v2.cc | 6 ++
tests/python/train/test_dtype.py | 29 ++++++
10 files changed, 236 insertions(+), 81 deletions(-)
diff --git a/docs/api/perl/io.md b/docs/api/perl/io.md
index be49764..ca3b0f1 100644
--- a/docs/api/perl/io.md
+++ b/docs/api/perl/io.md
@@ -69,6 +69,7 @@ Then we can call `$mod->fit($nd_iter, num_epoch=>2)` to train `loss` by 2 epochs
mx->io->NDArrayIter
mx->io->CSVIter
mx->io->ImageRecordIter
+mx->io->ImageRecordInt8Iter
mx->io->ImageRecordUInt8Iter
mx->io->MNISTIter
mx->recordio->MXRecordIO
diff --git a/docs/api/python/io/io.md b/docs/api/python/io/io.md
index c0dc8d1..13a6121 100644
--- a/docs/api/python/io/io.md
+++ b/docs/api/python/io/io.md
@@ -75,6 +75,7 @@ A detailed tutorial is available at
io.CSVIter
io.LibSVMIter
io.ImageRecordIter
+ io.ImageRecordInt8Iter
io.ImageRecordUInt8Iter
io.MNISTIter
recordio.MXRecordIO
diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py
index 0725165..47e2063 100644
--- a/example/quantization/imagenet_inference.py
+++ b/example/quantization/imagenet_inference.py
@@ -19,6 +19,7 @@ import argparse
import logging
import os
import time
+import numpy as np
import mxnet as mx
from mxnet import nd
from mxnet.contrib.quantization import *
@@ -98,7 +99,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
logger.info(m.get())
-def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
+def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger=None):
# get mod
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, symbol_file)
@@ -106,14 +107,28 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
logger.info('Loading symbol from file %s' % symbol_file_path)
sym = mx.sym.load(symbol_file_path)
mod = mx.mod.Module(symbol=sym, context=ctx)
- mod.bind(for_training = False,
- inputs_need_grad = False,
- data_shapes = [('data', (batch_size,)+data_shape)])
+ if data_layer_type == "int8":
+ dshape = mx.io.DataDesc(name='data', shape=(
+ batch_size,) + data_shape, dtype=np.int8)
+ elif data_layer_type == 'uint8':
+ dshape = mx.io.DataDesc(name='data', shape=(
+ batch_size,) + data_shape, dtype=np.uint8)
+ else: # float32
+ dshape = mx.io.DataDesc(name='data', shape=(
+ batch_size,) + data_shape, dtype=np.float32)
+ mod.bind(for_training=False,
+ inputs_need_grad=False,
+ data_shapes=[dshape])
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
# get data
- data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
- batch = mx.io.DataBatch(data, []) # empty label
+ if data_layer_type == "float32":
+ data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx, dtype=data_layer_type)
+ for _, shape in mod.data_shapes]
+ else:
+ data = [mx.nd.full(shape=shape, val=127, ctx=ctx, dtype=data_layer_type)
+ for _, shape in mod.data_shapes]
+ batch = mx.io.DataBatch(data, []) # empty label
# run
dry_run = 5 # use 5 iterations to warm up
@@ -152,6 +167,9 @@ if __name__ == '__main__':
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
+ parser.add_argument('--data-layer-type', type=str, default="float32",
+ choices=['float32', 'int8', 'uint8'],
+ help='data type for data layer')
args = parser.parse_args()
@@ -192,24 +210,52 @@ if __name__ == '__main__':
data_shape = tuple([int(i) for i in image_shape.split(',')])
logger.info('Input data shape = %s' % str(data_shape))
+ data_layer_type = args.data_layer_type
if args.benchmark == False:
dataset = args.dataset
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
logger.info('Dataset for inference: %s' % dataset)
# creating data iterator
- data = mx.io.ImageRecordIter(path_imgrec=dataset,
- label_width=1,
- preprocess_threads=data_nthreads,
- batch_size=batch_size,
- data_shape=data_shape,
- label_name=label_name,
- rand_crop=False,
- rand_mirror=False,
- shuffle=True,
- shuffle_chunk_seed=3982304,
- seed=48564309,
- **combine_mean_std)
+ if data_layer_type == 'int8':
+ data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset,
+ label_width=1,
+ preprocess_threads=data_nthreads,
+ batch_size=batch_size,
+ data_shape=data_shape,
+ label_name=label_name,
+ rand_crop=False,
+ rand_mirror=False,
+ shuffle=args.shuffle_dataset,
+ shuffle_chunk_seed=args.shuffle_chunk_seed,
+ seed=args.shuffle_seed,
+ **combine_mean_std)
+ elif data_layer_type == 'uint8':
+ data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset,
+ label_width=1,
+ preprocess_threads=data_nthreads,
+ batch_size=batch_size,
+ data_shape=data_shape,
+ label_name=label_name,
+ rand_crop=False,
+ rand_mirror=False,
+ shuffle=args.shuffle_dataset,
+ shuffle_chunk_seed=args.shuffle_chunk_seed,
+ seed=args.shuffle_seed,
+ **combine_mean_std)
+ else: #float32
+ data = mx.io.ImageRecordIter(path_imgrec=dataset,
+ label_width=1,
+ preprocess_threads=data_nthreads,
+ batch_size=batch_size,
+ data_shape=data_shape,
+ label_name=label_name,
+ rand_crop=False,
+ rand_mirror=False,
+ shuffle=args.shuffle_dataset,
+ shuffle_chunk_seed=args.shuffle_chunk_seed,
+ seed=args.shuffle_seed,
+ **combine_mean_std)
# loading model
sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
@@ -224,5 +270,5 @@ if __name__ == '__main__':
max_num_examples=num_inference_images, logger=logger)
else:
logger.info('Running model %s for inference' % symbol_file)
- speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, logger)
+ speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, data_layer_type, logger)
logger.info('batch size %2d, image/sec: %f', batch_size, speed)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 76a4995..9a24b75 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1602,8 +1602,8 @@ MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
* \param excluded_symbols op names to be excluded from being quantized
* \param num_offline number of parameters that are quantized offline
* \param offline_params array of c strings representing the names of params quantized offline
- * \param quantized_dtype the quantized destination type for input data.
- * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could.
+ * \param quantized_dtype the quantized destination type for input data
+ * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could
*/
MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle,
const mx_uint num_excluded_symbols,
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
index 297ceb8..19e7cfd 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
@@ -642,6 +642,7 @@ extends 'AI::MXNet::DataIter';
mx->io->CSVIter Returns the CSV file iterator.
mx->io->LibSVMIter Returns the LibSVM iterator which returns data with csr storage type.
mx->io->ImageRecordIter Iterates on image RecordIO files
+ mx->io->ImageRecordInt8Iter Iterating on image RecordIO files
mx->io->ImageRecordUInt8Iter Iterating on image RecordIO files
mx->io->MNISTIter Iterating on the MNIST dataset.
mx->recordio->MXRecordIO Reads/writes RecordIO data format, supporting sequential read and write.
diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc
index 5d5261b..0834dd7 100644
--- a/src/io/iter_image_recordio_2.cc
+++ b/src/io/iter_image_recordio_2.cc
@@ -372,6 +372,7 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
float RGBA_MULT[4] = { 0 };
float RGBA_BIAS[4] = { 0 };
float RGBA_MEAN[4] = { 0 };
+ int16_t RGBA_MEAN_INT[4] = {0};
mshadow::Tensor<cpu, 3, DType>& data = (*data_ptr);
if (!std::is_same<DType, uint8_t>::value) {
RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r;
@@ -387,6 +388,10 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
RGBA_MEAN[1] = normalize_param_.mean_g;
RGBA_MEAN[2] = normalize_param_.mean_b;
RGBA_MEAN[3] = normalize_param_.mean_a;
+ RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r);
+ RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g);
+ RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b);
+ RGBA_MEAN_INT[3] = std::round(normalize_param_.mean_a);
}
}
@@ -408,17 +413,30 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
for (int i = 0; i < res.rows; ++i) {
const uchar* im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
- for (int k = 0; k < n_channels; ++k) {
- RGBA[k] = im_data[swap_indices[k]];
- }
- if (!std::is_same<DType, uint8_t>::value) {
- // normalize/mirror here to avoid memory copies
- // logic from iter_normalize.h, function SetOutImg
+ if (std::is_same<DType, int8_t>::value) {
+ if (meanfile_ready_) {
+ for (int k = 0; k < n_channels; ++k) {
+ RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] -
+ static_cast<int16_t>(std::round(meanimg_[k][i][j])));
+ }
+ } else {
+ for (int k = 0; k < n_channels; ++k) {
+ RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] - RGBA_MEAN_INT[k]);
+ }
+ }
+ } else {
for (int k = 0; k < n_channels; ++k) {
- if (meanfile_ready_) {
- RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
- } else {
- RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
+ RGBA[k] = im_data[swap_indices[k]];
+ }
+ if (!std::is_same<DType, uint8_t>::value) {
+ // normalize/mirror here to avoid memory copies
+ // logic from iter_normalize.h, function SetOutImg
+ for (int k = 0; k < n_channels; ++k) {
+ if (meanfile_ready_) {
+ RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
+ } else {
+ RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
+ }
}
}
}
@@ -795,5 +813,22 @@ the data type instead of ``float``.
.set_body([]() {
return new ImageRecordIter2<uint8_t>();
});
+
+MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter)
+.describe(R"code(Iterating on image RecordIO files
+
+This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as
+the data type instead of ``float``.
+
+)code" ADD_FILELINE)
+.add_arguments(ImageRecParserParam::__FIELDS__())
+.add_arguments(ImageRecordParam::__FIELDS__())
+.add_arguments(BatchParam::__FIELDS__())
+.add_arguments(PrefetcherParam::__FIELDS__())
+.add_arguments(ListDefaultAugParams())
+.set_body([]() {
+ return new ImageRecordIter2<int8_t>();
+ });
+
} // namespace io
} // namespace mxnet
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
index e201d29..d6060e5 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
@@ -123,13 +123,32 @@ static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContex
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const QuantizeV2Param& param = nnvm::get<QuantizeV2Param>(attrs.parsed);
- auto out_type = GetOutputType(param);
- if (out_type == mshadow::kUint8) {
- MKLDNNQuantizeComputeKer<float, uint8_t>(inputs, outputs, param, req);
- } else if (out_type == mshadow::kInt8) {
- MKLDNNQuantizeComputeKer<float, int8_t>(inputs, outputs, param, req);
+ if (inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8) {
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ *outputs[1].data().dptr<float>() = param.min_calib_range.value();
+ *outputs[2].data().dptr<float>() = param.max_calib_range.value();
+ } else {
+ if (inputs[0].dtype() == mshadow::kUint8) {
+ *outputs[1].data().dptr<float>() = 0;
+ *outputs[2].data().dptr<float>() = 255;
+ } else {
+ *outputs[1].data().dptr<float>() = -127;
+ *outputs[2].data().dptr<float>() = 127;
+ }
+ }
+ if (req[0] != kWriteInplace) {
+ const_cast<NDArray&>(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData());
+ MKLDNNStream::Get()->Submit();
+ }
} else {
- LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
+ auto out_type = GetOutputType(param);
+ if (out_type == mshadow::kUint8) {
+ MKLDNNQuantizeComputeKer<float, uint8_t>(inputs, outputs, param, req);
+ } else if (out_type == mshadow::kInt8) {
+ MKLDNNQuantizeComputeKer<float, int8_t>(inputs, outputs, param, req);
+ } else {
+ LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
+ }
}
}
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index 7a09983..e3c4119 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -137,51 +137,67 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
Stream<xpu> *s = ctx.get_stream<xpu>();
const QuantizeV2Param ¶m = nnvm::get<QuantizeV2Param>(attrs.parsed);
auto out_type = GetOutputType(param);
- if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
- if (out_type == mshadow::kUint8) {
- Kernel<quantize_v2_unsigned, xpu>::Launch(
- s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
- outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
- param.max_calib_range.value(), MinValue<uint8_t>(), MaxValue<uint8_t>());
- } else if (out_type == mshadow::kInt8) { // zero-centered quantization
- Kernel<quantize_v2_zero_centered, xpu>::Launch(
- s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
- outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
- param.max_calib_range.value(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+
+ if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) {
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ *outputs[1].dptr<float>() = param.min_calib_range.value();
+ *outputs[2].dptr<float>() = param.max_calib_range.value();
} else {
- LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+ if (inputs[0].type_flag_ == mshadow::kUint8) {
+ *outputs[1].dptr<float>() = 0;
+ *outputs[2].dptr<float>() = 255;
+ } else {
+ *outputs[1].dptr<float>() = -127;
+ *outputs[2].dptr<float>() = 127;
+ }
}
- } else { // model is not calibrated
- mxnet::TShape src_shape, dst_shape;
- const size_t actual_float_size = sizeof(float);
- const size_t temp_reduce_size =
- ConfigReduce<xpu, SrcDType>(s, inputs[0].shape_, mxnet::TShape({1}),
- &src_shape, &dst_shape);
- Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, char>(
- Shape1(2 * actual_float_size + temp_reduce_size), s);
- const int dev_id = ctx.run_ctx.ctx.dev_id;
- TBlob in_min_t(reinterpret_cast<SrcDType *>(temp_space.dptr_), Shape1(1), xpu::kDevMask,
- dev_id);
- TBlob in_max_t(reinterpret_cast<SrcDType *>(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask,
- dev_id);
- Tensor<xpu, 1, char> workspace(temp_space.dptr_ + 2 * actual_float_size,
- Shape1(temp_reduce_size), s);
- broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
- s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
- broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
- s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
- if (out_type == mshadow::kUint8) {
- Kernel<quantize_v2_unsigned, xpu>::Launch(
- s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
- outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
- in_max_t.dptr<float>(), MinValue<uint8_t>(), MaxValue<uint8_t>());
- } else if (out_type == mshadow::kInt8) { // zero-centered quantization
- Kernel<quantize_v2_zero_centered, xpu>::Launch(
- s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
- outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
- in_max_t.dptr<float>(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
- } else {
- LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+ UnaryOp::IdentityCompute<xpu>(attrs, ctx, {inputs[0]}, req, outputs);
+ } else {
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ if (out_type == mshadow::kUint8) {
+ Kernel<quantize_v2_unsigned, xpu>::Launch(
+ s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
+ param.max_calib_range.value(), MinValue<uint8_t>(), MaxValue<uint8_t>());
+ } else if (out_type == mshadow::kInt8) { // zero-centered quantization
+ Kernel<quantize_v2_zero_centered, xpu>::Launch(
+ s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
+ param.max_calib_range.value(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+ } else {
+ LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+ }
+ } else { // model is not calibrated
+ mxnet::TShape src_shape, dst_shape;
+ const size_t actual_float_size = sizeof(float);
+ const size_t temp_reduce_size = ConfigReduce<xpu, SrcDType>(
+ s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape);
+ Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, char>(
+ Shape1(2 * actual_float_size + temp_reduce_size), s);
+ const int dev_id = ctx.run_ctx.ctx.dev_id;
+ TBlob in_min_t(reinterpret_cast<SrcDType *>(temp_space.dptr_), Shape1(1), xpu::kDevMask,
+ dev_id);
+ TBlob in_max_t(reinterpret_cast<SrcDType *>(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask,
+ dev_id);
+ Tensor<xpu, 1, char> workspace(temp_space.dptr_ + 2 * actual_float_size,
+ Shape1(temp_reduce_size), s);
+ broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
+ s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+ broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
+ s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+ if (out_type == mshadow::kUint8) {
+ Kernel<quantize_v2_unsigned, xpu>::Launch(
+ s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
+ in_max_t.dptr<float>(), MinValue<uint8_t>(), MaxValue<uint8_t>());
+ } else if (out_type == mshadow::kInt8) { // zero-centered quantization
+ Kernel<quantize_v2_zero_centered, xpu>::Launch(
+ s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
+ in_max_t.dptr<float>(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+ } else {
+ LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+ }
}
}
}
@@ -202,7 +218,8 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector<int>
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 3U);
const QuantizeV2Param ¶m = nnvm::get<QuantizeV2Param>(attrs.parsed);
- TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
+ CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
+ in_attrs->at(0) == mshadow::kInt8);
auto out_type = GetOutputType(param);
if (out_type == mshadow::kUint8) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc
index e221d58..300cdfe 100644
--- a/src/operator/quantization/quantize_v2.cc
+++ b/src/operator/quantization/quantize_v2.cc
@@ -88,6 +88,12 @@ If min_calib_range isn't presented, the output type will be int8.
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeV2Compute)
#endif
.set_attr<FCompute>("FCompute<cpu>", QuantizeV2Compute<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs) {
+ return std::vector<std::pair<int, int> >{{0, 0}};
+})
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", [](const NodeAttrs& attrs){
+ return std::vector<bool>{true};
+})
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
const QuantizeV2Param ¶m = nnvm::get<QuantizeV2Param>(attrs.parsed);
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py
index 2e3ff06..39bfbcd 100644
--- a/tests/python/train/test_dtype.py
+++ b/tests/python/train/test_dtype.py
@@ -65,6 +65,30 @@ def get_iterator_uint8(kv):
return (train, val)
+def get_iterator_int8(kv):
+ data_shape = (3, 28, 28)
+
+ train = mx.io.ImageRecordInt8Iter(
+ path_imgrec = "data/cifar/train.rec",
+ data_shape = data_shape,
+ batch_size = batch_size,
+ rand_crop = True,
+ rand_mirror = True,
+ num_parts = kv.num_workers,
+ part_index = kv.rank)
+ train = mx.io.PrefetchingIter(train)
+
+ val = mx.io.ImageRecordInt8Iter(
+ path_imgrec = "data/cifar/test.rec",
+ rand_crop = False,
+ rand_mirror = False,
+ data_shape = data_shape,
+ batch_size = batch_size,
+ num_parts = kv.num_workers,
+ part_index = kv.rank)
+
+ return (train, val)
+
def get_iterator_float32(kv):
data_shape = (3, 28, 28)
@@ -190,5 +214,10 @@ def test_cifar10():
run_cifar10(train, val, use_module=False)
run_cifar10(train, val, use_module=True)
+ # test int8 input
+ (train, val) = get_iterator_int8(kv)
+ run_cifar10(train, val, use_module=False)
+ run_cifar10(train, val, use_module=True)
+
if __name__ == "__main__":
test_cifar10()