You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/12 22:13:30 UTC

[GitHub] piiswrong closed pull request #7929: Merge `mxbox` into `gluon.data`

piiswrong closed pull request #7929: Merge `mxbox` into `gluon.data`
URL: https://github.com/apache/incubator-mxnet/pull/7929
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/__init__.py b/python/mxnet/gluon/data/__init__.py
index 23ae3e9b3b..5ade7ad1c9 100644
--- a/python/mxnet/gluon/data/__init__.py
+++ b/python/mxnet/gluon/data/__init__.py
@@ -25,4 +25,7 @@
 
 from .dataloader import *
 
-from . import vision
+# from . import vision
+from . import vision_dataset as vision
+
+from . import transforms
diff --git a/python/mxnet/gluon/data/transforms/README.md b/python/mxnet/gluon/data/transforms/README.md
new file mode 100644
index 0000000000..644c7335f1
--- /dev/null
+++ b/python/mxnet/gluon/data/transforms/README.md
@@ -0,0 +1,92 @@
+# Transoforms
+
+## Generic
+
+### `Compose`
+Function that composes multiple transformations into one.
+
+```python
+from mxbox import transforms
+transform = transforms.Compose([
+    transforms.Scale(256), 
+    transforms.RandomSizedCrop(224),
+    transforms.RandomHorizontalFlip(),
+    transforms.mx.ToNdArray(),
+    transforms.mx.Normalize(mean = [ 0.485, 0.456, 0.406 ],
+                            std  = [ 0.229, 0.224, 0.225 ]),
+])
+```
+
+### `Lambda`
+Given a Python lambda, applies it to the input img and returns it. For example:
+
+```python
+transforms.Lambda(lambda x: x.add(10))
+```
+
+## Transformation on PIL.Image
+
+Note: This part is almost same as transformations provided in [torchvision](https://github.com/pytorch/vision#transforms-on-pilimage).
+
+### `Scale(size, interpolation=Image.BILINEAR)`
+
+Rescales the input PIL.Image to the given 'size'.
+
+If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
+
+If 'size' is a number, it will indicate the size of the smaller edge. For example, if height > width, then image will be rescaled to (size * height / width, size) - size: size of the smaller edge - interpolation: Default: PIL.Image.BILINEAR
+
+### `CenterCrop(size)` - center-crops the image to the given size
+
+Crops the given PIL.Image at the center to have a region of the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size)
+RandomCrop(size, padding=0)
+
+Crops the given PIL.Image at a random location to have a region of the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size) If padding is non-zero, then the image is first zero-padded on each side with padding pixels.
+
+### `RandomHorizontalFlip()`
+
+Randomly horizontally flips the given PIL.Image with a probability of 0.5
+
+### `RandomSizedCrop(size, interpolation=Image.BILINEAR)`
+
+Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
+
+This is popularly used to train the Inception networks - size: size of the smaller edge - interpolation: Default: PIL.Image.BILINEAR
+
+### `Pad(padding, fill=0)`
+
+Pads the given image on each side with padding number of pixels, and the padding pixels are filled with pixel value fill. If a 5x5 image is padded with padding=1 then it becomes 7x7
+
+
+## Transformation on mx.ndarray
+Under namespace `mxbox.transforms.mx`, e.g, `mxbox.transforms.mx.stack()`,
+
+### `stack(sequence, axis=0)`
+
+Stack a sequences of `mx.ndarray` along with a specified new dimension.
+
+```python
+seq = [mx.nd.array(np.zeros([3, 32, 32])) for i in range(10)]
+
+stack(seq, axis=0)  # results in a [10x3x32x32] ndarray
+stack(seq, axis=1)  # results in a [3x10x32x32] ndarray
+
+# sometimes appear in classification labels
+seq = [i for i in range(10)]
+stack(seq) # results in a [10] ndarray
+```
+
+### `ToNdArray(dtype=np.float32)`
+Convert `PIL.Image` or `numpy` to `mx.ndarray`. Default dtype is `np.float32`, which should be compatible with popular graphic cards. If you want to try higher precision, or your card does not support `float32`, you can set it by yourself.
+
+Note: `ToNdArray()` will automatically transpose channels from `NxHxW` to `WxHxN` to fit mxnet preference.
+
+
+```python
+img # [3x32x32]
+ToNdArray()(img) # [32x3x3]
+``` 
+### `Normalize(mean, std=[1, 1, 1])`
+
+Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the `mx.ndarray`, i.e. channel = (channel - mean) / std.
+
diff --git a/python/mxnet/gluon/data/transforms/__init__.py b/python/mxnet/gluon/data/transforms/__init__.py
new file mode 100644
index 0000000000..04dc79b63e
--- /dev/null
+++ b/python/mxnet/gluon/data/transforms/__init__.py
@@ -0,0 +1,5 @@
+from .general import *
+
+from . import numpyTool as np
+from . import mxnetTool as mx
+
diff --git a/python/mxnet/gluon/data/transforms/general.py b/python/mxnet/gluon/data/transforms/general.py
new file mode 100644
index 0000000000..9541d03044
--- /dev/null
+++ b/python/mxnet/gluon/data/transforms/general.py
@@ -0,0 +1,254 @@
+from __future__ import division
+
+import math
+import random
+
+from PIL import Image, ImageOps
+import PIL
+
+try:
+    import accimage
+except ImportError:
+    accimage = None
+
+from . import numpyTool as np
+import numbers
+import types
+import collections
+
+
+class Compose(object):
+    """Composes several transforms into one.
+
+    Parameters
+    ----------
+        transforms : list of ``Transform`` objects
+            list of transforms to compose.
+
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, img):
+        for t in self.transforms:
+            img = t(img)
+        return img
+
+
+class Scale(object):
+    """Rescale the input PIL.Image to the given size.
+
+    Parameters
+    ----------
+    size : sequence or int
+        If size is a sequence like (w, h), image will be resized to match (w, h).
+        If size is an integer (w, ), then image will be resize to (w, w).
+
+    interpolation : Interpolation method (optional)
+        Prefered interpolation, default use ``PIL.Image.BILINEAR`` .
+    """
+
+    def __init__(self, size, interpolation=Image.BILINEAR):
+        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
+        self.size = size
+        self.interpolation = interpolation
+
+    def __call__(self, img):
+        """ the input PIL.Image to the given size.
+
+        Parameters
+        ----------
+            img : PIL.Image
+                Image to be scaled.
+        """
+        if isinstance(self.size, int):
+            w, h = img.size
+            if (w <= h and w == self.size) or (h <= w and h == self.size):
+                return img
+            if w < h:
+                ow = self.size
+                oh = int(self.size * h / w)
+                return img.resize((ow, oh), self.interpolation)
+            else:
+                oh = self.size
+                ow = int(self.size * w / h)
+                return img.resize((ow, oh), self.interpolation)
+        else:
+            return img.resize(self.size, self.interpolation)
+
+
+class CenterCrop(object):
+    """Crops the given PIL.Image at the center.
+
+    Parameters
+    ----------
+    size : sequence or int
+        Desired output size of the crop. If size is an int instead of sequence like (w, h), a square crop (size, size) is made.
+    """
+
+    def __init__(self, size):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+
+    def __call__(self, img):
+        """
+        Parameters
+        ----------
+            img : PIL.Image
+                Image to be cropped.
+
+        """
+        w, h = img.size
+        th, tw = self.size
+        x1 = int(round((w - tw) / 2.))
+        y1 = int(round((h - th) / 2.))
+        return img.crop((x1, y1, x1 + tw, y1 + th))
+
+
+class Pad(object):
+    """Pad the given PIL.Image on all sides with the given "pad" value.
+
+    Parameters
+    ----------
+    padding : int or sequence
+        Padding on each border. If a sequence of length 4, it is used to pad left, top, right and bottom borders respectively.
+    fill : int or float
+        Pixel fill value. Default is 0.
+    """
+
+    def __init__(self, padding, fill=0):
+        assert isinstance(padding, numbers.Number)
+        assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
+        self.padding = padding
+        self.fill = fill
+
+    def __call__(self, img):
+        """
+        Parameters
+        ----------
+        img : PIL.Image
+            Image to be padded.
+        """
+        return ImageOps.expand(img, border=self.padding, fill=self.fill)
+
+
+class Lambda(object):
+    """Apply a user-defined lambda as a transform.
+
+    Parameters
+    ----------
+    lambd : function
+        Lambda/function to be used for transform.
+    """
+
+    def __init__(self, lambd):
+        assert isinstance(lambd, types.LambdaType)
+        self.lambd = lambd
+
+    def __call__(self, img):
+        return self.lambd(img)
+
+
+class RandomCrop(object):
+    """Crop the given PIL.Image at a random location.
+
+    Parameters
+    ----------
+    size : sequence or int
+        Desired output size of the crop. If size is an int instead of sequence like (w, h), a square crop (size, size) is made.
+
+    padding : int or sequence (optional)
+        Optional padding on each border of the image. Default is 0, i.e no padding. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively.
+    """
+
+    def __init__(self, size, padding=0):
+        if isinstance(size, numbers.Number):
+            self.size = (int(size), int(size))
+        else:
+            self.size = size
+        self.padding = padding
+
+    def __call__(self, img):
+        """
+        Parameters
+        ----------
+        img : PIL.Image
+            Image to be cropped.
+
+        """
+        if self.padding > 0:
+            img = ImageOps.expand(img, border=self.padding, fill=0)
+
+        w, h = img.size
+        th, tw = self.size
+        if w == tw and h == th:
+            return img
+
+        x1 = random.randint(0, w - tw)
+        y1 = random.randint(0, h - th)
+        return img.crop((x1, y1, x1 + tw, y1 + th))
+
+
+class RandomHorizontalFlip(object):
+    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
+
+    def __call__(self, img):
+        """
+        Parameters
+        ----------
+        img : PIL.Image
+            Image to be flipped.
+        """
+        if random.random() < 0.5:
+            return img.transpose(Image.FLIP_LEFT_RIGHT)
+        return img
+
+
+class RandomSizedCrop(object):
+    """Crop the given PIL.Image to random size and aspect ratio.
+
+    A crop of random size of (0.08 to 1.0) of the original size and a random
+    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+
+    Parameters
+    ----------
+    size:
+        size of the smaller edge
+    interpolation:
+        Default: PIL.Image.BILINEAR
+    """
+
+    def __init__(self, size, interpolation=Image.BILINEAR):
+        self.size = size
+        self.interpolation = interpolation
+
+    def __call__(self, img):
+        for attempt in range(10):
+            area = img.size[0] * img.size[1]
+            target_area = random.uniform(0.08, 1.0) * area
+            aspect_ratio = random.uniform(3. / 4, 4. / 3)
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if random.random() < 0.5:
+                w, h = h, w
+
+            if w <= img.size[0] and h <= img.size[1]:
+                x1 = random.randint(0, img.size[0] - w)
+                y1 = random.randint(0, img.size[1] - h)
+
+                img = img.crop((x1, y1, x1 + w, y1 + h))
+                assert (img.size == (w, h))
+
+                return img.resize((self.size, self.size), self.interpolation)
+
+        # Fallback
+        scale = Scale(self.size, interpolation=self.interpolation)
+        crop = CenterCrop(self.size)
+        return crop(scale(img))
diff --git a/python/mxnet/gluon/data/transforms/mxnetTool.py b/python/mxnet/gluon/data/transforms/mxnetTool.py
new file mode 100644
index 0000000000..00c3f60866
--- /dev/null
+++ b/python/mxnet/gluon/data/transforms/mxnetTool.py
@@ -0,0 +1,137 @@
+from __future__ import division
+
+import math
+import random
+
+from PIL import Image, ImageOps
+import PIL
+
+try:
+    import accimage
+except ImportError:
+    accimage = None
+
+import numpy as np
+import numbers
+import types
+import collections
+import mxnet as mx
+
+
+def unsequeeze(input, axis):
+    try:
+        shape = input.shape
+        new_shape = shape[:axis] + (1,) + shape[axis:]
+        input = input.reshape(new_shape)
+        input = mx.nd.array(input)
+    except AttributeError:
+        # input is an integer, special judge for label
+        input = mx.nd.array([input])
+    return input
+
+
+def stack(sequence, axis=0):
+    if len(sequence) == 0:
+        raise ValueError("stack expects a non-empty sequence of tensors")
+    """
+    shape = sequence[0].shape
+    new_shape = shape[:axis] + (1,) + shape[axis:]
+    seq = [each.reshape(new_shape) for each in sequence]
+    """
+    seq = [unsequeeze(each, axis) for each in sequence]
+    return mx.nd.concatenate(seq)
+
+
+class ToNdArray(object):
+    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to mx.nd.array.
+
+    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
+    [0, 255] to a to mx.nd.array of shape (C x H x W) in the range [0.0, 255.0].
+    """
+    def __init__(self, dtype=np.float32):
+        self.dtype = dtype
+
+    def __call__(self, pic):
+        """
+        Parameters
+        ----------
+        pic (PIL.Image or numpy.ndarray):
+            Image to be converted to tensor.
+
+        """
+        if isinstance(pic, np.ndarray):
+            # handle numpy array
+            img = mx.nd.array((pic.transpose((2, 0, 1))))
+            # backward compatibility
+            return img.astype(dtype=float)
+
+        if accimage is not None and isinstance(pic, accimage.Image):
+            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
+            pic.copyto(nppic)
+            return mx.nd.array((nppic))
+
+        # handle PIL Image
+        if pic.mode == 'I':
+            img = mx.nd.array(np.array(pic, np.int32, copy=False))
+        elif pic.mode == 'I;16':
+            img = mx.nd.array(np.array(pic, np.int16, copy=False))
+        else:
+            # img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+            img = mx.nd.array(np.array(pic))
+
+        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
+        if pic.mode == 'YCbCr':
+            nchannel = 3
+        elif pic.mode == 'I;16':
+            nchannel = 1
+        else:
+            nchannel = len(pic.mode)
+
+        # now `img` is mx.nd.array
+        img = img.reshape((pic.size[1], pic.size[0], nchannel))
+        # put it from HWC to CHW format
+        # yikes, this transpose takes 80% of the loading time/CPU
+        img = mx.nd.transpose(img, axes=(2, 0, 1))
+        # img = mx.nd.expand_dims(img, axis=0)
+
+        return img.astype(dtype=self.dtype)
+
+
+class Normalize(object):
+    """Normalize an tensor image with mean and standard deviation.
+
+    Given mean: (R, G, B) and std: (R, G, B),
+    will normalize each channel of the torch.*Tensor, i.e.
+    channel = (channel - mean) / std
+
+    Parameters
+    ----------
+        mean (sequence): Sequence of means for R, G, B channels respecitvely.
+        std (sequence): Sequence of standard deviations for R, G, B channels
+            respecitvely.
+    """
+
+    def __init__(self, mean, std=[1, 1, 1]):
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, tensor):
+        """
+        Parameters
+        ----------
+        tensor : ndarray
+            Tensor image of size (C, H, W) to be normalized.
+
+        """
+        # TODO: make efficient
+        # for t, m, s in zip(tensor, self.mean, self.std):
+        #     t.sub_(m).div_(s)
+        for t, m, s in zip(tensor, self.mean, self.std):
+            t.__isub__(m).__idiv__(s)
+        return tensor
+
+
+if __name__ == "__main__":
+    data = [_ for _ in range(10)]
+    # print(data)
+    print(stack(data))
diff --git a/python/mxnet/gluon/data/transforms/numpyTool.py b/python/mxnet/gluon/data/transforms/numpyTool.py
new file mode 100644
index 0000000000..7fb77ef69e
--- /dev/null
+++ b/python/mxnet/gluon/data/transforms/numpyTool.py
@@ -0,0 +1,87 @@
+from __future__ import division
+
+import math
+import random
+from PIL import Image, ImageOps
+import PIL
+
+try:
+    import accimage
+except ImportError:
+    accimage = None
+
+import numpy as np
+import numbers
+import types
+import collections
+
+
+class ToNumpy(object):
+    def __call__(self, pic):
+
+        if not isinstance(pic, PIL.Image.Image):
+            raise TypeError('Only support PIL image, type %s is invalid' % type(pic))
+            # handle PIL Image
+
+        if pic.mode == 'I':
+            img = np.array(pic, np.int32, copy=False)
+        elif pic.mode == 'I;16':
+            img = np.array(pic, np.int16, copy=False)
+        else:
+            # img = pic.tobytes()
+            img = np.array(pic)
+
+        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
+        if pic.mode == 'YCbCr':
+            nchannel = 3
+        elif pic.mode == 'I;16':
+            nchannel = 1
+        else:
+            nchannel = len(pic.mode)
+        img = img.reshape(pic.size[1], pic.size[0], nchannel)
+
+        # put it from HWC to CHW format
+        # yikes, this transpose takes 80% of the loading time/CPU
+        # img = img.transpose(0, 1).transpose(0, 2).contiguous()
+
+
+        # if isinstance(img, torch.ByteTensor):
+        #     raise NotImplementedError('Byte Image is not supported yet')
+        #     # return img.float().div(255)
+        # else:
+        #     return img
+        return img
+
+
+class Normalize(object):
+    """Normalize an tensor image with mean and standard deviation.
+
+    Given mean: (R, G, B) and std: (R, G, B),
+    will normalize each channel of the torch.*Tensor, i.e.
+    channel = (channel - mean) / std
+
+    Parameters
+    ----------
+    mean (sequence): Sequence of means for R, G, B channels respecitvely.
+    std (sequence): Sequence of standard deviations for R, G, B channels
+        respecitvely.
+    """
+
+    def __init__(self, mean, std=[1, 1, 1], axis=0):
+        self.mean = mean
+        self.std = std
+        self.axis = axis
+
+    def __call__(self, tensor):
+        """
+        Parameters
+        ----------
+        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+
+        """
+        # TODO: make efficient
+        # for t, m, s in zip(tensor, self.mean, self.std):
+        #     t.sub_(m).div_(s)
+        for t, m, s in zip(tensor, self.mean, self.std):
+            t.__isub__(m).__idiv__(s)
+        return tensor
diff --git a/python/mxnet/gluon/data/vision_dataset/CIFAR.py b/python/mxnet/gluon/data/vision_dataset/CIFAR.py
new file mode 100644
index 0000000000..c963a437d6
--- /dev/null
+++ b/python/mxnet/gluon/data/vision_dataset/CIFAR.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Dataset container."""
+
+import os
+import gzip
+import tarfile
+import struct
+import warnings
+import numpy as np
+
+from .. import dataset
+from ...utils import download, check_sha1
+from .... import nd, image, recordio
+
+from .utils import _DownloadedDataset
+
+# TODO: add cifar 100
+
+class CIFAR10(_DownloadedDataset):
+    """CIFAR10 image classification dataset from `https://www.cs.toronto.edu/~kriz/cifar.html`_.
+
+    Each sample is an image (in 3D NDArray) with shape (32, 32, 1).
+
+    Parameters
+    ----------
+    root : str
+        Path to temp folder for storing data.
+    train : bool
+        Whether to load the training or testing set.
+    transform : function
+        A user defined callback that transforms each instance. For example::
+
+            transform=lambda data, label: (data.astype(np.float32)/255, label)
+    """
+
+    def __init__(self, root='~/.mxnet/datasets/cifar10', train=True,
+                 transform=None):
+        self._file_hashes = {'data_batch_1.bin': 'aadd24acce27caa71bf4b10992e9e7b2d74c2540',
+                             'data_batch_2.bin': 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795',
+                             'data_batch_3.bin': '1dd00a74ab1d17a6e7d73e185b69dbf31242f295',
+                             'data_batch_4.bin': 'aab85764eb3584312d3c7f65fd2fd016e36a258e',
+                             'data_batch_5.bin': '26e2849e66a845b7f1e4614ae70f4889ae604628',
+                             'test_batch.bin': '67eb016db431130d61cd03c7ad570b013799c88c'}
+        super(CIFAR10, self).__init__(root, train, transform)
+
+    def _read_batch(self, filename):
+        with open(filename, 'rb') as fin:
+            data = np.fromstring(fin.read(), dtype=np.uint8).reshape(-1, 3072+1)
+
+        return data[:, 1:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \
+               data[:, 0].astype(np.int32)
+
+    def _get_data(self):
+        file_paths = [(name, os.path.join(self._root, 'cifar-10-batches-bin/', name))
+                      for name in self._file_hashes]
+        if any(not os.path.exists(path) or not check_sha1(path, self._file_hashes[name])
+               for name, path in file_paths):
+            url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+            filename = download(url, self._root,
+                                sha1_hash='e8aa088b9774a44ad217101d2e2569f823d2d491')
+
+            with tarfile.open(filename) as tar:
+                tar.extractall(self._root)
+
+        if self._train:
+            filename = os.path.join(self._root, 'cifar-10-batches-bin/data_batch_%d.bin')
+            data, label = zip(*[self._read_batch(filename%i) for i in range(1, 6)])
+            data = np.concatenate(data)
+            label = np.concatenate(label)
+        else:
+            filename = os.path.join(self._root, 'cifar-10-batches-bin/test_batch.bin')
+            data, label = self._read_batch(filename)
+
+        self._data = nd.array(data, dtype=data.dtype)
+        self._label = label
\ No newline at end of file
diff --git a/python/mxnet/gluon/data/vision_dataset/ImageFolderDataset.py b/python/mxnet/gluon/data/vision_dataset/ImageFolderDataset.py
new file mode 100644
index 0000000000..f291c0df7f
--- /dev/null
+++ b/python/mxnet/gluon/data/vision_dataset/ImageFolderDataset.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Dataset container."""
+
+import os
+import gzip
+import tarfile
+import struct
+import warnings
+import numpy as np
+
+from .. import dataset
+from ...utils import download, check_sha1
+from .... import nd, image, recordio
+
+
+class ImageFolderDataset(dataset.Dataset):
+    """A dataset for loading image files stored in a folder structure like::
+
+        root/car/0001.jpg
+        root/car/xxxa.jpg
+        root/car/yyyb.jpg
+        root/bus/123.jpg
+        root/bus/023.jpg
+        root/bus/wwww.jpg
+
+    Parameters
+    ----------
+    root : str
+        Path to root directory.
+    flag : {0, 1}, default 1
+        If 0, always convert loaded images to greyscale (1 channel).
+        If 1, always convert loaded images to colored (3 channels).
+    transform : callable
+        A function that takes data and label and transforms them::
+
+            transform = lambda data, label: (data.astype(np.float32)/255, label)
+
+    Attributes
+    ----------
+    synsets : list
+        List of class names. `synsets[i]` is the name for the integer label `i`
+    items : list of tuples
+        List of all images in (filename, label) pairs.
+    """
+
+    def __init__(self, root, flag=1, transform=None):
+        self._root = os.path.expanduser(root)
+        self._flag = flag
+        self._transform = transform
+        self._exts = ['.jpg', '.jpeg', '.png']
+        self._list_iamges(self._root)
+
+    def _list_iamges(self, root):
+        self.synsets = []
+        self.items = []
+
+        for folder in sorted(os.listdir(root)):
+            path = os.path.join(root, folder)
+            if not os.path.isdir(path):
+                warnings.warn('Ignoring %s, which is not a directory.' % path, stacklevel=3)
+                continue
+            label = len(self.synsets)
+            self.synsets.append(folder)
+            for filename in sorted(os.listdir(path)):
+                filename = os.path.join(path, filename)
+                ext = os.path.splitext(filename)[1]
+                if ext.lower() not in self._exts:
+                    warnings.warn('Ignoring %s of type %s. Only support %s' % (
+                        filename, ext, ', '.join(self._exts)))
+                    continue
+                self.items.append((filename, label))
+
+    def __getitem__(self, idx):
+        img = image.imread(self.items[idx][0], self._flag)
+        label = self.items[idx][1]
+        if self._transform is not None:
+            return self._transform(img, label)
+        return img, label
+
+    def __len__(self):
+        return len(self.items)
diff --git a/python/mxnet/gluon/data/vision_dataset/ImageRecordDataset.py b/python/mxnet/gluon/data/vision_dataset/ImageRecordDataset.py
new file mode 100644
index 0000000000..7091f62ff1
--- /dev/null
+++ b/python/mxnet/gluon/data/vision_dataset/ImageRecordDataset.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Dataset container."""
+
+import os
+import gzip
+import tarfile
+import struct
+import warnings
+import numpy as np
+
+from .. import dataset
+from ...utils import download, check_sha1
+from .... import nd, image, recordio
+
+from .utils import _DownloadedDataset
+
+
+class ImageRecordDataset(dataset.RecordFileDataset):
+    """A dataset wrapping over a RecordIO file containing images.
+
+    Each sample is an image and its corresponding label.
+
+    Parameters
+    ----------
+    filename : str
+        Path to rec file.
+    flag : {0, 1}, default 1
+        If 0, always convert images to greyscale.
+
+        If 1, always convert images to colored (RGB).
+    transform : function
+        A user defined callback that transforms each instance. For example::
+
+            transform=lambda data, label: (data.astype(np.float32)/255, label)
+    """
+    def __init__(self, filename, flag=1, transform=None):
+        super(ImageRecordDataset, self).__init__(filename)
+        self._flag = flag
+        self._transform = transform
+
+    def __getitem__(self, idx):
+        record = super(ImageRecordDataset, self).__getitem__(idx)
+        header, img = recordio.unpack(record)
+        if self._transform is not None:
+            return self._transform(image.imdecode(img, self._flag), header.label)
+        return image.imdecode(img, self._flag), header.label
diff --git a/python/mxnet/gluon/data/vision_dataset/MNIST.py b/python/mxnet/gluon/data/vision_dataset/MNIST.py
new file mode 100644
index 0000000000..82d5a75ade
--- /dev/null
+++ b/python/mxnet/gluon/data/vision_dataset/MNIST.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Dataset container."""
+
+import os
+import gzip
+import tarfile
+import struct
+import warnings
+import numpy as np
+
+from .. import dataset
+from ...utils import download, check_sha1
+from .... import nd, image, recordio
+
+from .utils import _DownloadedDataset
+
+
+class MNIST(_DownloadedDataset):
+    """MNIST handwritten digits dataset from `http://yann.lecun.com/exdb/mnist`_.
+
+    Each sample is an image (in 3D NDArray) with shape (28, 28, 1).
+
+    Parameters
+    ----------
+    root : str
+        Path to temp folder for storing data.
+    train : bool
+        Whether to load the training or testing set.
+    transform : function
+        A user defined callback that transforms each instance. For example::
+
+            transform=lambda data, label: (data.astype(np.float32)/255, label)
+    """
+
+    def __init__(self, root='~/.mxnet/datasets/mnist', train=True,
+                 transform=None):
+        self._base_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com' \
+                         '/gluon/dataset/mnist/'
+        self._train_data = ('train-images-idx3-ubyte.gz',
+                            '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
+        self._train_label = ('train-labels-idx1-ubyte.gz',
+                             '2a80914081dc54586dbdf242f9805a6b8d2a15fc')
+        self._test_data = ('t10k-images-idx3-ubyte.gz',
+                           'c3a25af1f52dad7f726cce8cacb138654b760d48')
+        self._test_label = ('t10k-labels-idx1-ubyte.gz',
+                            '763e7fa3757d93b0cdec073cef058b2004252c17')
+        super(MNIST, self).__init__(root, train, transform)
+
+    def _get_data(self):
+        if self._train:
+            data, label = self._train_data, self._train_label
+        else:
+            data, label = self._test_data, self._test_label
+
+        data_file = download(self._base_url + data[0], self._root,
+                             sha1_hash=data[1])
+        label_file = download(self._base_url + label[0], self._root,
+                              sha1_hash=label[1])
+
+        with gzip.open(label_file, 'rb') as fin:
+            struct.unpack(">II", fin.read(8))
+            label = np.fromstring(fin.read(), dtype=np.uint8).astype(np.int32)
+
+        with gzip.open(data_file, 'rb') as fin:
+            struct.unpack(">IIII", fin.read(16))
+            data = np.fromstring(fin.read(), dtype=np.uint8)
+            data = data.reshape(len(label), 28, 28, 1)
+
+        self._data = nd.array(data, dtype=data.dtype)
+        self._label = label
+
+
+class FashionMNIST(MNIST):
+    """A dataset of Zalando's article images consisting of fashion products,
+    a drop-in replacement of the original MNIST dataset from
+    `https://github.com/zalandoresearch/fashion-mnist`_.
+
+    Each sample is an image (in 3D NDArray) with shape (28, 28, 1).
+
+    Parameters
+    ----------
+    root : str
+        Path to temp folder for storing data.
+    train : bool
+        Whether to load the training or testing set.
+    transform : function
+        A user defined callback that transforms each instance. For example::
+
+            transform=lambda data, label: (data.astype(np.float32)/255, label)
+    """
+
+    def __init__(self, root='~/.mxnet/datasets/fashion-mnist', train=True,
+                 transform=None):
+        self._base_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com' \
+                         '/gluon/dataset/fashion-mnist/'
+        self._train_data = ('train-images-idx3-ubyte.gz',
+                            '0cf37b0d40ed5169c6b3aba31069a9770ac9043d')
+        self._train_label = ('train-labels-idx1-ubyte.gz',
+                             '236021d52f1e40852b06a4c3008d8de8aef1e40b')
+        self._test_data = ('t10k-images-idx3-ubyte.gz',
+                           '626ed6a7c06dd17c0eec72fa3be1740f146a2863')
+        self._test_label = ('t10k-labels-idx1-ubyte.gz',
+                            '17f9ab60e7257a1620f4ad76bbbaf857c3920701')
+        super(MNIST, self).__init__(root, train, transform)  # pylint: disable=bad-super-call
diff --git a/python/mxnet/gluon/data/vision_dataset/__init__.py b/python/mxnet/gluon/data/vision_dataset/__init__.py
new file mode 100644
index 0000000000..a3137f4313
--- /dev/null
+++ b/python/mxnet/gluon/data/vision_dataset/__init__.py
@@ -0,0 +1,4 @@
+from .CIFAR import CIFAR10
+from .MNIST import MNIST, FashionMNIST
+from .ImageFolderDataset import ImageFolderDataset
+from .ImageRecordDataset import ImageRecordDataset
diff --git a/python/mxnet/gluon/data/vision_dataset/utils.py b/python/mxnet/gluon/data/vision_dataset/utils.py
new file mode 100644
index 0000000000..fd60e253a6
--- /dev/null
+++ b/python/mxnet/gluon/data/vision_dataset/utils.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+"""Dataset container."""
+
+import os
+import gzip
+import tarfile
+import struct
+import warnings
+import numpy as np
+
+from .. import dataset
+from ...utils import download, check_sha1
+from .... import nd, image, recordio
+
+
+class _DownloadedDataset(dataset.Dataset):
+    """Base class for MNIST, cifar10, etc."""
+    def __init__(self, root, train, transform):
+        self._root = os.path.expanduser(root)
+        self._train = train
+        self._transform = transform
+        self._data = None
+        self._label = None
+
+        if not os.path.isdir(self._root):
+            os.makedirs(self._root)
+        self._get_data()
+
+    def __getitem__(self, idx):
+        if self._transform is not None:
+            return self._transform(self._data[idx], self._label[idx])
+        return self._data[idx], self._label[idx]
+
+    def __len__(self):
+        return len(self._label)
+
+    def _get_data(self):
+        raise NotImplementedError
\ No newline at end of file
diff --git a/python/setup.py b/python/setup.py
index 029b3afa06..288115cc5e 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -28,7 +28,7 @@
 else:
     from setuptools import setup
     from setuptools.extension import Extension
-    kwargs = {'install_requires': ['numpy', 'requests', 'graphviz'], 'zip_safe': False}
+    kwargs = {'install_requires': ['numpy', 'requests', 'graphviz', 'Pillow'], 'zip_safe': False}
 from setuptools import find_packages
 
 with_cython = False


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services