You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2019/03/04 21:50:47 UTC

[incubator-mxnet] branch master updated: Add int8 data loader (#14123)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new df7771b  Add int8 data loader (#14123)
df7771b is described below

commit df7771b4d29f6283f9d7d9d721378fc6be5cf9d8
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Tue Mar 5 05:50:27 2019 +0800

    Add int8 data loader (#14123)
    
    * Enable int8 data layer
    
    Change-Id: I3d97ef80b7466d7555f4970e24f02e8dfba6be2b
    
    * fix lint
    
    * Add parameter description
    
    * Fix imagenet_inference.py
    
    * Allow quantize_v2 to accept int8
    
    * make float32 default
---
 docs/api/perl/io.md                                |   1 +
 docs/api/python/io/io.md                           |   1 +
 example/quantization/imagenet_inference.py         |  84 +++++++++++++----
 include/mxnet/c_api.h                              |   4 +-
 perl-package/AI-MXNet/lib/AI/MXNet/IO.pm           |   1 +
 src/io/iter_image_recordio_2.cc                    |  55 +++++++++--
 .../quantization/mkldnn/mkldnn_quantize_v2-inl.h   |  31 ++++--
 src/operator/quantization/quantize_v2-inl.h        | 105 ++++++++++++---------
 src/operator/quantization/quantize_v2.cc           |   6 ++
 tests/python/train/test_dtype.py                   |  29 ++++++
 10 files changed, 236 insertions(+), 81 deletions(-)

diff --git a/docs/api/perl/io.md b/docs/api/perl/io.md
index be49764..ca3b0f1 100644
--- a/docs/api/perl/io.md
+++ b/docs/api/perl/io.md
@@ -69,6 +69,7 @@ Then we can call `$mod->fit($nd_iter, num_epoch=>2)` to train `loss` by 2 epochs
 mx->io->NDArrayIter
 mx->io->CSVIter
 mx->io->ImageRecordIter
+mx->io->ImageRecordInt8Iter
 mx->io->ImageRecordUInt8Iter
 mx->io->MNISTIter
 mx->recordio->MXRecordIO
diff --git a/docs/api/python/io/io.md b/docs/api/python/io/io.md
index c0dc8d1..13a6121 100644
--- a/docs/api/python/io/io.md
+++ b/docs/api/python/io/io.md
@@ -75,6 +75,7 @@ A detailed tutorial is available at
     io.CSVIter
     io.LibSVMIter
     io.ImageRecordIter
+    io.ImageRecordInt8Iter
     io.ImageRecordUInt8Iter
     io.MNISTIter
     recordio.MXRecordIO
diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py
index 0725165..47e2063 100644
--- a/example/quantization/imagenet_inference.py
+++ b/example/quantization/imagenet_inference.py
@@ -19,6 +19,7 @@ import argparse
 import logging
 import os
 import time
+import numpy as np
 import mxnet as mx
 from mxnet import nd
 from mxnet.contrib.quantization import *
@@ -98,7 +99,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
             logger.info(m.get())
 
 
-def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
+def benchmark_score(symbol_file, ctx, batch_size, num_batches, data_layer_type, logger=None):
     # get mod
     cur_path = os.path.dirname(os.path.realpath(__file__))
     symbol_file_path = os.path.join(cur_path, symbol_file)
@@ -106,14 +107,28 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
         logger.info('Loading symbol from file %s' % symbol_file_path)
     sym = mx.sym.load(symbol_file_path)
     mod = mx.mod.Module(symbol=sym, context=ctx)
-    mod.bind(for_training     = False,
-             inputs_need_grad = False,
-             data_shapes      = [('data', (batch_size,)+data_shape)])
+    if data_layer_type == "int8":
+        dshape = mx.io.DataDesc(name='data', shape=(
+            batch_size,) + data_shape, dtype=np.int8)
+    elif data_layer_type == 'uint8':
+        dshape = mx.io.DataDesc(name='data', shape=(
+            batch_size,) + data_shape, dtype=np.uint8)
+    else:  # float32
+        dshape = mx.io.DataDesc(name='data', shape=(
+            batch_size,) + data_shape, dtype=np.float32)
+    mod.bind(for_training=False,
+             inputs_need_grad=False,
+             data_shapes=[dshape])
     mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
 
     # get data
-    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
-    batch = mx.io.DataBatch(data, []) # empty label
+    if data_layer_type == "float32":
+        data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx, dtype=data_layer_type)
+                for _, shape in mod.data_shapes]
+    else:
+        data = [mx.nd.full(shape=shape, val=127, ctx=ctx, dtype=data_layer_type)
+                for _, shape in mod.data_shapes]
+    batch = mx.io.DataBatch(data, [])  # empty label
 
     # run
     dry_run = 5                 # use 5 iterations to warm up
