You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/30 19:37:45 UTC

[GitHub] piiswrong closed pull request #8901: add comments and sanity check

piiswrong closed pull request #8901: add comments and sanity check
URL: https://github.com/apache/incubator-mxnet/pull/8901
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index 9b4d197906..f12584bca4 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -41,12 +41,53 @@ def __len__(self):
         raise NotImplementedError
 
     def transform(self, fn, lazy=True):
+        """Returns a new dataset with each sample transformed by the
+        transformer function `fn`.
+
+        Parameters
+        ----------
+        fn : callable
+            A transformer function that takes a sample as input and
+            returns the transformed sample.
+        lazy : bool, default True
+            If False, transforms all samples at once. Otherwise,
+            transforms each sample on demand. Note that if `fn`
+            is stochastic, you must set lazy to True or you will
+            get the same result on all epochs.
+
+        Returns
+        -------
+        Dataset
+            The transformed dataset.
+        """
         trans = _LazyTransformDataset(self, fn)
         if lazy:
             return trans
         return SimpleDataset([i for i in trans])
 
     def transform_first(self, fn, lazy=True):
+        """Returns a new dataset with the first element of each sample
+        transformed by the transformer function `fn`.
+
+        This is useful, for example, when you only want to transform data
+        while keeping label as is.
+
+        Parameters
+        ----------
+        fn : callable
+            A transformer function that takes the first elemtn of a sample
+            as input and returns the transformed element.
+        lazy : bool, default True
+            If False, transforms all samples at once. Otherwise,
+            transforms each sample on demand. Note that if `fn`
+            is stochastic, you must set lazy to True or you will
+            get the same result on all epochs.
+
+        Returns
+        -------
+        Dataset
+            The transformed dataset.
+        """
         def base_fn(x, *args):
             if args:
                 return (fn(x),) + args
@@ -55,6 +96,13 @@ def base_fn(x, *args):
 
 
 class SimpleDataset(Dataset):
+    """Simple Dataset wrapper for lists and arrays.
+
+    Parameters
+    ----------
+    data : dataset-like object
+        Any object that implements `len()` and `[]`.
+    """
     def __init__(self, data):
         self._data = data
 
@@ -66,6 +114,7 @@ def __getitem__(self, idx):
 
 
 class _LazyTransformDataset(Dataset):
+    """Lazily transformed dataset."""
     def __init__(self, data, fn):
         self._data = data
         self._fn = fn
@@ -81,13 +130,14 @@ def __getitem__(self, idx):
 
 
 class ArrayDataset(Dataset):
-    """A dataset of multiple arrays.
+    """A dataset that combines multiple dataset-like objects, e.g.
+    Datasets, lists, arrays, etc.
 
-    The i-th sample is `(x1[i], x2[i], ...)`.
+    The i-th sample is defined as `(x1[i], x2[i], ...)`.
 
     Parameters
     ----------
