You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/14 21:50:48 UTC

[incubator-mxnet] branch master updated: Improve data transform for gluon data loader (#11183)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b819fd9  Improve data transform for gluon data loader (#11183)
b819fd9 is described below

commit b819fd9dc2435a582831ffad3b1668e58664ee5d
Author: Tong He <he...@gmail.com>
AuthorDate: Thu Jun 14 14:50:39 2018 -0700

    Improve data transform for gluon data loader (#11183)
    
    * improve transforms.Resize
    
    * fix
    
    * Trigger CI
    
    * Trigger CI
    
    * improve
    
    * Trigger CI
    
    * Trigger CI
    
    * fix unittest
    
    * keep_ratio is false by default, to keep consistency
---
 python/mxnet/gluon/data/vision/transforms.py    | 31 +++++++++++++++++++------
 tests/python/unittest/test_gluon_data_vision.py | 25 ++++++++++----------
 2 files changed, 37 insertions(+), 19 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 7ec1c32..2e35a40 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -196,7 +196,7 @@ class RandomResizedCrop(Block):
         - **out**: output tensor with (H x W x C) shape.
     """
     def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
-                 interpolation=2):
+                 interpolation=1):
         super(RandomResizedCrop, self).__init__()
         if isinstance(size, numeric_types):
             size = (size, size)
@@ -233,7 +233,7 @@ class CenterCrop(Block):
     >>> transformer(image)
     <NDArray 500x1000x3 @cpu(0)>
     """
-    def __init__(self, size, interpolation=2):
+    def __init__(self, size, interpolation=1):
         super(CenterCrop, self).__init__()
         if isinstance(size, numeric_types):
             size = (size, size)
@@ -250,6 +250,9 @@ class Resize(Block):
     ----------
     size : int or tuple of (W, H)
         Size of output image.
+    keep_ratio : bool
+        Whether to resize the short edge or both edges to `size`,
+        if size is give as an integer.
     interpolation : int
         Interpolation method for resizing. By default uses bilinear
         interpolation. See OpenCV's resize function for available choices.
@@ -268,14 +271,28 @@ class Resize(Block):
     >>> transformer(image)
     <NDArray 500x1000x3 @cpu(0)>
     """
-    def __init__(self, size, interpolation=2):
+    def __init__(self, size, keep_ratio=False, interpolation=1):
         super(Resize, self).__init__()
-        if isinstance(size, numeric_types):
-            size = (size, size)
-        self._args = tuple(size) + (interpolation,)
+        self._keep = keep_ratio
+        self._size = size
+        self._interpolation = interpolation
 
     def forward(self, x):
-        return image.imresize(x, *self._args)
+        if isinstance(self._size, numeric_types):
+            if not self._keep:
+                wsize = self._size
+                hsize = self._size
+            else:
+                h, w, _ = x.shape
+                if h > w:
+                    wsize = self._size
+                    hsize = int(h * wsize / w)
+                else:
+                    hsize = self._size
+                    wsize = int(w * hsize / h)
+        else:
+            wsize, hsize = self._size
+        return image.imresize(x, wsize, hsize, self._interpolation)
 
 
 class RandomFlipLeftRight(HybridBlock):
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index a15a7e9..2ff9c5c 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -66,18 +66,19 @@ def test_transformer():
     from mxnet.gluon.data.vision import transforms
 
     transform = transforms.Compose([
-		transforms.Resize(300),
-		transforms.CenterCrop(256),
-		transforms.RandomResizedCrop(224),
-		transforms.RandomFlipLeftRight(),
-		transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
-		transforms.RandomBrightness(0.1),
-		transforms.RandomContrast(0.1),
-		transforms.RandomSaturation(0.1),
-		transforms.RandomHue(0.1),
-		transforms.RandomLighting(0.1),
-		transforms.ToTensor(),
-		transforms.Normalize([0, 0, 0], [1, 1, 1])])
+        transforms.Resize(300),
+        transforms.Resize(300, keep_ratio=True),
+        transforms.CenterCrop(256),
+        transforms.RandomResizedCrop(224),
+        transforms.RandomFlipLeftRight(),
+        transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
+        transforms.RandomBrightness(0.1),
+        transforms.RandomContrast(0.1),
+        transforms.RandomSaturation(0.1),
+        transforms.RandomHue(0.1),
+        transforms.RandomLighting(0.1),
+        transforms.ToTensor(),
+        transforms.Normalize([0, 0, 0], [1, 1, 1])])
 
     transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
 

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.