You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/22 20:43:37 UTC

[incubator-mxnet] 06/20: Add Gluon data transform (#8672)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 7999c43eb5046e1f741806581a830de7ac8ae87d
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Nov 15 15:16:51 2017 -0800

    Add Gluon data transform (#8672)
    
    * fix
    
    * fix
    
    * fix
---
 python/mxnet/gluon/data/dataset.py                 |  43 +++++-
 python/mxnet/gluon/data/vision/__init__.py         |  22 +++
 .../gluon/data/{vision.py => vision/datasets.py}   |   0
 python/mxnet/gluon/data/vision/transforms.py       | 153 +++++++++++++++++++++
 src/operator/image/image_aug_op.h                  |  70 ++++++++++
 src/operator/image/image_random.cc                 |   4 +-
 6 files changed, 289 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index 2c46f1e..35d4c5c 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -18,12 +18,14 @@
 # coding: utf-8
 # pylint: disable=
 """Dataset container."""
-__all__ = ['Dataset', 'ArrayDataset', 'RecordFileDataset']
+__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset', 'LabeledDataset',
+           'RecordFileDataset']
 
 import os
 
 from ... import recordio, ndarray
 
+
 class Dataset(object):
     """Abstract dataset class. All datasets should have this interface.
 
@@ -38,6 +40,45 @@ class Dataset(object):
     def __len__(self):
         raise NotImplementedError
 
+    def transform(self, fn, lazy=True):
+        trans = _LazyTransformDataset(self, fn)
+        if lazy:
+            return trans
+        return SimpleDataset([i for i in trans])
+
+    def transform_first(self, fn, lazy=True):
+        def base_fn(x, *args):
+            if args:
+                return (fn(x),) + args
+            return fn(x)
+        return self.transform(base_fn, lazy)
+
+
+class SimpleDataset(Dataset):
+    def __init__(self, data):
+        self._data = data
+
+    def __len__(self):
+        return len(self._data)
+
+    def __getitem__(self, idx):
+        return self._data[idx]
+
+
+class _LazyTransformDataset(Dataset):
+    def __init__(self, data, fn):
+        self._data = data
+        self._fn = fn
+
+    def __len__(self):
+        return len(self._data)
+
+    def __getitem__(self, idx):
+        item = self._data[idx]
+        if isinstance(item, tuple):
+            return self._fn(*item)
+        return self._fn(item)
+
 
 class ArrayDataset(Dataset):
     """A dataset of multiple arrays.
diff --git a/python/mxnet/gluon/data/vision/__init__.py b/python/mxnet/gluon/data/vision/__init__.py
new file mode 100644
index 0000000..8837984
--- /dev/null
+++ b/python/mxnet/gluon/data/vision/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+from .datasets import *
+
+from . import transforms
diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision/datasets.py
similarity index 100%
rename from python/mxnet/gluon/data/vision.py
rename to python/mxnet/gluon/data/vision/datasets.py
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
new file mode 100644
index 0000000..fa7c0f2
--- /dev/null
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+from .. import dataset
+from ...block import Block, HybridBlock
+from ...nn import Sequential, HybridSequential
+from .... import ndarray, initializer
+
+
+class Compose(Sequential):
+    def __init__(self, transforms):
+        super(Compose, self).__init__()
+        transforms.append(None)
+        hybrid = []
+        for i in transforms:
+            if isinstance(i, HybridBlock):
+                hybrid.append(i)
+                continue
+            elif len(hybrid) == 1:
+                self.register_child(hybrid[0])
+            elif len(hybrid) > 1:
+                hblock = HybridSequential()
+                for j in hybrid:
+                    hblock.add(j)
+                self.register_child(hblock)
+            if i is not None:
+                self.register_child(i)
+        self.hybridize()
+
+
+class Cast(HybridBlock):
+    def __init__(self, dtype='float32'):
+        super(Cast, self).__init__()
+        self._dtype = dtype
+
+    def hybrid_forward(self, F, x):
+        return F.cast(x, self._dtype)
+
+
+class ToTensor(HybridBlock):
+    def __init__(self):
+        super(ToTensor, self).__init__()
+
+    def hybrid_forward(self, F, x):
+        return F.cast(x, 'float32').transpose((2, 0, 1))
+
+
+class Normalize(HybridBlock):
+    def __init__(self, mean, std):
+        super(Normalize, self).__init__()
+        self._mean = mean
+        self._std = std
+
+    def hybrid_forward(self, F, x):
+        return F.image.normalize(x, self._mean, self._std)
+
+
+class RandomResizedCrop(HybridBlock):
+    def __init__(self, size, area=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
+                 interpolation=2):
+        super(RandomResizedCrop, self).__init__()
+        self._args = (size, area, ratio, interpolation)
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_resized_crop(x, *self._args)
+
+
+class CenterCrop(HybridBlock):
+    def __init__(self, size):
+        super(CenterCrop, self).__init__()
+        self._size = size
+
+    def hybrid_forward(self, F, x):
+        return F.image.center_crop(x, size)
+
+
+class Resize(HybridBlock):
+    def __init__(self, size, interpolation=2):
+        super(Resize, self).__init__()
+        self._args = (size, interpolation)
+
+    def hybrid_forward(self, F, x):
+        return F.image.resize(x, *self._args)
+
+
+class RandomFlip(HybridBlock):
+    def __init__(self, axis=1):
+        super(RandomFlip, self).__init__()
+        self._axis = axis
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_flip(x, self._axis)
+
+
+class RandomBrightness(HybridBlock):
+    def __init__(self, max_brightness):
+        super(RandomBrightness, self).__init__()
+        self._max_brightness = max_brightness
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_brightness(x, self._max_brightness)
+
+
+class RandomContrast(HybridBlock):
+    def __init__(self, max_contrast):
+        super(RandomContrast, self).__init__()
+        self._max_contrast = max_contrast
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_contrast(x, self._max_contrast)
+
+
+class RandomSaturation(HybridBlock):
+    def __init__(self, max_saturation):
+        super(RandomSaturation, self).__init__()
+        self._max_saturation = max_saturation
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_saturation(x, self._max_saturation)
+
+
+class RandomHue(HybridBlock):
+    def __init__(self, max_hue):
+        super(RandomHue, self).__init__()
+        self._max_hue = max_hue
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_hue(x, self._max_hue)
+
+
+class RandomColorJitter(HybridBlock):
+    def __init__(self, max_brightness=0, max_contrast=0, max_saturation=0, max_hue=0):
+        super(RandomColorJitter, self).__init__()
+        self._args = (max_brightness, max_contrast, max_saturation, max_hue)
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_color_jitter(x, *self._args)
diff --git a/src/operator/image/image_aug_op.h b/src/operator/image/image_aug_op.h
new file mode 100644
index 0000000..40315ec
--- /dev/null
+++ b/src/operator/image/image_aug_op.h
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
+#define MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include "../mshadow_op.h"
+#include "../elemwise_op_common.h"
+#include "../mxnet_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
+  nnvm::Tuple<float> mean, std;
+  DMLC_DECLARE_PARAMETER(NormalizeParam) {
+    DMLC_DECLARE_FIELD(mean).set_default(nnvm::Tuple<float>({0.f}))
+      .describe("");
+    DMLC_DECLARE_FIELD(std).set_default(nnvm::Tuple<float>({1.f}))
+      .describe("");
+  }
+};
+
+
+void NormalizeCompute(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  using namespace mxnet_op;
+  const auto& params = dmlc::get<NormalizeParam>(attrs.parsed);
+  CHECK_NE(req[0], kAddTo);
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    auto num_channel = inputs[0].shape_[0];
+    auto size = inputs[0].Size(1, inputs[0].ndim());
+    nnvm::Tuple<DType> mean(params.mean.begin(), params.mean.end());
+    nnvm::Tuple<DType> std(params.std.begin(), params.std.end());
+    DType* src = inputs[0].dptr<DType>();
+    DType* dst = outputs[0].dptr<DType>();
+    for (int i = 0; i < num_channel; ++i) {
+      for (int j = 0; j < size; ++j, ++out, ++src) {
+        *out = (*src - mean[i]) / std[i];
+      }
+    }
+  });
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 83abc17..63f7904 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -25,8 +25,8 @@
 
 #include <mxnet/base.h>
 #include "./image_random-inl.h"
-#include "../../operator/operator_common.h"
-#include "../../operator/elemwise_op_common.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
 
 
 

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.