You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/14 21:50:40 UTC

[GitHub] piiswrong closed pull request #11183: Improve data transform for gluon data loader

piiswrong closed pull request #11183: Improve data transform for gluon data loader
URL: https://github.com/apache/incubator-mxnet/pull/11183
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 7ec1c32d5e3..2e35a404b00 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -196,7 +196,7 @@ class RandomResizedCrop(Block):
         - **out**: output tensor with (H x W x C) shape.
     """
     def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
-                 interpolation=2):
+                 interpolation=1):
         super(RandomResizedCrop, self).__init__()
         if isinstance(size, numeric_types):
             size = (size, size)
@@ -233,7 +233,7 @@ class CenterCrop(Block):
     >>> transformer(image)
     <NDArray 500x1000x3 @cpu(0)>
     """
-    def __init__(self, size, interpolation=2):
+    def __init__(self, size, interpolation=1):
         super(CenterCrop, self).__init__()
         if isinstance(size, numeric_types):
             size = (size, size)
@@ -250,6 +250,9 @@ class Resize(Block):
     ----------
     size : int or tuple of (W, H)
         Size of output image.
+    keep_ratio : bool
+        Whether to resize the short edge or both edges to `size`,
+        if size is give as an integer.
     interpolation : int
         Interpolation method for resizing. By default uses bilinear
         interpolation. See OpenCV's resize function for available choices.
@@ -268,14 +271,28 @@ class Resize(Block):
     >>> transformer(image)
     <NDArray 500x1000x3 @cpu(0)>
     """
-    def __init__(self, size, interpolation=2):
+    def __init__(self, size, keep_ratio=False, interpolation=1):
         super(Resize, self).__init__()
-        if isinstance(size, numeric_types):
-            size = (size, size)
-        self._args = tuple(size) + (interpolation,)
+        self._keep = keep_ratio
+        self._size = size
+        self._interpolation = interpolation
 
     def forward(self, x):
-        return image.imresize(x, *self._args)
+        if isinstance(self._size, numeric_types):
+            if not self._keep:
+                wsize = self._size
+                hsize = self._size
+            else:
+                h, w, _ = x.shape
+                if h > w:
+                    wsize = self._size
+                    hsize = int(h * wsize / w)
+                else:
+                    hsize = self._size
+                    wsize = int(w * hsize / h)
+        else:
+            wsize, hsize = self._size
+        return image.imresize(x, wsize, hsize, self._interpolation)
 
 
 class RandomFlipLeftRight(HybridBlock):
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index fe360ac9708..ecc941edd83 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -66,18 +66,19 @@ def test_transformer():
     from mxnet.gluon.data.vision import transforms
 
     transform = transforms.Compose([
-		transforms.Resize(300),
-		transforms.CenterCrop(256),
-		transforms.RandomResizedCrop(224),
-		transforms.RandomFlipLeftRight(),
-		transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
-		transforms.RandomBrightness(0.1),
-		transforms.RandomContrast(0.1),
-		transforms.RandomSaturation(0.1),
-		transforms.RandomHue(0.1),
-		transforms.RandomLighting(0.1),
-		transforms.ToTensor(),
-		transforms.Normalize([0, 0, 0], [1, 1, 1])])
+        transforms.Resize(300),
+        transforms.Resize(300, keep_ratio=True),
+        transforms.CenterCrop(256),
+        transforms.RandomResizedCrop(224),
+        transforms.RandomFlipLeftRight(),
+        transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
+        transforms.RandomBrightness(0.1),
+        transforms.RandomContrast(0.1),
+        transforms.RandomSaturation(0.1),
+        transforms.RandomHue(0.1),
+        transforms.RandomLighting(0.1),
+        transforms.ToTensor(),
+        transforms.Normalize([0, 0, 0], [1, 1, 1])])
 
     transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services