You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/21 00:31:51 UTC
[incubator-mxnet] 06/08: Add Gluon data transform (#8672)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch vision
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit e741d8f290e9f490c2db0b08e8b5021db641535b
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Nov 15 15:16:51 2017 -0800
Add Gluon data transform (#8672)
* fix
* fix
* fix
---
python/mxnet/gluon/data/dataset.py | 43 +++++-
python/mxnet/gluon/data/vision/__init__.py | 22 +++
.../gluon/data/{vision.py => vision/datasets.py} | 0
python/mxnet/gluon/data/vision/transforms.py | 153 +++++++++++++++++++++
src/operator/image/image_aug_op.h | 70 ++++++++++
src/operator/image/image_random.cc | 4 +-
6 files changed, 289 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py
index 059c2a6..740a2a4 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -18,12 +18,14 @@
# coding: utf-8
# pylint: disable=
"""Dataset container."""
-__all__ = ['Dataset', 'ArrayDataset', 'RecordFileDataset']
+__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset', 'LabeledDataset',
+ 'RecordFileDataset']
import os
from ... import recordio, ndarray
+
class Dataset(object):
"""Abstract dataset class. All datasets should have this interface.
@@ -38,6 +40,45 @@ class Dataset(object):
def __len__(self):
raise NotImplementedError
+ def transform(self, fn, lazy=True):
+ trans = _LazyTransformDataset(self, fn)
+ if lazy:
+ return trans
+ return SimpleDataset([i for i in trans])
+
+ def transform_first(self, fn, lazy=True):
+ def base_fn(x, *args):
+ if args:
+ return (fn(x),) + args
+ return fn(x)
+ return self.transform(base_fn, lazy)
+
+
+class SimpleDataset(Dataset):
+ def __init__(self, data):
+ self._data = data
+
+ def __len__(self):
+ return len(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+
+class _LazyTransformDataset(Dataset):
+ def __init__(self, data, fn):
+ self._data = data
+ self._fn = fn
+
+ def __len__(self):
+ return len(self._data)
+
+ def __getitem__(self, idx):
+ item = self._data[idx]
+ if isinstance(item, tuple):
+ return self._fn(*item)
+ return self._fn(item)
+
class ArrayDataset(Dataset):
"""A dataset of multiple arrays.
diff --git a/python/mxnet/gluon/data/vision/__init__.py b/python/mxnet/gluon/data/vision/__init__.py
new file mode 100644
index 0000000..8837984
--- /dev/null
+++ b/python/mxnet/gluon/data/vision/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+from .datasets import *
+
+from . import transforms
diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision/datasets.py
similarity index 100%
rename from python/mxnet/gluon/data/vision.py
rename to python/mxnet/gluon/data/vision/datasets.py
diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
new file mode 100644
index 0000000..fa7c0f2
--- /dev/null
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+from .. import dataset
+from ...block import Block, HybridBlock
+from ...nn import Sequential, HybridSequential
+from .... import ndarray, initializer
+
+
+class Compose(Sequential):
+ def __init__(self, transforms):
+ super(Compose, self).__init__()
+ transforms.append(None)
+ hybrid = []
+ for i in transforms:
+ if isinstance(i, HybridBlock):
+ hybrid.append(i)
+ continue
+ elif len(hybrid) == 1:
+ self.register_child(hybrid[0])
+ elif len(hybrid) > 1:
+ hblock = HybridSequential()
+ for j in hybrid:
+ hblock.add(j)
+ self.register_child(hblock)
+ if i is not None:
+ self.register_child(i)
+ self.hybridize()
+
+
+class Cast(HybridBlock):
+ def __init__(self, dtype='float32'):
+ super(Cast, self).__init__()
+ self._dtype = dtype
+
+ def hybrid_forward(self, F, x):
+ return F.cast(x, self._dtype)
+
+
+class ToTensor(HybridBlock):
+ def __init__(self):
+ super(ToTensor, self).__init__()
+
+ def hybrid_forward(self, F, x):
+ return F.cast(x, 'float32').transpose((2, 0, 1))
+
+
+class Normalize(HybridBlock):
+ def __init__(self, mean, std):
+ super(Normalize, self).__init__()
+ self._mean = mean
+ self._std = std
+
+ def hybrid_forward(self, F, x):
+ return F.image.normalize(x, self._mean, self._std)
+
+
+class RandomResizedCrop(HybridBlock):
+ def __init__(self, size, area=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
+ interpolation=2):
+ super(RandomResizedCrop, self).__init__()
+ self._args = (size, area, ratio, interpolation)
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_resized_crop(x, *self._args)
+
+
+class CenterCrop(HybridBlock):
+ def __init__(self, size):
+ super(CenterCrop, self).__init__()
+ self._size = size
+
+ def hybrid_forward(self, F, x):
+ return F.image.center_crop(x, size)
+
+
+class Resize(HybridBlock):
+ def __init__(self, size, interpolation=2):
+ super(Resize, self).__init__()
+ self._args = (size, interpolation)
+
+ def hybrid_forward(self, F, x):
+ return F.image.resize(x, *self._args)
+
+
+class RandomFlip(HybridBlock):
+ def __init__(self, axis=1):
+ super(RandomFlip, self).__init__()
+ self._axis = axis
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_flip(x, self._axis)
+
+
+class RandomBrightness(HybridBlock):
+ def __init__(self, max_brightness):
+ super(RandomBrightness, self).__init__()
+ self._max_brightness = max_brightness
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_brightness(x, self._max_brightness)
+
+
+class RandomContrast(HybridBlock):
+ def __init__(self, max_contrast):
+ super(RandomContrast, self).__init__()
+ self._max_contrast = max_contrast
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_contrast(x, self._max_contrast)
+
+
+class RandomSaturation(HybridBlock):
+ def __init__(self, max_saturation):
+ super(RandomSaturation, self).__init__()
+ self._max_saturation = max_saturation
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_saturation(x, self._max_saturation)
+
+
+class RandomHue(HybridBlock):
+ def __init__(self, max_hue):
+ super(RandomHue, self).__init__()
+ self._max_hue = max_hue
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_hue(x, self._max_hue)
+
+
+class RandomColorJitter(HybridBlock):
+ def __init__(self, max_brightness=0, max_contrast=0, max_saturation=0, max_hue=0):
+ super(RandomColorJitter, self).__init__()
+ self._args = (max_brightness, max_contrast, max_saturation, max_hue)
+
+ def hybrid_forward(self, F, x):
+ return F.image.random_color_jitter(x, *self._args)
diff --git a/src/operator/image/image_aug_op.h b/src/operator/image/image_aug_op.h
new file mode 100644
index 0000000..40315ec
--- /dev/null
+++ b/src/operator/image/image_aug_op.h
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
+#define MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include "../mshadow_op.h"
+#include "../elemwise_op_common.h"
+#include "../mxnet_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
+ nnvm::Tuple<float> mean, std;
+ DMLC_DECLARE_PARAMETER(NormalizeParam) {
+ DMLC_DECLARE_FIELD(mean).set_default(nnvm::Tuple<float>({0.f}))
+ .describe("");
+ DMLC_DECLARE_FIELD(std).set_default(nnvm::Tuple<float>({1.f}))
+ .describe("");
+ }
+};
+
+
+void NormalizeCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ using namespace mxnet_op;
+ const auto& params = dmlc::get<NormalizeParam>(attrs.parsed);
+ CHECK_NE(req[0], kAddTo);
+ MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+ auto num_channel = inputs[0].shape_[0];
+ auto size = inputs[0].Size(1, inputs[0].ndim());
+ nnvm::Tuple<DType> mean(params.mean.begin(), params.mean.end());
+ nnvm::Tuple<DType> std(params.std.begin(), params.std.end());
+ DType* src = inputs[0].dptr<DType>();
+ DType* dst = outputs[0].dptr<DType>();
+ for (int i = 0; i < num_channel; ++i) {
+ for (int j = 0; j < size; ++j, ++out, ++src) {
+ *out = (*src - mean[i]) / std[i];
+ }
+ }
+ });
+}
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_OPERATOR_IMAGE_IMAGE_AUG_OP_H_
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 83abc17..63f7904 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -25,8 +25,8 @@
#include <mxnet/base.h>
#include "./image_random-inl.h"
-#include "../../operator/operator_common.h"
-#include "../../operator/elemwise_op_common.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
--
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.