-    *args : one or more arrays
+    *args : one or more dataset-like objects
         The data arrays.
     """
     def __init__(self, *args):
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 931d644b17..8daf88e6f4 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -20,11 +20,18 @@
 from .. import dataset
 from ...block import Block, HybridBlock
 from ...nn import Sequential, HybridSequential
-from .... import ndarray, initializer
-from ....base import _Null
+from .... import ndarray, initializer, image
+from ....base import _Null, numeric_types
 
 
 class Compose(Sequential):
+    """Sequentially composes multiple transforms.
+
+    Parameters
+    ----------
+    transforms : list of transform Blocks.
+        The list of transforms to be composed.
+    """
     def __init__(self, transforms):
         super(Compose, self).__init__()
         transforms.append(None)
@@ -34,18 +41,25 @@ def __init__(self, transforms):
                 hybrid.append(i)
                 continue
             elif len(hybrid) == 1:
-                self.register_child(hybrid[0])
+                self.add(hybrid[0])
             elif len(hybrid) > 1:
                 hblock = HybridSequential()
                 for j in hybrid:
                     hblock.add(j)
-                self.register_child(hblock)
+                self.add(hblock)
             if i is not None:
-                self.register_child(i)
+                self.add(i)
         self.hybridize()
 
 
 class Cast(HybridBlock):
+    """Cast input to a specific data type
+
+    Parameters
+    ----------
+    dtype : str, default 'float32'
+        The target data type, in string or `numpy.dtype`.
+    """
     def __init__(self, dtype='float32'):
         super(Cast, self).__init__()
         self._dtype = dtype
@@ -55,6 +69,12 @@ def hybrid_forward(self, F, x):
 
 
 class ToTensor(HybridBlock):
+    """Converts an image NDArray to a tensor NDArray.
+
+    Converts an image NDArray of shape (H x W x C) in the range
+    [0, 255] to a float32 tensor NDArray of shape (C x H x W) in
+    the range [0, 1).
+    """
     def __init__(self):
         super(ToTensor, self).__init__()
 
@@ -63,6 +83,23 @@ def hybrid_forward(self, F, x):
 
 
 class Normalize(HybridBlock):
+    """Normalize an tensor of shape (C x H x W) with mean and
+    standard deviation.
+
+    Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
+    this transform normalizes each channel of the input tensor with::
+
+        output[i] = (input[i] - mi) / si
+
+    If mean or std is scalar, the same value will be applied to all channels.
+
+    Parameters
+    ----------
+    mean : float or tuple of floats
+        The mean values.
+    std : float or tuple of floats
+        The standard deviation values.
+    """
     def __init__(self, mean, std):
         super(Normalize, self).__init__()
         self._mean = mean
@@ -72,101 +109,189 @@ def hybrid_forward(self, F, x):
         return F.image.normalize(x, self._mean, self._std)
 
 
-class RandomResizedCrop(HybridBlock):
-    def __init__(self, size, area=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
+class RandomResizedCrop(Block):
+    """Crop the input image with random scale and aspect ratio.
+
+    Makes a crop of the original image with random size (default: 0.08
+    to 1.0 of the original image size) and random aspect ratio (default:
+    3/4 to 4/3), then resize it to the specified size.
+
+    Parameters
+    ----------
+    size : int or tuple of (W, H)
+        Size of the final output.
+    scale : tuple of two floats
+        If scale is `(min_area, max_area)`, the cropped image's area will
+        range from min_area to max_area of the original image's area
+    ratio : tuple of two floats
+        Range of aspect ratio of the cropped image before resizing.
+    interpolation : int
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+    """
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
                  interpolation=2):
         super(RandomResizedCrop, self).__init__()
-        self._args = (size, area, ratio, interpolation)
-
-    def hybrid_forward(self, F, x):
-        return F.image.random_resized_crop(x, *self._args)
-
-
-class CenterCrop(HybridBlock):
-    def __init__(self, size):
+        if isinstance(size, numeric_types):
+            size = (size, size)
+        self._args = (size, scale[0], ratio, interpolation)
+
+    def forward(self, x):
+        return image.random_size_crop(x, *self._args)[0]
+
+
+class CenterCrop(Block):
+    """Crops the image `src` to the given `size` by trimming on all four
+    sides and preserving the center of the image. Upsamples if `src` is
+    smaller than `size`.
+
+    Parameters
+    ----------
+    size : int or tuple of (W, H)
+        Size of output image.
+    interpolation : int
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+    """
+    def __init__(self, size, interpolation=2):
         super(CenterCrop, self).__init__()
-        self._size = size
+        if isinstance(size, numeric_types):
+            size = (size, size)
+        self._args = (size, interpolation)
+
+    def forward(self, x):
+        return image.center_crop(x, *self._args)[0]
 
-    def hybrid_forward(self, F, x):
-        return F.image.center_crop(x, size)
 
+class Resize(Block):
+    """Resize an image to the given size.
 
