You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/03 20:22:46 UTC

[incubator-mxnet] branch master updated: Gluon image-classification example improvement (#9633)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8780096  Gluon image-classification example improvement (#9633)
8780096 is described below

commit 87800967037f711644216076dd8404f6402ff69c
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Sat Mar 3 14:22:41 2018 -0600

    Gluon image-classification example improvement (#9633)
    
    * backup
    
    * backup
    
    * finish
    
    * fix multiple
    
    * fix
    
    * fix
    
    * fix padding
    
    * add more tests
    
    * fix expanduser
---
 example/gluon/data.py                    |  82 ++++++++--------
 example/gluon/image_classification.py    | 162 ++++++++++++++++++++-----------
 python/mxnet/contrib/__init__.py         |   2 +
 python/mxnet/contrib/io.py               |  95 ++++++++++++++++++
 tests/python/unittest/test_contrib_io.py |  46 +++++++++
 5 files changed, 289 insertions(+), 98 deletions(-)

diff --git a/example/gluon/data.py b/example/gluon/data.py
index dc8f12e..c996c9a 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -19,8 +19,14 @@
 """ data iterator for mnist """
 import os
 import random
+import logging
+logging.basicConfig(level=logging.INFO)
+
 import mxnet as mx
 from mxnet.test_utils import get_cifar10
+from mxnet.gluon.data.vision import ImageFolderDataset
+from mxnet.gluon.data import DataLoader
+from mxnet.contrib.io import DataLoaderIter
 
 def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
     get_cifar10()
@@ -49,50 +55,38 @@ def get_cifar10_iterator(batch_size, data_shape, resize=-1, num_parts=1, part_in
 
     return train, val
 
-
-def get_imagenet_iterator(train_data, val_data, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
-    train = mx.io.ImageRecordIter(
-        path_imgrec             = train_data,
-        data_shape              = data_shape,
-        mean_r                  = 123.68,
-        mean_g                  = 116.779,
-        mean_b                  = 103.939,
-        std_r                   = 58.395,
-        std_g                   = 57.12,
-        std_b                   = 57.375,
-        preprocess_threads      = 32,
-        shuffle                 = True,
-        batch_size              = batch_size,
-        rand_crop               = True,
-        resize                  = resize,
-        random_mirror           = True,
-        max_random_h            = 36,
-        max_random_s            = 50,
-        max_random_l            = 50,
-        max_random_rotate_angle = 10,
-        max_random_shear_ratio  = 0.1,
-        max_random_aspect_ratio = 0.25,
-        fill_value              = 127,
-        min_random_scale        = 0.533,
-        num_parts               = num_parts,
-        part_index              = part_index)
-
-    val = mx.io.ImageRecordIter(
-        path_imgrec        = val_data,
-        data_shape         = data_shape,
-        mean_r             = 123.68,
-        mean_g             = 116.779,
-        mean_b             = 103.939,
-        std_r              = 58.395,
-        std_g              = 57.12,
-        std_b              = 57.375,
-        preprocess_threads = 32,
-        batch_size         = batch_size,
-        resize             = resize,
-        num_parts          = num_parts,
-        part_index         = part_index)
-
-    return train, val
+def get_imagenet_transforms(data_shape=224, dtype='float32'):
+    def train_transform(image, label):
+        image, _ = mx.image.random_size_crop(image, (data_shape, data_shape), 0.08, (3/4., 4/3.))
+        image = mx.nd.image.random_flip_left_right(image)
+        image = mx.nd.image.to_tensor(image)
+        image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
+        return mx.nd.cast(image, dtype), label
+
+    def val_transform(image, label):
+        image = mx.image.resize_short(image, data_shape + 32)
+        image, _ = mx.image.center_crop(image, (data_shape, data_shape))
+        image = mx.nd.image.to_tensor(image)
+        image = mx.nd.image.normalize(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
+        return mx.nd.cast(image, dtype), label
+    return train_transform, val_transform
+
+def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'):
+    """Dataset loader with preprocessing."""
+    train_dir = os.path.join(root, 'train')
+    train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
+    logging.info("Loading image folder %s, this may take a bit long...", train_dir)
+    train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
+    train_data = DataLoader(train_dataset, batch_size, shuffle=True,
+                            last_batch='discard', num_workers=num_workers)
+    val_dir = os.path.join(root, 'val')
+    if not os.path.isdir(os.path.join(os.path.expanduser(root, 'val', 'n01440764'))):
+        user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1'
+        raise ValueError(user_warning)
+    logging.info("Loading image folder %s, this may take a bit long...", val_dir)
+    val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
+    val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
+    return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
 
 
 class DummyIter(mx.io.DataIter):
diff --git a/example/gluon/image_classification.py b/example/gluon/image_classification.py
index 529b977..9acfda5 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -17,9 +17,8 @@
 
 from __future__ import division
 
-import argparse, time
+import argparse, time, os
 import logging
-logging.basicConfig(level=logging.INFO)
 
 import mxnet as mx
 from mxnet import gluon
@@ -27,26 +26,40 @@ from mxnet.gluon import nn
 from mxnet.gluon.model_zoo import vision as models
 from mxnet import autograd as ag
 from mxnet.test_utils import get_mnist_iterator
+from mxnet.metric import Accuracy, TopKAccuracy, CompositeEvalMetric
+import numpy as np
 
 from data import *
 
+# logging
+logging.basicConfig(level=logging.INFO)
+fh = logging.FileHandler('image-classification.log')
+logger = logging.getLogger()
+logger.addHandler(fh)
+formatter = logging.Formatter('%(message)s')
+fh.setFormatter(formatter)
+fh.setLevel(logging.DEBUG)
+logging.debug('\n%s', '-' * 100)
+formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
+fh.setFormatter(formatter)
+
 # CLI
 parser = argparse.ArgumentParser(description='Train a model for image classification.')
 parser.add_argument('--dataset', type=str, default='cifar10',
-                    help='dataset to use. options are mnist, cifar10, and dummy.')
-parser.add_argument('--train-data', type=str, default='',
-                    help='training record file to use, required for imagenet.')
-parser.add_argument('--val-data', type=str, default='',
-                    help='validation record file to use, required for imagenet.')
+                    help='dataset to use. options are mnist, cifar10, imagenet and dummy.')
+parser.add_argument('--data-dir', type=str, default='',
+                    help='training directory of imagenet images, contains train/val subdirs.')
 parser.add_argument('--batch-size', type=int, default=32,
                     help='training batch size per device (CPU/GPU).')
-parser.add_argument('--num-gpus', type=int, default=0,
-                    help='number of gpus to use.')
-parser.add_argument('--epochs', type=int, default=3,
+parser.add_argument('--num-worker', '-j', dest='num_workers', default=4, type=int,
+                    help='number of workers of dataloader.')
+parser.add_argument('--gpus', type=str, default='',
+                    help='ordinates of gpus to use, can be "0,1,2" or empty for cpu only.')
+parser.add_argument('--epochs', type=int, default=120,
                     help='number of training epochs.')
-parser.add_argument('--lr', type=float, default=0.01,
-                    help='learning rate. default is 0.01.')
-parser.add_argument('-momentum', type=float, default=0.9,
+parser.add_argument('--lr', type=float, default=0.1,
+                    help='learning rate. default is 0.1.')
+parser.add_argument('--momentum', type=float, default=0.9,
                     help='momentum value for optimizer, default is 0.9.')
 parser.add_argument('--wd', type=float, default=0.0001,
                     help='weight decay rate. default is 0.0001.')
@@ -62,39 +75,64 @@ parser.add_argument('--batch-norm', action='store_true',
                     help='enable batch normalization or not in vgg. default is false.')
 parser.add_argument('--use-pretrained', action='store_true',
                     help='enable using pretrained model from gluon.')
+parser.add_argument('--prefix', default='', type=str,
+                    help='path to checkpoint prefix, default is current working dir')
+parser.add_argument('--start-epoch', default=0, type=int,
+                    help='starting epoch, 0 for fresh training, > 0 to resume')
+parser.add_argument('--resume', type=str, default='',
+                    help='path to saved weight where you want resume')
+parser.add_argument('--lr-factor', default=0.1, type=float,
+                    help='learning rate decay ratio')
+parser.add_argument('--lr-steps', default='30,60,90', type=str,
+                    help='list of learning rate decay epochs as in str')
+parser.add_argument('--dtype', default='float32', type=str,
+                    help='data type, float32 or float16 if applicable')
+parser.add_argument('--save-frequency', default=10, type=int,
+                    help='epoch frequence to save model, best model will always be saved')
 parser.add_argument('--kvstore', type=str, default='device',
                     help='kvstore to use for trainer/module.')
-parser.add_argument('--log-interval', type=int, default=50, help='Number of batches to wait before logging.')
+parser.add_argument('--log-interval', type=int, default=50,
+                    help='Number of batches to wait before logging.')
 parser.add_argument('--profile', action='store_true',
                     help='Option to turn on memory profiling for front-end, '\
                          'and prints out the memory usage by python function at the end.')
 opt = parser.parse_args()
 
-logging.info(opt)
-
+# global variables
+logger.info('Starting new image-classification task:, %s',opt)
 mx.random.seed(opt.seed)
-
+model_name = opt.model
 dataset_classes = {'mnist': 10, 'cifar10': 10, 'imagenet': 1000, 'dummy': 1000}
-
 batch_size, dataset, classes = opt.batch_size, opt.dataset, dataset_classes[opt.dataset]
-
-num_gpus = opt.num_gpus
-
+context = [mx.gpu(int(i)) for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
+num_gpus = len(context)
 batch_size *= max(1, num_gpus)
-context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
+lr_steps = [int(x) for x in opt.lr_steps.split(',') if x.strip()]
+metric = CompositeEvalMetric([Accuracy(), TopKAccuracy(5)])
 
-model_name = opt.model
+def get_model(model, ctx, opt):
+    """Model initialization."""
+    kwargs = {'ctx': ctx, 'pretrained': opt.use_pretrained, 'classes': classes}
+    if model.startswith('resnet'):
+        kwargs['thumbnail'] = opt.use_thumbnail
+    elif model.startswith('vgg'):
+        kwargs['batch_norm'] = opt.batch_norm
 
-kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
-if model_name.startswith('resnet'):
-    kwargs['thumbnail'] = opt.use_thumbnail
-elif model_name.startswith('vgg'):
-    kwargs['batch_norm'] = opt.batch_norm
+    net = models.get_model(model, **kwargs)
+    if opt.resume:
+        net.load_params(opt.resume)
+    elif not opt.use_pretrained:
+        if model in ['alexnet']:
+            net.initialize(mx.init.Normal())
+        else:
+            net.initialize(mx.init.Xavier(magnitude=2))
+    net.cast(opt.dtype)
+    return net
 
-net = models.get_model(opt.model, **kwargs)
+net = get_model(opt.model, context, opt)
 
 def get_data_iters(dataset, batch_size, num_workers=1, rank=0):
-    # get dataset iterators
+    """get dataset iterators"""
     if dataset == 'mnist':
         train_data, val_data = get_mnist_iterator(batch_size, (1, 28, 28),
                                                   num_parts=num_workers, part_index=rank)
@@ -102,14 +140,12 @@ def get_data_iters(dataset, batch_size, num_workers=1, rank=0):
         train_data, val_data = get_cifar10_iterator(batch_size, (3, 32, 32),
                                                     num_parts=num_workers, part_index=rank)
     elif dataset == 'imagenet':
+        if not opt.data_dir:
+            raise ValueError('Dir containing raw images in train/val is required for imagenet, plz specify "--data-dir"')
         if model_name == 'inceptionv3':
-            train_data, val_data = get_imagenet_iterator(opt.train_data, opt.val_data,
-                                                         batch_size, (3, 299, 299),
-                                                         num_parts=num_workers, part_index=rank)
+            train_data, val_data = get_imagenet_iterator(opt.data_dir, batch_size, opt.num_workers, 299, opt.dtype)
         else:
-            train_data, val_data = get_imagenet_iterator(opt.train_data, opt.val_data,
-                                                         batch_size, (3, 224, 224),
-                                                         num_parts=num_workers, part_index=rank)
+            train_data, val_data = get_imagenet_iterator(opt.data_dir, batch_size, opt.num_workers, 224, opt.dtype)
     elif dataset == 'dummy':
         if model_name == 'inceptionv3':
             train_data, val_data = dummy_iterator(batch_size, (3, 299, 299))
@@ -118,7 +154,7 @@ def get_data_iters(dataset, batch_size, num_workers=1, rank=0):
     return train_data, val_data
 
 def test(ctx, val_data):
-    metric = mx.metric.Accuracy()
+    metric.reset()
     val_data.reset()
     for batch in val_data:
         data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
@@ -129,27 +165,45 @@ def test(ctx, val_data):
         metric.update(label, outputs)
     return metric.get()
 
+def update_learning_rate(lr, trainer, epoch, ratio, steps):
+    """Set the learning rate to the initial value decayed by ratio every N epochs."""
+    new_lr = lr * (ratio ** int(np.sum(np.array(steps) < epoch)))
+    trainer.set_learning_rate(new_lr)
+    return trainer
+
+def save_checkpoint(epoch, top1, best_acc):
+    if opt.save_frequency and (epoch + 1) % opt.save_frequency == 0:
+        fname = os.path.join(opt.prefix, '%s_%d_acc_%.4f.params' % (opt.model, epoch, top1))
+        net.save_params(fname)
+        logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
+    if top1 > best_acc[0]:
+        best_acc[0] = top1
+        fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model))
+        net.save_params(fname)
+        logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
 
-def train(epochs, ctx):
+def train(opt, ctx):
     if isinstance(ctx, mx.Context):
         ctx = [ctx]
-    net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
     kv = mx.kv.create(opt.kvstore)
     train_data, val_data = get_data_iters(dataset, batch_size, kv.num_workers, kv.rank)
+    net.collect_params().reset_ctx(ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd',
-                            {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum},
+                            {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum,
+                             'multi_precision': True},
                             kvstore = kv)
-    metric = mx.metric.Accuracy()
     loss = gluon.loss.SoftmaxCrossEntropyLoss()
 
-    for epoch in range(epochs):
+    best_acc = [0]
+    for epoch in range(opt.start_epoch, opt.epochs):
+        trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
         tic = time.time()
         train_data.reset()
         metric.reset()
         btic = time.time()
         for i, batch in enumerate(train_data):
-            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
-            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
+            data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
+            label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
             outputs = []
             Ls = []
             with ag.record():
@@ -160,23 +214,23 @@ def train(epochs, ctx):
                     # on all GPUs for better speed on multiple GPUs.
                     Ls.append(L)
                     outputs.append(z)
-                for L in Ls:
-                    L.backward()
+                ag.backward(Ls)
             trainer.step(batch.data[0].shape[0])
             metric.update(label, outputs)
             if opt.log_interval and not (i+1)%opt.log_interval:
                 name, acc = metric.get()
-                logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
-                               epoch, i, batch_size/(time.time()-btic), name, acc))
+                logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'%(
+                               epoch, i, batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
             btic = time.time()
 
         name, acc = metric.get()
-        logging.info('[Epoch %d] training: %s=%f'%(epoch, name, acc))
-        logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
+        logger.info('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name[0], acc[0], name[1], acc[1]))
+        logger.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
         name, val_acc = test(ctx, val_data)
-        logging.info('[Epoch %d] validation: %s=%f'%(epoch, name, val_acc))
+        logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))
 
-    net.save_params('image-classifier-%s-%d.params'%(opt.model, epochs))
+        # save model if meet requirements
+        save_checkpoint(epoch, val_acc[0], best_acc)
 
 def main():
     if opt.mode == 'symbolic':
@@ -193,13 +247,13 @@ def main():
                 batch_end_callback = mx.callback.Speedometer(batch_size, max(1, opt.log_interval)),
                 epoch_end_callback = mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
                 optimizer = 'sgd',
-                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum},
+                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'multi_precision': True},
                 initializer = mx.init.Xavier(magnitude=2))
         mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
     else:
         if opt.mode == 'hybrid':
             net.hybridize()
