You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/05/05 02:13:46 UTC

[incubator-mxnet] branch master updated: Refactor ImageRecordIter (#14824)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 621b391  Refactor ImageRecordIter (#14824)
621b391 is described below

commit 621b391c1445b9abba7f226abcbc55ab6aee5a0c
Author: Zhennan Qin <zh...@intel.com>
AuthorDate: Sun May 5 10:13:26 2019 +0800

    Refactor ImageRecordIter (#14824)
    
    * cpu optimized data loader
    
    * Fix CI
    
    * Fix CI
    
    * Fix ci
    
    * Fix doc
---
 docs/architecture/note_data_loading.md     |   8 +-
 example/quantization/README.md             |  32 ++++----
 example/quantization/imagenet_inference.py |  54 ++++---------
 src/io/image_iter_common.h                 |  10 +++
 src/io/iter_image_recordio_2.cc            | 119 ++++++++++++++++++++++++++++-
 tests/python/train/test_dtype.py           |  66 ++++++++++++++++
 6 files changed, 228 insertions(+), 61 deletions(-)

diff --git a/docs/architecture/note_data_loading.md b/docs/architecture/note_data_loading.md
index 293b675..a60bf90 100644
--- a/docs/architecture/note_data_loading.md
+++ b/docs/architecture/note_data_loading.md
@@ -83,7 +83,7 @@ In MXNet, we rely on the binary recordIO format implemented in dmlc-core.
 In MXNet's binary RecordIO, we store each data instance as a record.
 **kMagic** is a *magic number* indicating the start of a record.
 **Lrecord** encodes length and a continue flag.
-In lrecord,  
+In lrecord,
 - cflag == 0: this is a complete record
 - cflag == 1: start of a multiple-records
 - cflag == 2: middle of multiple-records
@@ -228,7 +228,11 @@ dataiter = mx.io.ImageRecordIter(
     # Backend Parameter, preprocessing thread number
     preprocess_threads=4,
     # Backend Parameter, prefetch buffer size
-    prefetch_buffer=1)
+    prefetch_buffer=1,
+    # Optional, the device context which data loader optimized for, could be 'gpu' or 'cpu'
+    ctx="gpu",
+    # The out data type, could be 'float32' 'int8' or 'uint8'
+    dtype="float32")
 ```
 
 Generally, to create a data iterator, you need to provide five kinds of parameters:
diff --git a/example/quantization/README.md b/example/quantization/README.md
index c39d027..93a14cf 100644
--- a/example/quantization/README.md
+++ b/example/quantization/README.md
@@ -54,10 +54,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference 
-python imagenet_inference.py --symbol-file=./model/resnet50_v1-symbol.json --param-file=./model/resnet50_v1-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/resnet50_v1-symbol.json --param-file=./model/resnet50_v1-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/resnet50_v1-quantized-5batches-naive-symbol.json --param-file=./model/resnet50_v1-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/resnet50_v1-quantized-5batches-naive-symbol.json --param-file=./model/resnet50_v1-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/resnet50_v1-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
@@ -78,10 +78,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference
-python imagenet_inference.py --symbol-file=./model/squeezenet1.0-symbol.json --param-file=./model/squeezenet1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/squeezenet1.0-symbol.json --param-file=./model/squeezenet1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/squeezenet1.0-quantized-5batches-naive-symbol.json --param-file=./model/squeezenet1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/squeezenet1.0-quantized-5batches-naive-symbol.json --param-file=./model/squeezenet1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/squeezenet1.0-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
@@ -102,10 +102,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference
-python imagenet_inference.py --symbol-file=./model/mobilenet1.0-symbol.json --param-file=./model/mobilenet1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/mobilenet1.0-symbol.json --param-file=./model/mobilenet1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/mobilenet1.0-quantized-5batches-naive-symbol.json --param-file=./model/mobilenet1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/mobilenet1.0-quantized-5batches-naive-symbol.json --param-file=./model/mobilenet1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/mobilenet1.0-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
@@ -126,10 +126,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference
-python imagenet_inference.py --symbol-file=./model/mobilenetv2_1.0-symbol.json --param-file=./model/mobilenetv2_1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/mobilenetv2_1.0-symbol.json --param-file=./model/mobilenetv2_1.0-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/mobilenetv2_1.0-quantized-5batches-naive-symbol.json --param-file=./model/mobilenetv2_1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/mobilenetv2_1.0-quantized-5batches-naive-symbol.json --param-file=./model/mobilenetv2_1.0-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/mobilenetv2_1.0-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
@@ -150,10 +150,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference
-python imagenet_inference.py --symbol-file=./model/inceptionv3-symbol.json --param-file=./model/inceptionv3-0000.params --image-shape=3,299,299 --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/inceptionv3-symbol.json --param-file=./model/inceptionv3-0000.params --image-shape=3,299,299 --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/inceptionv3-quantized-5batches-naive-symbol.json --param-file=./model/inceptionv3-quantized-0000.params --image-shape=3,299,299 --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/inceptionv3-quantized-5batches-naive-symbol.json --param-file=./model/inceptionv3-quantized-0000.params --image-shape=3,299,299 --rgb-mean=123.68,116.779,103.939 --rgb-std=58.393,57.12,57.375 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/inceptionv3-symbol.json --image-shape=3,299,299 --batch-size=64 --num-inference-batches=500 --ctx=cpu  --benchmark=True
@@ -175,10 +175,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference 
-python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-symbol.json --param-file=./model/imagenet1k-resnet-152-0000.params --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-symbol.json --param-file=./model/imagenet1k-resnet-152-0000.params --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-resnet-152-quantized-0000.params --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/imagenet1k-resnet-152-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
@@ -200,10 +200,10 @@ The model would be automatically replaced in fusion and quantization format. It
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference 
-python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-symbol.json --param-file=./model/imagenet1k-inception-bn-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-symbol.json --param-file=./model/imagenet1k-inception-bn-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch INT8 Inference
-python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu  --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-quantized-5batches-naive-symbol.json --param-file=./model/imagenet1k-inception-bn-quantized-0000.params --rgb-mean=123.68,116.779,103.939 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/imagenet1k-inception-bn-symbol.json --batch-size=64 --num-inference-batches=500 --ctx=cpu --benchmark=True
@@ -244,7 +244,7 @@ Some tips on quantization configs:
 export MXNET_SUBGRAPH_BACKEND=MKLDNN
 
 # Launch FP32 Inference 
-python imagenet_inference.py --symbol-file=./model/custom-symbol.json --param-file=./model/custom-0000.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/custom-symbol.json --param-file=./model/custom-0000.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* --ctx=cpu
 ```
 
 3. Then, you should add `rgb_mean`, `rgb_std` and `excluded_sym_names` in this script. Notice that you should exclude conv/pool layers that have negative data since IntelĀ® MKL-DNN only supports `uint8` quantization temporarily. You should also exclude all fc layers in your model.
@@ -261,7 +261,7 @@ python imagenet_gen_qsym_mkldnn.py --model=custom --num-calib-batches=5 --calib-
 
 ```
 # Launch INT8 Inference 
-python imagenet_inference.py --symbol-file=./model/*.json --param-file=./model/*.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* --ctx=cpu --data-nthreads=1
+python imagenet_inference.py --symbol-file=./model/*.json --param-file=./model/*.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* --ctx=cpu
 
 # Launch dummy data Inference
 python imagenet_inference.py --symbol-file=./model/*.json --batch-size=* --num-inference-batches=500 --ctx=cpu --benchmark=True
diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py
index 47e2063..e785461 100644
--- a/example/quantization/imagenet_inference.py
+++ b/example/quantization/imagenet_inference.py
@@ -217,45 +217,21 @@ if __name__ == '__main__':
         logger.info('Dataset for inference: %s' % dataset)
 
         # creating data iterator
-        if data_layer_type == 'int8':
-            data = mx.io.ImageRecordInt8Iter(path_imgrec=dataset,
-                                             label_width=1,
-                                             preprocess_threads=data_nthreads,
-                                             batch_size=batch_size,
-                                             data_shape=data_shape,
-                                             label_name=label_name,
-                                             rand_crop=False,
-                                             rand_mirror=False,
-                                             shuffle=args.shuffle_dataset,
-                                             shuffle_chunk_seed=args.shuffle_chunk_seed,
-                                             seed=args.shuffle_seed,
-                                             **combine_mean_std)
-        elif data_layer_type == 'uint8':
-            data = mx.io.ImageRecordUInt8Iter(path_imgrec=dataset,
-                                              label_width=1,
-                                              preprocess_threads=data_nthreads,
-                                              batch_size=batch_size,
-                                              data_shape=data_shape,
-                                              label_name=label_name,
-                                              rand_crop=False,
-                                              rand_mirror=False,
-                                              shuffle=args.shuffle_dataset,
-                                              shuffle_chunk_seed=args.shuffle_chunk_seed,
-                                              seed=args.shuffle_seed,
-                                              **combine_mean_std)
-        else:  #float32
-            data = mx.io.ImageRecordIter(path_imgrec=dataset,
-                                         label_width=1,
-                                         preprocess_threads=data_nthreads,
-                                         batch_size=batch_size,
-                                         data_shape=data_shape,
-                                         label_name=label_name,
-                                         rand_crop=False,
-                                         rand_mirror=False,
-                                         shuffle=args.shuffle_dataset,
-                                         shuffle_chunk_seed=args.shuffle_chunk_seed,
-                                         seed=args.shuffle_seed,
-                                         **combine_mean_std)
+        data = mx.io.ImageRecordIter(
+            path_imgrec=dataset,
+            label_width=1,
+            preprocess_threads=data_nthreads,
+            batch_size=batch_size,
+            data_shape=data_shape,
+            label_name=label_name,
+            rand_crop=False,
+            rand_mirror=False,
+            shuffle=args.shuffle_dataset,
+            shuffle_chunk_seed=args.shuffle_chunk_seed,
+            seed=args.shuffle_seed,
+            dtype=data_layer_type,
+            ctx=args.ctx,
+            **combine_mean_std)
 
         # loading model
         sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h
index 4bbcb9d..4d4b373 100644
--- a/src/io/image_iter_common.h
+++ b/src/io/image_iter_common.h
@@ -346,8 +346,13 @@ struct ImageDetNormalizeParam :  public dmlc::Parameter<ImageDetNormalizeParam>
 
 // Define prefetcher parameters
 struct PrefetcherParam : public dmlc::Parameter<PrefetcherParam> {
+  enum CtxType { kGPU = 0, kCPU};
   /*! \brief number of prefetched batches */
   size_t prefetch_buffer;
+
+  /*! \brief Context data loader optimized for */
+  int ctx;
+
   /*! \brief data type */
   dmlc::optional<int> dtype;
 
@@ -355,6 +360,10 @@ struct PrefetcherParam : public dmlc::Parameter<PrefetcherParam> {
   DMLC_DECLARE_PARAMETER(PrefetcherParam) {
     DMLC_DECLARE_FIELD(prefetch_buffer).set_default(4)
         .describe("Maximum number of batches to prefetch.");
+    DMLC_DECLARE_FIELD(ctx).set_default(kGPU)
+        .add_enum("cpu", kCPU)
+        .add_enum("gpu", kGPU)
+        .describe("Context data loader optimized for.");
     DMLC_DECLARE_FIELD(dtype)
       .add_enum("float32", mshadow::kFloat32)
       .add_enum("float64", mshadow::kFloat64)
@@ -362,6 +371,7 @@ struct PrefetcherParam : public dmlc::Parameter<PrefetcherParam> {
       .add_enum("int64", mshadow::kInt64)
       .add_enum("int32", mshadow::kInt32)
       .add_enum("uint8", mshadow::kUint8)
+      .add_enum("int8", mshadow::kInt8)
       .set_default(dmlc::optional<int>())
       .describe("Output data type. ``None`` means no change.");
   }
diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc
index 0834dd7..5d9e81d 100644
--- a/src/io/iter_image_recordio_2.cc
+++ b/src/io/iter_image_recordio_2.cc
@@ -44,6 +44,7 @@
 #include "../common/utils.h"
 
 namespace mxnet {
+
 namespace io {
 // parser to parse image recordio
 template<typename DType>
@@ -87,7 +88,7 @@ class ImageRecordIOParser2 {
   ImageRecordParam record_param_;
   BatchParam batch_param_;
   ImageNormalizeParam normalize_param_;
-  PrefetcherParam prefetch_param_;
+
   #if MXNET_USE_OPENCV
   /*! \brief augmenters */
   std::vector<std::vector<std::unique_ptr<ImageAugmenter> > > augmenters_;
@@ -133,7 +134,6 @@ inline void ImageRecordIOParser2<DType>::Init(
   record_param_.InitAllowUnknown(kwargs);
   batch_param_.InitAllowUnknown(kwargs);
   normalize_param_.InitAllowUnknown(kwargs);
-  prefetch_param_.InitAllowUnknown(kwargs);
   n_parsed_ = 0;
   overflow = false;
   rnd_.seed(kRandMagic + record_param_.seed);
@@ -141,7 +141,7 @@ inline void ImageRecordIOParser2<DType>::Init(
   #pragma omp parallel
   {
     // be conservative, set number of real cores
-    maxthread = std::max(omp_get_num_procs() / 2 - 1, 1);
+    maxthread = std::max(omp_get_num_procs(), 1);
   }
   param_.preprocess_threads = std::min(maxthread, param_.preprocess_threads);
   #pragma omp parallel num_threads(param_.preprocess_threads)
@@ -763,6 +763,113 @@ class ImageRecordIter2 : public IIterator<DataBatch> {
     ImageRecordIOParser2<DType> parser_;
 };
 
+template<typename DType = real_t>
+class ImageRecordIter2CPU : public IIterator<DataBatch> {
+ public:
+  ImageRecordIter2CPU() {
+    out_ = new DataBatch();
+    var_ = Engine::Get()->NewVariable();
+  }
+
+  virtual ~ImageRecordIter2CPU(void) {
+    Engine::Get()->DeleteVariable([](mxnet::RunContext ctx) {}, Context::CPU(), var_);
+    delete out_;
+  }
+
+  virtual void Init(const std::vector<std::pair<std::string, std::string>>& kwargs) {
+    parser_.Init(kwargs);
+  }
+
+  virtual void BeforeFirst(void) { parser_.BeforeFirst(); }
+
+  // From iter_prefetcher.h
+  virtual bool Next(void) {
+    bool result = false;
+    const auto engine = Engine::Get();
+    engine->PushSync(
+        [this, &result](RunContext ctx) {
+          result = this->parser_.ParseNext(out_);
+        },
+        Context::CPU(), {}, {var_}, FnProperty::kNormal, 0, "DataLoader");
+    engine->WaitForVar(var_);
+    return result;
+  }
+
+  virtual const DataBatch& Value(void) const { return *out_; }
+
+ private:
+  /*! \brief Backend thread */
+  dmlc::ThreadedIter<DataBatch> iter_;
+  /*! \brief output data */
+  DataBatch* out_;
+  Engine::VarHandle var_;
+  /*! \brief queue to be recycled */
+  std::queue<DataBatch*> recycle_queue_;
+  /* \brief parser */
+  ImageRecordIOParser2<DType> parser_;
+};
+
+class ImageRecordIter2Wrapper : public IIterator<DataBatch> {
+ public:
+  ~ImageRecordIter2Wrapper(void) override {
+    if (record_iter_) delete record_iter_;
+  }
+  void Init(const std::vector<std::pair<std::string, std::string>>& kwargs) override {
+    PrefetcherParam prefetch_param;
+    prefetch_param.InitAllowUnknown(kwargs);
+    int dtype = mshadow::kFloat32;
+    if (prefetch_param.dtype.has_value()) {
+      dtype = prefetch_param.dtype.value();
+    }
+    if (prefetch_param.ctx == PrefetcherParam::CtxType::kCPU) {
+      LOG(INFO) << "Create ImageRecordIter2 optimized for CPU backend.";
+      switch (dtype) {
+        case mshadow::kFloat32:
+          record_iter_ = new ImageRecordIter2CPU<float>();
+          break;
+        case mshadow::kUint8:
+          record_iter_ = new ImageRecordIter2CPU<uint8_t>();
+          break;
+        case mshadow::kInt8:
+          record_iter_ = new ImageRecordIter2CPU<int8_t>();
+          break;
+        default:
+          LOG(FATAL) << "unknown dtype for ImageRecordIter2.";
+      }
+    } else {
+      // For gpu
+      switch (dtype) {
+        case mshadow::kFloat32:
+          record_iter_ = new ImageRecordIter2<float>();
+          break;
+        case mshadow::kUint8:
+          record_iter_ = new ImageRecordIter2<uint8_t>();
+          break;
+        case mshadow::kInt8:
+          record_iter_ = new ImageRecordIter2<int8_t>();
+          break;
+        default:
+          LOG(FATAL) << "unknown dtype for ImageRecordIter2.";
+      }
+    }
+    record_iter_->Init(kwargs);
+    }
+
+    void BeforeFirst(void) override {
+      record_iter_->BeforeFirst();
+    }
+
+    // From iter_prefetcher.h
+    bool Next(void) override { return record_iter_->Next(); }
+
+    const DataBatch &Value(void) const override {
+      return record_iter_->Value();
+    }
+
+ private:
+  IIterator<DataBatch>* record_iter_ = nullptr;
+};
+
 MXNET_REGISTER_IO_ITER(ImageRecordIter)
 .describe(R"code(Iterates on image RecordIO files
 
@@ -795,12 +902,14 @@ Example::
 .add_arguments(ListDefaultAugParams())
 .add_arguments(ImageNormalizeParam::__FIELDS__())
 .set_body([]() {
-    return new ImageRecordIter2<real_t>();
+    return new ImageRecordIter2Wrapper();
     });
 
 MXNET_REGISTER_IO_ITER(ImageRecordUInt8Iter)
 .describe(R"code(Iterating on image RecordIO files
 
+.. note:: ImageRecordUInt8Iter is deprecated. Use ImageRecordIter(dtype='uint8') instead.
+
 This iterator is identical to ``ImageRecordIter`` except for using ``uint8`` as
 the data type instead of ``float``.
 
@@ -817,6 +926,8 @@ the data type instead of ``float``.
 MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter)
 .describe(R"code(Iterating on image RecordIO files
 
+.. note:: ``ImageRecordInt8Iter`` is deprecated. Use ImageRecordIter(dtype='int8') instead.
+
 This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as
 the data type instead of ``float``.
 
diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py
index 39bfbcd..47b785c 100644
--- a/tests/python/train/test_dtype.py
+++ b/tests/python/train/test_dtype.py
@@ -65,6 +65,34 @@ def get_iterator_uint8(kv):
 
     return (train, val)
 
+def get_iterator_uint8_with_param(kv, ctx):
+    data_shape = (3, 28, 28)
+
+    train = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/train.rec",
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        rand_crop   = True,
+        rand_mirror = True,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank,
+        dtype       ='uint8',
+        ctx         = ctx)
+    train = mx.io.PrefetchingIter(train)
+
+    val = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/test.rec",
+        rand_crop   = False,
+        rand_mirror = False,
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank,
+        dtype       ='uint8',
+        ctx         = ctx)
+
+    return (train, val)
+
 def get_iterator_int8(kv):
     data_shape = (3, 28, 28)
 
@@ -89,6 +117,34 @@ def get_iterator_int8(kv):
 
     return (train, val)
 
+def get_iterator_int8_with_param(kv, ctx):
+    data_shape = (3, 28, 28)
+
+    train = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/train.rec",
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        rand_crop   = True,
+        rand_mirror = True,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank,
+        dtype       ='int8',
+        ctx         = ctx)
+    train = mx.io.PrefetchingIter(train)
+
+    val = mx.io.ImageRecordIter(
+        path_imgrec = "data/cifar/test.rec",
+        rand_crop   = False,
+        rand_mirror = False,
+        data_shape  = data_shape,
+        batch_size  = batch_size,
+        num_parts   = kv.num_workers,
+        part_index  = kv.rank,
+        dtype       = 'int8',
+        ctx         = ctx)
+
+    return (train, val)
+
 def get_iterator_float32(kv):
     data_shape = (3, 28, 28)
 
@@ -214,10 +270,20 @@ def test_cifar10():
     run_cifar10(train, val, use_module=False)
     run_cifar10(train, val, use_module=True)
 
+    for ctx in ("gpu", "cpu"):
+        (train, val) = get_iterator_uint8_with_param(kv, ctx)
+        run_cifar10(train, val, use_module=False)
+        run_cifar10(train, val, use_module=True)
+
     # test int8 input
     (train, val) = get_iterator_int8(kv)
     run_cifar10(train, val, use_module=False)
     run_cifar10(train, val, use_module=True)
 
+    for ctx in ("gpu", "cpu"):
+        (train, val) = get_iterator_int8_with_param(kv, ctx)
+        run_cifar10(train, val, use_module=False)
+        run_cifar10(train, val, use_module=True)
+
 if __name__ == "__main__":
     test_cifar10()