You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/14 21:50:40 UTC
[GitHub] piiswrong closed pull request #11183: Improve data transform for
gluon data loader
piiswrong closed pull request #11183: Improve data transform for gluon data loader
URL: https://github.com/apache/incubator-mxnet/pull/11183
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 7ec1c32d5e3..2e35a404b00 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -196,7 +196,7 @@ class RandomResizedCrop(Block):
- **out**: output tensor with (H x W x C) shape.
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
- interpolation=2):
+ interpolation=1):
super(RandomResizedCrop, self).__init__()
if isinstance(size, numeric_types):
size = (size, size)
@@ -233,7 +233,7 @@ class CenterCrop(Block):
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
"""
- def __init__(self, size, interpolation=2):
+ def __init__(self, size, interpolation=1):
super(CenterCrop, self).__init__()
if isinstance(size, numeric_types):
size = (size, size)
@@ -250,6 +250,9 @@ class Resize(Block):
----------
size : int or tuple of (W, H)
Size of output image.
+ keep_ratio : bool
+ Whether to resize the short edge or both edges to `size`,
+ if size is give as an integer.
interpolation : int
Interpolation method for resizing. By default uses bilinear
interpolation. See OpenCV's resize function for available choices.
@@ -268,14 +271,28 @@ class Resize(Block):
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
"""
- def __init__(self, size, interpolation=2):
+ def __init__(self, size, keep_ratio=False, interpolation=1):
super(Resize, self).__init__()
- if isinstance(size, numeric_types):
- size = (size, size)
- self._args = tuple(size) + (interpolation,)
+ self._keep = keep_ratio
+ self._size = size
+ self._interpolation = interpolation
def forward(self, x):
- return image.imresize(x, *self._args)
+ if isinstance(self._size, numeric_types):
+ if not self._keep:
+ wsize = self._size
+ hsize = self._size
+ else:
+ h, w, _ = x.shape
+ if h > w:
+ wsize = self._size
+ hsize = int(h * wsize / w)
+ else:
+ hsize = self._size
+ wsize = int(w * hsize / h)
+ else:
+ wsize, hsize = self._size
+ return image.imresize(x, wsize, hsize, self._interpolation)
class RandomFlipLeftRight(HybridBlock):
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index fe360ac9708..ecc941edd83 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -66,18 +66,19 @@ def test_transformer():
from mxnet.gluon.data.vision import transforms
transform = transforms.Compose([
- transforms.Resize(300),
- transforms.CenterCrop(256),
- transforms.RandomResizedCrop(224),
- transforms.RandomFlipLeftRight(),
- transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
- transforms.RandomBrightness(0.1),
- transforms.RandomContrast(0.1),
- transforms.RandomSaturation(0.1),
- transforms.RandomHue(0.1),
- transforms.RandomLighting(0.1),
- transforms.ToTensor(),
- transforms.Normalize([0, 0, 0], [1, 1, 1])])
+ transforms.Resize(300),
+ transforms.Resize(300, keep_ratio=True),
+ transforms.CenterCrop(256),
+ transforms.RandomResizedCrop(224),
+ transforms.RandomFlipLeftRight(),
+ transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
+ transforms.RandomBrightness(0.1),
+ transforms.RandomContrast(0.1),
+ transforms.RandomSaturation(0.1),
+ transforms.RandomHue(0.1),
+ transforms.RandomLighting(0.1),
+ transforms.ToTensor(),
+ transforms.Normalize([0, 0, 0], [1, 1, 1])])
transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services