You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/14 21:50:48 UTC
[incubator-mxnet] branch master updated: Improve data transform for
gluon data loader (#11183)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b819fd9 Improve data transform for gluon data loader (#11183)
b819fd9 is described below
commit b819fd9dc2435a582831ffad3b1668e58664ee5d
Author: Tong He <he...@gmail.com>
AuthorDate: Thu Jun 14 14:50:39 2018 -0700
Improve data transform for gluon data loader (#11183)
* improve transforms.Resize
* fix
* Trigger CI
* Trigger CI
* improve
* Trigger CI
* Trigger CI
* fix unittest
* keep_ratio is false by default, to keep consistency
---
python/mxnet/gluon/data/vision/transforms.py | 31 +++++++++++++++++++------
tests/python/unittest/test_gluon_data_vision.py | 25 ++++++++++----------
2 files changed, 37 insertions(+), 19 deletions(-)
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 7ec1c32..2e35a40 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -196,7 +196,7 @@ class RandomResizedCrop(Block):
- **out**: output tensor with (H x W x C) shape.
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
- interpolation=2):
+ interpolation=1):
super(RandomResizedCrop, self).__init__()
if isinstance(size, numeric_types):
size = (size, size)
@@ -233,7 +233,7 @@ class CenterCrop(Block):
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
"""
- def __init__(self, size, interpolation=2):
+ def __init__(self, size, interpolation=1):
super(CenterCrop, self).__init__()
if isinstance(size, numeric_types):
size = (size, size)
@@ -250,6 +250,9 @@ class Resize(Block):
----------
size : int or tuple of (W, H)
Size of output image.
+ keep_ratio : bool
+ Whether to resize the short edge or both edges to `size`,
+ if size is give as an integer.
interpolation : int
Interpolation method for resizing. By default uses bilinear
interpolation. See OpenCV's resize function for available choices.
@@ -268,14 +271,28 @@ class Resize(Block):
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
"""
- def __init__(self, size, interpolation=2):
+ def __init__(self, size, keep_ratio=False, interpolation=1):
super(Resize, self).__init__()
- if isinstance(size, numeric_types):
- size = (size, size)
- self._args = tuple(size) + (interpolation,)
+ self._keep = keep_ratio
+ self._size = size
+ self._interpolation = interpolation
def forward(self, x):
- return image.imresize(x, *self._args)
+ if isinstance(self._size, numeric_types):
+ if not self._keep:
+ wsize = self._size
+ hsize = self._size
+ else:
+ h, w, _ = x.shape
+ if h > w:
+ wsize = self._size
+ hsize = int(h * wsize / w)
+ else:
+ hsize = self._size
+ wsize = int(w * hsize / h)
+ else:
+ wsize, hsize = self._size
+ return image.imresize(x, wsize, hsize, self._interpolation)
class RandomFlipLeftRight(HybridBlock):
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index a15a7e9..2ff9c5c 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -66,18 +66,19 @@ def test_transformer():
from mxnet.gluon.data.vision import transforms
transform = transforms.Compose([
- transforms.Resize(300),
- transforms.CenterCrop(256),
- transforms.RandomResizedCrop(224),
- transforms.RandomFlipLeftRight(),
- transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
- transforms.RandomBrightness(0.1),
- transforms.RandomContrast(0.1),
- transforms.RandomSaturation(0.1),
- transforms.RandomHue(0.1),
- transforms.RandomLighting(0.1),
- transforms.ToTensor(),
- transforms.Normalize([0, 0, 0], [1, 1, 1])])
+ transforms.Resize(300),
+ transforms.Resize(300, keep_ratio=True),
+ transforms.CenterCrop(256),
+ transforms.RandomResizedCrop(224),
+ transforms.RandomFlipLeftRight(),
+ transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
+ transforms.RandomBrightness(0.1),
+ transforms.RandomContrast(0.1),
+ transforms.RandomSaturation(0.1),
+ transforms.RandomHue(0.1),
+ transforms.RandomLighting(0.1),
+ transforms.ToTensor(),
+ transforms.Normalize([0, 0, 0], [1, 1, 1])])
transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.