@@ -152,6 +167,9 @@ if __name__ == '__main__':
                         help='shuffling seed, see'
                              ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
                              ' for more details')
+    parser.add_argument('--data-layer-type', type=str, default="float32",
+                        choices=['float32', 'int8', 'uint8'],
+                        help='data type for data layer')
 
     args = parser.parse_args()
 
@@ -192,24 +210,52 @@ if __name__ == '__main__':
     data_shape = tuple([int(i) for i in image_shape.split(',')])
     logger.info('Input data shape = %s' % str(data_shape))
 
+    data_layer_type = args.data_layer_type
     if args.benchmark == False:
         dataset = args.dataset
         download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
         logger.info('Dataset for inference: %s' % dataset)
 
         # creating data iterator
-        data = mx.io.ImageRecordIter(path_imgrec=dataset,
-                                    label_width=1,
-                                    preprocess_threads=data_nthreads,
-                                    batch_size=batch_size,
-                                    data_shape=data_shape,
-                                    label_name=label_name,
-                                    rand_crop=False,
-                                    rand_mirror=False,
-                                    shuffle=True,
-                                    shuffle_chunk_seed=3982304,
-                                    seed=48564309,
-                                    **combine_mean_std)
+        if data_layer_type == 'int8':
+            data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset,
+                                             label_width=1,
+                                             preprocess_threads=data_nthreads,
+                                             batch_size=batch_size,
+                                             data_shape=data_shape,
+                                             label_name=label_name,
+                                             rand_crop=False,
+                                             rand_mirror=False,
+                                             shuffle=args.shuffle_dataset,
+                                             shuffle_chunk_seed=args.shuffle_chunk_seed,
+                                             seed=args.shuffle_seed,
+                                             **combine_mean_std)
+        elif data_layer_type == 'uint8':
+            data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset,
+                                              label_width=1,
+                                              preprocess_threads=data_nthreads,
+                                              batch_size=batch_size,
+                                              data_shape=data_shape,
+                                              label_name=label_name,
+                                              rand_crop=False,
+                                              rand_mirror=False,
+                                              shuffle=args.shuffle_dataset,
+                                              shuffle_chunk_seed=args.shuffle_chunk_seed,
+                                              seed=args.shuffle_seed,
+                                              **combine_mean_std)
+        else:  #float32
+            data = mx.io.ImageRecordIter(path_imgrec=dataset,
+                                         label_width=1,
+                                         preprocess_threads=data_nthreads,
+                                         batch_size=batch_size,
+                                         data_shape=data_shape,
+                                         label_name=label_name,
+                                         rand_crop=False,
+                                         rand_mirror=False,
+                                         shuffle=args.shuffle_dataset,
+                                         shuffle_chunk_seed=args.shuffle_chunk_seed,
+                                         seed=args.shuffle_seed,
+                                         **combine_mean_std)
 
         # loading model
         sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
@@ -224,5 +270,5 @@ if __name__ == '__main__':
             max_num_examples=num_inference_images, logger=logger)
     else:
         logger.info('Running model %s for inference' % symbol_file)
-        speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, logger)
+        speed = benchmark_score(symbol_file, ctx, batch_size, args.num_inference_batches, data_layer_type, logger)
         logger.info('batch size %2d, image/sec: %f', batch_size, speed)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 76a4995..9a24b75 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1602,8 +1602,8 @@ MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
  * \param excluded_symbols op names to be excluded from being quantized
  * \param num_offline number of parameters that are quantized offline
  * \param offline_params array of c strings representing the names of params quantized offline