-class Resize(HybridBlock):
+    Parameters
+    ----------
+    size : int or tuple of (W, H)
+        Size of output image.
+    interpolation : int
+        Interpolation method for resizing. By default uses bilinear
+        interpolation. See OpenCV's resize function for available choices.
+    """
     def __init__(self, size, interpolation=2):
         super(Resize, self).__init__()
-        self._args = (size, interpolation)
+        if isinstance(size, numeric_types):
+            size = (size, size)
+        self._args = tuple(size) + (interpolation,)
+
+    def forward(self, x):
+        return image.imresize(x, *self._args)
+
+
+class RandomHorizontalFlip(HybridBlock):
+    """Randomly flip the input image horizontally with a probability
+    of 0.5.
+    """
+    def __init__(self):
+        super(RandomHorizontalFlip, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.image.resize(x, *self._args)
+        return F.image.random_horizontal_flip(x)
 
 
-class RandomFlip(HybridBlock):
-    def __init__(self, axis=1):
-        super(RandomFlip, self).__init__()
-        self._axis = axis
+class RandomVerticalFlip(HybridBlock):
+    """Randomly flip the input image vertically with a probability
+    of 0.5.
+    """
+    def __init__(self):
+        super(RandomVerticalFlip, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.image.random_flip(x, self._axis)
+        return F.image.random_vertical_flip(x)
 
 
 class RandomBrightness(HybridBlock):
-    def __init__(self, max_brightness):
+    """Randomly jitters image brightness with a factor
+    chosen from `[max(0, 1 - brightness), 1 + brightness]`.
+    """
+    def __init__(self, brightness):
         super(RandomBrightness, self).__init__()
-        self._max_brightness = max_brightness
+        self._args = (max(0, 1-brightness), 1+brightness)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_brightness(x, self._max_brightness)
+        return F.image.random_brightness(x, *self._args)
 
 
 class RandomContrast(HybridBlock):
-    def __init__(self, max_contrast):
+    """Randomly jitters image contrast with a factor
+    chosen from `[max(0, 1 - contrast), 1 + contrast]`.
+    """
+    def __init__(self, contrast):
         super(RandomContrast, self).__init__()
-        self._max_contrast = max_contrast
+        self._args = (max(0, 1-contrast), 1+contrast)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_contrast(x, self._max_contrast)
+        return F.image.random_contrast(x, *self._args)
 
 
 class RandomSaturation(HybridBlock):
-    def __init__(self, max_saturation):
+    """Randomly jitters image saturation with a factor
+    chosen from `[max(0, 1 - saturation), 1 + saturation]`.
+    """
+    def __init__(self, saturation):
         super(RandomSaturation, self).__init__()
-        self._max_saturation = max_saturation
+        self._args = (max(0, 1-saturation), 1+saturation)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_saturation(x, self._max_saturation)
+        return F.image.random_saturation(x, *self._args)
 
 
 class RandomHue(HybridBlock):
-    def __init__(self, max_hue):
+    """Randomly jitters image hue with a factor
+    chosen from `[max(0, 1 - hue), 1 + hue]`.
+    """
+    def __init__(self, hue):
         super(RandomHue, self).__init__()
-        self._max_hue = max_hue
+        self._args = (max(0, 1-hue), 1+hue)
 
     def hybrid_forward(self, F, x):
-        return F.image.random_hue(x, self._max_hue)
+        return F.image.random_hue(x, *self._args)
 
 
 class RandomColorJitter(HybridBlock):
-    def __init__(self, max_brightness=0, max_contrast=0, max_saturation=0, max_hue=0):
+    """Randomly jitters the brightness, contrast, saturation, and hue
+    of an image.
+
+    Parameters
+    ----------
+    brightness : float
+        How much to jitter brightness. brightness factor is randomly
+        chosen from `[max(0, 1 - brightness), 1 + brightness]`.
+    contrast : float
+        How much to jitter contrast. contrast factor is randomly
+        chosen from `[max(0, 1 - contrast), 1 + contrast]`.
+    saturation : float
+        How much to jitter saturation. saturation factor is randomly
+        chosen from `[max(0, 1 - saturation), 1 + saturation]`.
+    hue : float
+        How much to jitter hue. hue factor is randomly
+        chosen from `[max(0, 1 - hue), 1 + hue]`.
+    """
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
         super(RandomColorJitter, self).__init__()
-        self._args = (max_brightness, max_contrast, max_saturation, max_hue)
+        self._args = (brightness, contrast, saturation, hue)
 
     def hybrid_forward(self, F, x):
         return F.image.random_color_jitter(x, *self._args)
 
 
-class AdjustLighting(HybridBlock):
-    def __init__(self, alpha_rgb=_Null, eigval=_Null, eigvec=_Null):
-        super(AdjustLighting, self).__init__()
-        self._args = (alpha_rgb, eigval, eigvec)
-
-    def hybrid_forward(self, F, x):
-        return F.image.adjust_lighting(x, *self._args)
-
-
 class RandomLighting(HybridBlock):
-    def __init__(self, alpha_std=_Null, eigval=_Null, eigvec=_Null):
+    """Add AlexNet-style PCA-based noise to an image.
+
+    Parameters
+    ----------
+    alpha : float
+        Intensity of the image.
+    """
+    def __init__(self, alpha):
         super(RandomLighting, self).__init__()
-        self._args = (alpha_std, eigval, eigvec)
+        self._alpha = alpha
 
     def hybrid_forward(self, F, x):
-        return F.image.random_lighting(x, *self._args)
\ No newline at end of file
+        return F.image.random_lighting(x, self._alpha)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index 63c5d28b7c..c72ef7c8c1 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -107,6 +107,27 @@ def test_multi_worker():
         assert (batch.asnumpy() == i).all()
 
 
+def test_transformer():
+    from mxnet.gluon.data.vision import transforms
+
+    transform = transforms.Compose([
+		transforms.Resize(300),
+		transforms.CenterCrop(256),
+		transforms.RandomResizedCrop(224),
+		transforms.RandomHorizontalFlip(),
+		transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
+		transforms.RandomBrightness(0.1),
+		transforms.RandomContrast(0.1),
+		transforms.RandomSaturation(0.1),
+		transforms.RandomHue(0.1),
+		transforms.RandomLighting(0.1),
+		transforms.ToTensor(),
+		transforms.Normalize([0, 0, 0], [1, 1, 1])])
+
+    transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
+
+
 if __name__ == '__main__':
+    test_transformer()
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services