-        train(opt.epochs, context)
+        train(opt, context)
 
 if __name__ == '__main__':
     if opt.profile:
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index 21c7771..36ee213 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -28,3 +28,5 @@ from . import autograd
 from . import tensorboard
 
 from . import text
+
+from . import io
diff --git a/python/mxnet/contrib/io.py b/python/mxnet/contrib/io.py
new file mode 100644
index 0000000..6020b3e
--- /dev/null
+++ b/python/mxnet/contrib/io.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Contrib data iterators for common data formats."""
+from __future__ import absolute_import
+from ..io import DataIter, DataDesc
+from .. import ndarray as nd
+
+
+class DataLoaderIter(DataIter):
+    """Returns an iterator for ``mx.gluon.data.Dataloader`` so gluon dataloader
+    can be used in symbolic module.
+
+    Parameters
+    ----------
+    loader : mxnet.gluon.data.Dataloader
+        Gluon dataloader instance
+    data_name : str, optional
+        The data name.
+    label_name : str, optional
+        The label name.
+    dtype : str, optional
+        The dtype specifier, can be float32 or float16
+
+    Example usage:
+    ----------
+    >>> import mxnet as mx
+    >>> from mxnet.gluon.data.vision import MNIST
+    >>> from mxnet.gluon.data import DataLoader
+    >>> train_dataset = MNIST(train=True)
+    >>> train_data = mx.gluon.data.DataLoader(train_dataset, 32, shuffle=True, num_workers=4)
+    >>> dataiter = mx.io.DataloaderIter(train_data)
+    >>> for batch in dataiter:
+    ...     batch.data[0].shape
+    ...
+    (32L, 28L, 28L, 1L)
+    """
+    def __init__(self, loader, data_name='data', label_name='softmax_label', dtype='float32'):
+        super(DataLoaderIter, self).__init__()
+        self._loader = loader
+        self._iter = iter(self._loader)
+        data, label = next(self._iter)
+        self.batch_size = data.shape[0]
+        self.dtype = dtype
+        self.provide_data = [DataDesc(data_name, data.shape, dtype)]
+        self.provide_label = [DataDesc(label_name, label.shape, dtype)]
+        self._current_batch = None
+        self.reset()
+
+    def reset(self):
+        self._iter = iter(self._loader)
+
+    def iter_next(self):
+        try:
+            self._current_batch = next(self._iter)
+        except StopIteration:
+            self._current_batch = None
+        return self._current_batch is not None
+
+    def getdata(self):
+        if self.getpad():
+            dshape = self._current_batch[0].shape
+            ret = nd.empty(shape=([self.batch_size] + list(dshape[1:])))
+            ret[:dshape[0]] = self._current_batch[0].astype(self.dtype)
+            return [ret]
+        return [self._current_batch[0].astype(self.dtype)]
+
+    def getlabel(self):
+        if self.getpad():
+            lshape = self._current_batch[1].shape
+            ret = nd.empty(shape=([self.batch_size] + list(lshape[1:])))
+            ret[:lshape[0]] = self._current_batch[1].astype(self.dtype)
+            return [ret]
+        return [self._current_batch[1].astype(self.dtype)]
+
+    def getpad(self):
+        return self.batch_size - self._current_batch[0].shape[0]
+
+    def getindex(self):
+        return None
diff --git a/tests/python/unittest/test_contrib_io.py b/tests/python/unittest/test_contrib_io.py
new file mode 100644
index 0000000..dbae69f
--- /dev/null
+++ b/tests/python/unittest/test_contrib_io.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet.ndarray as nd
+from mxnet.gluon.data.vision.datasets import *
+from mxnet.gluon.data.dataloader import *
+from mxnet.contrib.io import *
+from mxnet.test_utils import *
+
+def test_contrib_DataLoaderIter():
+    def test_mnist_batches(batch_size, expected, last_batch='discard'):
+        dataset = MNIST(train=False)
+        dataloader = DataLoader(dataset, batch_size, last_batch=last_batch)
+        test_iter = DataLoaderIter(dataloader)
+        batch = next(test_iter)
+        assert batch.data[0].shape == (batch_size, 28, 28, 1)
+        assert batch.label[0].shape == (batch_size,)
+        count = 0
+        test_iter.reset()
+        for batch in test_iter:
+            count += 1
+        assert count == expected, "expected {} batches, given {}".format(expected, count)
+
+    num_examples = 10000
+    test_mnist_batches(50, num_examples // 50, 'discard')
+    test_mnist_batches(31, num_examples // 31, 'discard')
+    test_mnist_batches(31, num_examples // 31, 'rollover')
+    test_mnist_batches(31, num_examples // 31 + 1, 'keep')
+
+
+if __name__ == "__main__":
+    test_contrib_DataLoaderIter()

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.