You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/22 20:43:49 UTC

[incubator-mxnet] 18/20: add comments and sanity check (#8901)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit b9569ee7175c21faca6bf82a5d31a5b4a829c03f
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Thu Nov 30 11:37:24 2017 -0800

    add comments and sanity check (#8901)
---
 python/mxnet/gluon/data/dataset.py           |  56 ++++++-
 python/mxnet/gluon/data/vision/transforms.py | 229 +++++++++++++++++++++------
 tests/python/unittest/test_gluon_data.py     |  21 +++
 3 files changed, 251 insertions(+), 55 deletions(-)

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index f7ab395..4b97e43 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -41,12 +41,53 @@ class Dataset(object):
         raise NotImplementedError
 
     def transform(self, fn, lazy=True):
+        """Returns a new dataset with each sample transformed by the
+        transformer function `fn`.
+
+        Parameters
+        ----------
+        fn : callable
+            A transformer function that takes a sample as input and
+            returns the transformed sample.
+        lazy : bool, default True
+            If False, transforms all samples at once. Otherwise,
+            transforms each sample on demand. Note that if `fn`
+            is stochastic, you must set lazy to True or you will
+            get the same result on all epochs.
+
+        Returns
+        -------
+        Dataset
+            The transformed dataset.
+        """
         trans = _LazyTransformDataset(self, fn)
         if lazy:
             return trans
         return SimpleDataset([i for i in trans])
 
     def transform_first(self, fn, lazy=True):
+        """Returns a new dataset with the first element of each sample
+        transformed by the transformer function `fn`.
+
+        This is useful, for example, when you only want to transform data
+        while keeping label as is.
+
+        Parameters
+        ----------
+        fn : callable
+            A transformer function that takes the first elemtn of a sample
+            as input and returns the transformed element.
+        lazy : bool, default True
+            If False, transforms all samples at once. Otherwise,
+            transforms each sample on demand. Note that if `fn`
+            is stochastic, you must set lazy to True or you will
+            get the same result on all epochs.
+
+        Returns
+        -------
+        Dataset
+            The transformed dataset.
+        """
         def base_fn(x, *args):
             if args:
                 return (fn(x),) + args
@@ -55,6 +96,13 @@ class Dataset(object):
 
 
 class SimpleDataset(Dataset):
+    """Simple Dataset wrapper for lists and arrays.
+
+    Parameters
+    ----------
+    data : dataset-like object
+        Any object that implements `len()` and `[]`.
+    """
     def __init__(self, data):
         self._data = data
 
@@ -66,6 +114,7 @@ class SimpleDataset(Dataset):
 
 
 class _LazyTransformDataset(Dataset):
+    """Lazily transformed dataset."""
     def __init__(self, data, fn):
         self._data = data
         self._fn = fn
@@ -81,13 +130,14 @@ class _LazyTransformDataset(Dataset):
 
 
 class ArrayDataset(Dataset):
-    """A dataset of multiple arrays.
+    """A dataset that combines multiple dataset-like objects, e.g.
+    Datasets, lists, arrays, etc.
 
-    The i-th sample is `(x1[i], x2[i], ...)`.
+    The i-th sample is defined as `(x1[i], x2[i], ...)`.
 
     Parameters
     ----------
-    *args : one or more arrays
+    *args : one or more dataset-like objects
         The data arrays.
     """
     def __init__(self, *args):
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 931d644..8daf88e 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -20,11 +20,18 @@
 from .. import dataset
 from ...block import Block, HybridBlock
 from ...nn import Sequential, HybridSequential
-from .... import ndarray, initializer
-from ....base import _Null
+from .... import ndarray, initializer, image
+from ....base import _Null, numeric_types
 
 
 class Compose(Sequential):
+    """Sequentially composes multiple transforms.
+
+    Parameters
+    ----------
+    transforms : list of transform Blocks.
+        The list of transforms to be composed.
+    """
     def __init__(self, transforms):
         super(Compose, self).__init__()
         transforms.append(None)
@@ -34,18 +41,25 @@ class Compose(Sequential):
                 hybrid.append(i)
                 continue
             elif len(hybrid) == 1:
-                self.register_child(hybrid[0])
+                self.add(hybrid[0])
             elif len(hybrid) > 1:
                 hblock = HybridSequential()
                 for j in hybrid:
                     hblock.add(j)
-                self.register_child(hblock)
+                self.add(hblock)
             if i is not None:
-                self.register_child(i)
+                self.add(i)
         self.hybridize()
 
 
 class Cast(HybridBlock):
+    """Cast input to a specific data type
+
+    Parameters
+    ----------
+    dtype : str, default 'float32'
+        The target data type, in string or `numpy.dtype`.
+    """
     def __init__(self, dtype='float32'):
         super(Cast, self).__init__()
         self._dtype = dtype
@@ -55,6 +69,12 @@ class Cast(HybridBlock):
 
 
 class ToTensor(HybridBlock):
+    """Converts an image NDArray to a tensor NDArray.
+
+    Converts an image NDArray of shape (H x W x C) in the range
+    [0, 255] to a float32 tensor NDArray of shape (C x H x W) in
+    the range [0, 1).
+    """
     def __init__(self):
         super(ToTensor, self).__init__()
 
@@ -63,6 +83,23 @@ class ToTensor(HybridBlock):
 
 
 class Normalize(HybridBlock):
+    """Normalize an tensor of shape (C x H x W) with mean and
+    standard deviation.
+
+    Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
+    this transform normalizes each channel of the input tensor with::
+
+        output[i] = (input[i] - mi) / si
+
+    If mean or std is scalar, the same value will be applied to all channels.
+
+    Parameters
+    ----------
+    mean : float or tuple of floats
+        The mean values.
+    std : float or tuple of floats
+        The standard deviation values.
+    """
     def __init__(self, mean, std):
         super(Normalize, self).__init__()
         self._mean = mean
@@ -72,101 +109,189 @@ class Normalize(HybridBlock):
         return F.image.normalize(x, self._mean, self._std)
 
 
-class RandomResizedCrop(HybridBlock):
-    def __init__(self, size, area=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
+class RandomResizedCrop(Block):
+    """Crop the input image with random scale and aspect ratio.
+
+    Makes a crop of the original image with random size (default: 0.08
+    to 1.0 of the original image size) and random aspect ratio (default:
+    3/4 to 4/3), then resize it to the specified size.
+
+    Parameters
+    ----------
+    size : int or tuple of (W, H)
+        Size of the final output.
+    scale : tuple of two floats
+        If scale is `(min_area, max_area)`, the cropped image's area will
+        range from min_area to max_area of the original image's area
+    ratio : tuple of two floats
+        Range of aspect ratio of the cropped image before resizing.
+    interpolation : int
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+    """
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
                  interpolation=2):
         super(RandomResizedCrop, self).__init__()
-        self._args = (size, area, ratio, interpolation)
-
-    def hybrid_forward(self, F, x):
-        return F.image.random_resized_crop(x, *self._args)
-
-
-class CenterCrop(HybridBlock):
-    def __init__(self, size):
+        if isinstance(size, numeric_types):
+            size = (size, size)
+        self._args = (size, scale[0], ratio, interpolation)
+
+    def forward(self, x):
+        return image.random_size_crop(x, *self._args)[0]
+
+
+class CenterCrop(Block):
+    """Crops the image `src` to the given `size` by trimming on all four
+    sides and preserving the center of the image. Upsamples if `src` is
+    smaller than `size`.
+
+    Parameters
+    ----------
+    size : int or tuple of (W, H)
+        Size of output image.
+    interpolation : int
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+    """
+    def __init__(self, size, interpolation=2):
         super(CenterCrop, self).__init__()
-        self._size = size
+        if isinstance(size, numeric_types):
+            size = (size, size)
+        self._args = (size, interpolation)
+
+    def forward(self, x):
+        return image.center_crop(x, *self._args)[0]
 
-    def hybrid_forward(self, F, x):
-        return F.image.center_crop(x, size)
 
+class Resize(Block):
+    """Resize an image to the given size.
 
-class Resize(HybridBlock):
+    Parameters
+    ----------
+    size : int or tuple of (W, H)
+        Size of output image.
+    interpolation : int
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+    """
     def __init__(self, size, interpolation=2):
         super(Resize, self).__init__()
-        self._args = (size, interpolation)
+        if isinstance(size, numeric_types):
+            size = (size, size)
+        self._args = tuple(size) + (interpolation,)
+
+    def forward(self, x):
+        return image.imresize(x, *self._args)
+
+
+class RandomHorizontalFlip(HybridBlock):
+    """Randomly flip the input image horizontally with a probability
+    of 0.5.
+    """
+    def __init__(self):
+        super(RandomHorizontalFlip, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.image.resize(x, *self._args)
+        return F.image.random_horizontal_flip(x)
 
 
-class RandomFlip(HybridBlock):
-    def __init__(self, axis=1):
-        super(RandomFlip, self).__init__()
-        self._axis = axis
+class RandomVerticalFlip(HybridBlock):
+    """Randomly flip the input image vertically with a probability
+    of 0.5.
+    """
+    def __init__(self):
+        super(RandomVerticalFlip, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.image.random_flip(x, self._axis)
+        return F.image.random_vertical_flip(x)
 
 
 class RandomBrightness(HybridBlock):
-    def __init__(self, max_brightness):
+    """Randomly jitters image brightness with a factor
+    chosen from `[max(0, 1 - brightness), 1 + brightness]`.
+    """
+    def __init__(self, brightness):
         super(RandomBrightness, self).__init__()
-        self._max_brightness = max_brightness
+        self._args = (max(0, 1-brightness), 1+brightness)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_brightness(x, self._max_brightness)
+        return F.image.random_brightness(x, *self._args)
 
 
 class RandomContrast(HybridBlock):
-    def __init__(self, max_contrast):
+    """Randomly jitters image contrast with a factor
+    chosen from `[max(0, 1 - contrast), 1 + contrast]`.
+    """
+    def __init__(self, contrast):
         super(RandomContrast, self).__init__()
-        self._max_contrast = max_contrast
+        self._args = (max(0, 1-contrast), 1+contrast)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_contrast(x, self._max_contrast)
+        return F.image.random_contrast(x, *self._args)
 
 
 class RandomSaturation(HybridBlock):
-    def __init__(self, max_saturation):
+    """Randomly jitters image saturation with a factor
+    chosen from `[max(0, 1 - saturation), 1 + saturation]`.
+    """
+    def __init__(self, saturation):
         super(RandomSaturation, self).__init__()
-        self._max_saturation = max_saturation
+        self._args = (max(0, 1-saturation), 1+saturation)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_saturation(x, self._max_saturation)
+        return F.image.random_saturation(x, *self._args)
 
 
 class RandomHue(HybridBlock):
-    def __init__(self, max_hue):
+    """Randomly jitters image hue with a factor
+    chosen from `[max(0, 1 - hue), 1 + hue]`.
+    """
+    def __init__(self, hue):
         super(RandomHue, self).__init__()
-        self._max_hue = max_hue
+        self._args = (max(0, 1-hue), 1+hue)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_hue(x, self._max_hue)
+        return F.image.random_hue(x, *self._args)
 
 
 class RandomColorJitter(HybridBlock):
-    def __init__(self, max_brightness=0, max_contrast=0, max_saturation=0, max_hue=0):
+    """Randomly jitters the brightness, contrast, saturation, and hue
+    of an image.
+
+    Parameters
+    ----------
+    brightness : float
+        How much to jitter brightness. brightness factor is randomly
+        chosen from `[max(0, 1 - brightness), 1 + brightness]`.
+    contrast : float
+        How much to jitter contrast. contrast factor is randomly
+        chosen from `[max(0, 1 - contrast), 1 + contrast]`.
+    saturation : float
+        How much to jitter saturation. saturation factor is randomly
+        chosen from `[max(0, 1 - saturation), 1 + saturation]`.
+    hue : float
+        How much to jitter hue. hue factor is randomly
+        chosen from `[max(0, 1 - hue), 1 + hue]`.
+    """
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
         super(RandomColorJitter, self).__init__()
-        self._args = (max_brightness, max_contrast, max_saturation, max_hue)
+        self._args = (brightness, contrast, saturation, hue)
 
     def hybrid_forward(self, F, x):
         return F.image.random_color_jitter(x, *self._args)
 
 
-class AdjustLighting(HybridBlock):
-    def __init__(self, alpha_rgb=_Null, eigval=_Null, eigvec=_Null):
-        super(AdjustLighting, self).__init__()
-        self._args = (alpha_rgb, eigval, eigvec)
-
-    def hybrid_forward(self, F, x):
-        return F.image.adjust_lighting(x, *self._args)
-
-
 class RandomLighting(HybridBlock):
-    def __init__(self, alpha_std=_Null, eigval=_Null, eigvec=_Null):
+    """Add AlexNet-style PCA-based noise to an image.
+
+    Parameters
+    ----------
+    alpha : float
+        Intensity of the image.
+    """
+    def __init__(self, alpha):
         super(RandomLighting, self).__init__()
-        self._args = (alpha_std, eigval, eigvec)
+        self._alpha = alpha
 
     def hybrid_forward(self, F, x):
-        return F.image.random_lighting(x, *self._args)
\ No newline at end of file
+        return F.image.random_lighting(x, self._alpha)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 63c5d28..c72ef7c 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -107,6 +107,27 @@ def test_multi_worker():
         assert (batch.asnumpy() == i).all()
 
 
+def test_transformer():
+    from mxnet.gluon.data.vision import transforms
+
+    transform = transforms.Compose([
+		transforms.Resize(300),
+		transforms.CenterCrop(256),
+		transforms.RandomResizedCrop(224),
+		transforms.RandomHorizontalFlip(),
+		transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
+		transforms.RandomBrightness(0.1),
+		transforms.RandomContrast(0.1),
+		transforms.RandomSaturation(0.1),
+		transforms.RandomHue(0.1),
+		transforms.RandomLighting(0.1),
+		transforms.ToTensor(),
+		transforms.Normalize([0, 0, 0], [1, 1, 1])])
+
+    transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
+
+
 if __name__ == '__main__':
+    test_transformer()
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.