- * \param quantized_dtype the quantized destination type for input data.
- * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could.
+ * \param quantized_dtype the quantized destination type for input data
+ * \param calib_quantize **Deprecated**. quantize op will always be calibrated if could
  */
 MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle,
                                const mx_uint num_excluded_symbols,
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
index 297ceb8..19e7cfd 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/IO.pm
@@ -642,6 +642,7 @@ extends 'AI::MXNet::DataIter';
     mx->io->CSVIter                     Returns the CSV file iterator.
     mx->io->LibSVMIter                  Returns the LibSVM iterator which returns data with csr storage type.
     mx->io->ImageRecordIter             Iterates on image RecordIO files
+    mx->io->ImageRecordInt8Iter         Iterating on image RecordIO files
     mx->io->ImageRecordUInt8Iter        Iterating on image RecordIO files
     mx->io->MNISTIter                   Iterating on the MNIST dataset.
     mx->recordio->MXRecordIO            Reads/writes RecordIO data format, supporting sequential read and write.
diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc
index 5d5261b..0834dd7 100644
--- a/src/io/iter_image_recordio_2.cc
+++ b/src/io/iter_image_recordio_2.cc
@@ -372,6 +372,7 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
   float RGBA_MULT[4] = { 0 };
   float RGBA_BIAS[4] = { 0 };
   float RGBA_MEAN[4] = { 0 };
+  int16_t RGBA_MEAN_INT[4] = {0};
   mshadow::Tensor<cpu, 3, DType>& data = (*data_ptr);
   if (!std::is_same<DType, uint8_t>::value) {
     RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r;
@@ -387,6 +388,10 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
       RGBA_MEAN[1] = normalize_param_.mean_g;
       RGBA_MEAN[2] = normalize_param_.mean_b;
       RGBA_MEAN[3] = normalize_param_.mean_a;
+      RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r);
+      RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g);
+      RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b);
+      RGBA_MEAN_INT[3] = std::round(normalize_param_.mean_a);
     }
   }
 
@@ -408,17 +413,30 @@ void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
   for (int i = 0; i < res.rows; ++i) {
     const uchar* im_data = res.ptr<uchar>(i);
     for (int j = 0; j < res.cols; ++j) {
-      for (int k = 0; k < n_channels; ++k) {
-        RGBA[k] = im_data[swap_indices[k]];
-      }
-      if (!std::is_same<DType, uint8_t>::value) {
-        // normalize/mirror here to avoid memory copies
-        // logic from iter_normalize.h, function SetOutImg
+      if (std::is_same<DType, int8_t>::value) {
+        if (meanfile_ready_) {
+          for (int k = 0; k < n_channels; ++k) {
+            RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] -
+                                    static_cast<int16_t>(std::round(meanimg_[k][i][j])));
+          }
+        } else {
+          for (int k = 0; k < n_channels; ++k) {
+            RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] - RGBA_MEAN_INT[k]);
+          }
+        }
+      } else {
         for (int k = 0; k < n_channels; ++k) {
-          if (meanfile_ready_) {
-            RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
-          } else {
-            RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
+          RGBA[k] = im_data[swap_indices[k]];
+        }
+        if (!std::is_same<DType, uint8_t>::value) {
+          // normalize/mirror here to avoid memory copies
+          // logic from iter_normalize.h, function SetOutImg
+          for (int k = 0; k < n_channels; ++k) {
+            if (meanfile_ready_) {
+              RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
+            } else {
+              RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
+            }
           }
         }
       }
@@ -795,5 +813,22 @@ the data type instead of ``float``.
 .set_body([]() {
     return new ImageRecordIter2<uint8_t>();
   });
+
+MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter)
+.describe(R"code(Iterating on image RecordIO files
+
+This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as
+the data type instead of ``float``.
+
+)code" ADD_FILELINE)
+.add_arguments(ImageRecParserParam::__FIELDS__())
+.add_arguments(ImageRecordParam::__FIELDS__())
+.add_arguments(BatchParam::__FIELDS__())
+.add_arguments(PrefetcherParam::__FIELDS__())
+.add_arguments(ListDefaultAugParams())
+.set_body([]() {
+    return new ImageRecordIter2<int8_t>();
+  });
+
 }  // namespace io
 }  // namespace mxnet
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
index e201d29..d6060e5 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
@@ -123,13 +123,32 @@ static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContex
                                     const std::vector<OpReqType>& req,
                                     const std::vector<NDArray>& outputs) {
   const QuantizeV2Param& param = nnvm::get<QuantizeV2Param>(attrs.parsed);
-  auto out_type = GetOutputType(param);
-  if (out_type == mshadow::kUint8) {
-    MKLDNNQuantizeComputeKer<float, uint8_t>(inputs, outputs, param, req);
-  } else if (out_type == mshadow::kInt8) {
-    MKLDNNQuantizeComputeKer<float, int8_t>(inputs, outputs, param, req);
+  if (inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8) {
+    if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+      *outputs[1].data().dptr<float>() = param.min_calib_range.value();
+      *outputs[2].data().dptr<float>() = param.max_calib_range.value();
+    } else {
+      if (inputs[0].dtype() == mshadow::kUint8) {
+        *outputs[1].data().dptr<float>() = 0;
+        *outputs[2].data().dptr<float>() = 255;
+      } else {
+        *outputs[1].data().dptr<float>() = -127;
+        *outputs[2].data().dptr<float>() = 127;
+      }
+    }
+    if (req[0] != kWriteInplace) {
+      const_cast<NDArray&>(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData());
+      MKLDNNStream::Get()->Submit();
+    }
   } else {
-    LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
+    auto out_type = GetOutputType(param);
+    if (out_type == mshadow::kUint8) {
+      MKLDNNQuantizeComputeKer<float, uint8_t>(inputs, outputs, param, req);
+    } else if (out_type == mshadow::kInt8) {
+      MKLDNNQuantizeComputeKer<float, int8_t>(inputs, outputs, param, req);
+    } else {
+      LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
+    }
   }
 }
 
diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h
index 7a09983..e3c4119 100644
--- a/src/operator/quantization/quantize_v2-inl.h
+++ b/src/operator/quantization/quantize_v2-inl.h
@@ -137,51 +137,67 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
   Stream<xpu> *s = ctx.get_stream<xpu>();
   const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
   auto out_type = GetOutputType(param);
-  if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
-    if (out_type == mshadow::kUint8) {
-      Kernel<quantize_v2_unsigned, xpu>::Launch(
-          s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
-          outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
-          param.max_calib_range.value(), MinValue<uint8_t>(), MaxValue<uint8_t>());
-    } else if (out_type == mshadow::kInt8) {  // zero-centered quantization
-      Kernel<quantize_v2_zero_centered, xpu>::Launch(
-          s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
-          outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
-          param.max_calib_range.value(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+
+  if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) {
+    if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+      *outputs[1].dptr<float>() = param.min_calib_range.value();
+      *outputs[2].dptr<float>() = param.max_calib_range.value();
     } else {
-      LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+      if (inputs[0].type_flag_ == mshadow::kUint8) {
+        *outputs[1].dptr<float>() = 0;
+        *outputs[2].dptr<float>() = 255;
+      } else {
+        *outputs[1].dptr<float>() = -127;
+        *outputs[2].dptr<float>() = 127;
+      }
     }
-  } else {  // model is not calibrated
-    mxnet::TShape src_shape, dst_shape;
-    const size_t actual_float_size = sizeof(float);
-    const size_t temp_reduce_size =
-        ConfigReduce<xpu, SrcDType>(s, inputs[0].shape_, mxnet::TShape({1}),
-                                    &src_shape, &dst_shape);
-    Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, char>(
-        Shape1(2 * actual_float_size + temp_reduce_size), s);
-    const int dev_id = ctx.run_ctx.ctx.dev_id;
-    TBlob in_min_t(reinterpret_cast<SrcDType *>(temp_space.dptr_), Shape1(1), xpu::kDevMask,
-                   dev_id);
-    TBlob in_max_t(reinterpret_cast<SrcDType *>(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask,
-                   dev_id);
-    Tensor<xpu, 1, char> workspace(temp_space.dptr_ + 2 * actual_float_size,
-                                   Shape1(temp_reduce_size), s);
-    broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
-        s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
-    broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
-        s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
-    if (out_type == mshadow::kUint8) {
-      Kernel<quantize_v2_unsigned, xpu>::Launch(
-          s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
-          outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
-          in_max_t.dptr<float>(), MinValue<uint8_t>(), MaxValue<uint8_t>());
-    } else if (out_type == mshadow::kInt8) {  // zero-centered quantization
-      Kernel<quantize_v2_zero_centered, xpu>::Launch(
-          s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
-          outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
-          in_max_t.dptr<float>(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
-    } else {
-      LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+    UnaryOp::IdentityCompute<xpu>(attrs, ctx, {inputs[0]}, req, outputs);
+  } else {
+    if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+      if (out_type == mshadow::kUint8) {
+        Kernel<quantize_v2_unsigned, xpu>::Launch(
+            s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
+            outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
+            param.max_calib_range.value(), MinValue<uint8_t>(), MaxValue<uint8_t>());
+      } else if (out_type == mshadow::kInt8) {  // zero-centered quantization
+        Kernel<quantize_v2_zero_centered, xpu>::Launch(
+            s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
+            outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
+            param.max_calib_range.value(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+      } else {
+        LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+      }
+    } else {  // model is not calibrated
+      mxnet::TShape src_shape, dst_shape;
+      const size_t actual_float_size = sizeof(float);
+      const size_t temp_reduce_size = ConfigReduce<xpu, SrcDType>(
+          s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape);
+      Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, char>(
+          Shape1(2 * actual_float_size + temp_reduce_size), s);
+      const int dev_id = ctx.run_ctx.ctx.dev_id;
+      TBlob in_min_t(reinterpret_cast<SrcDType *>(temp_space.dptr_), Shape1(1), xpu::kDevMask,
+                    dev_id);
+      TBlob in_max_t(reinterpret_cast<SrcDType *>(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask,
+                    dev_id);
+      Tensor<xpu, 1, char> workspace(temp_space.dptr_ + 2 * actual_float_size,
+                                    Shape1(temp_reduce_size), s);
+      broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
+          s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+      broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
+          s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+      if (out_type == mshadow::kUint8) {
+        Kernel<quantize_v2_unsigned, xpu>::Launch(
+            s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
+            outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
+            in_max_t.dptr<float>(), MinValue<uint8_t>(), MaxValue<uint8_t>());
+      } else if (out_type == mshadow::kInt8) {  // zero-centered quantization
+        Kernel<quantize_v2_zero_centered, xpu>::Launch(
+            s, outputs[0].Size(), outputs[0].dptr<int8_t>(), outputs[1].dptr<float>(),
+            outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
+            in_max_t.dptr<float>(), MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
+      } else {
+        LOG(FATAL) << "quantize op only supports int8 and uint8 as output type";
+      }
     }
   }
 }
@@ -202,7 +218,8 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector<int>
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 3U);
   const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
+  CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
+        in_attrs->at(0) == mshadow::kInt8);
   auto out_type = GetOutputType(param);
   if (out_type == mshadow::kUint8) {
     TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc
index e221d58..300cdfe 100644
--- a/src/operator/quantization/quantize_v2.cc
+++ b/src/operator/quantization/quantize_v2.cc
@@ -88,6 +88,12 @@ If min_calib_range isn't presented, the output type will be int8.
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeV2Compute)
 #endif
 .set_attr<FCompute>("FCompute<cpu>", QuantizeV2Compute<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs) {
+  return std::vector<std::pair<int, int> >{{0, 0}};
+})
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity", [](const NodeAttrs& attrs){
+  return std::vector<bool>{true};
+})
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
   const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
   if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py
index 2e3ff06..39bfbcd 100644
--- a/tests/python/train/test_dtype.py
+++ b/tests/python/train/test_dtype.py
@@ -65,6 +65,30 @@ def get_iterator_uint8(kv):
 
     return (train, val)
 
+def get_iterator_int8(kv):
+    data_shape = (3, 28, 28)
+
+    train = mx.io.ImageRecordInt8Iter(
+        path_imgrec = "data/cifar/train.rec",
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        rand_crop   = True,
+        rand_mirror = True,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank)
+    train = mx.io.PrefetchingIter(train)
+
+    val = mx.io.ImageRecordInt8Iter(
+        path_imgrec = "data/cifar/test.rec",
+        rand_crop   = False,
+        rand_mirror = False,
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank)
+
+    return (train, val)
+
 def get_iterator_float32(kv):
     data_shape = (3, 28, 28)
 
@@ -190,5 +214,10 @@ def test_cifar10():
     run_cifar10(train, val, use_module=False)
     run_cifar10(train, val, use_module=True)
 
+    # test int8 input
+    (train, val) = get_iterator_int8(kv)
+    run_cifar10(train, val, use_module=False)
+    run_cifar10(train, val, use_module=True)
+
 if __name__ == "__main__":
     test_cifar10()