You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/14 00:32:50 UTC

[incubator-mxnet] branch vision updated: [WIP]Image Augmenter (#8633)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch vision
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/vision by this push:
     new c1de00a  [WIP]Image Augmenter (#8633)
c1de00a is described below

commit c1de00a1f5e7d87fc409b2c933e3a2b7841734dd
Author: Hu Shiwen <ya...@gmail.com>
AuthorDate: Tue Nov 14 08:32:47 2017 +0800

    [WIP]Image Augmenter (#8633)
    
    * add file
    
    * add random_brightness
    add python mx.sym/nd.image
    
    * fix lint
    
    * add image/image_common.h
    
    * add RandomContrast
    
    * change name
    
    * fix
---
 python/mxnet/base.py                               |  8 +-
 python/mxnet/ndarray/__init__.py                   |  4 +-
 .../mxnet/{symbol/__init__.py => ndarray/image.py} | 15 ++--
 python/mxnet/symbol/__init__.py                    |  4 +-
 python/mxnet/symbol/{__init__.py => image.py}      | 15 ++--
 src/operator/batch_norm_v1-inl.h                   |  2 +-
 src/operator/image/image_common.h                  | 88 +++++++++++++++++++
 src/operator/image/image_random-inl.h              | 99 ++++++++++++++++++++++
 src/operator/image/image_random.cc                 | 50 +++++++++++
 src/operator/random/multisample_op.h               |  2 +-
 src/operator/tensor/broadcast_reduce_op_index.cc   |  2 +-
 .../elemwise_binary_broadcast_op_extended.cc       |  2 +-
 .../tensor/elemwise_binary_broadcast_op_logic.cc   |  2 +-
 13 files changed, 261 insertions(+), 32 deletions(-)

diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 5882a50..cbc36d3 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -366,7 +366,7 @@ def _as_list(obj):
         return [obj]
 
 
-_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_']
+_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_']
 
 
 def _get_op_name_prefix(op_name):
@@ -420,10 +420,11 @@ def _init_op_module(root_namespace, module_name, make_op_func):
         hdl = OpHandle()
         check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
         op_name_prefix = _get_op_name_prefix(name)
+        module_name_local = module_name
         if len(op_name_prefix) > 0:
             func_name = name[len(op_name_prefix):]
             cur_module = submodule_dict[op_name_prefix]
-            module_name = "%s.%s.%s" % (root_namespace, module_name, op_name_prefix[1:-1])
+            module_name_local = "%s.%s.%s" % (root_namespace, module_name, op_name_prefix[1:-1])
         elif name.startswith('_'):
             func_name = name
             cur_module = module_internal
@@ -432,10 +433,11 @@ def _init_op_module(root_namespace, module_name, make_op_func):
             cur_module = module_op
 
         function = make_op_func(hdl, name, func_name)
-        function.__module__ = module_name
+        function.__module__ = module_name_local
         setattr(cur_module, function.__name__, function)
         cur_module.__all__.append(function.__name__)
 
+
         if op_name_prefix == '_contrib_':
             hdl = OpHandle()
             check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index 586dc9e..86a3a20 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -17,7 +17,7 @@
 
 """NDArray API of MXNet."""
 
-from . import _internal, contrib, linalg, op, random, sparse, utils
+from . import _internal, contrib, linalg, op, random, sparse, utils, image
 # pylint: disable=wildcard-import, redefined-builtin
 try:
     from .gen_op import * # pylint: disable=unused-wildcard-import
@@ -31,4 +31,4 @@ from .utils import load, save, zeros, empty, array
 from .sparse import _ndarray_cls
 from .ndarray import _GRAD_REQ_MAP
 
-__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + ['contrib', 'linalg', 'random', 'sparse']
+__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/ndarray/image.py
similarity index 67%
copy from python/mxnet/symbol/__init__.py
copy to python/mxnet/ndarray/image.py
index a07025e..0afab24 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/ndarray/image.py
@@ -15,17 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Symbol API of MXNet."""
-
-from . import _internal, contrib, linalg, op, random, sparse
-# pylint: disable=wildcard-import, redefined-builtin
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-wildcard-import
+"""Image NDArray API of MXNet."""
 try:
-    from .gen_op import * # pylint: disable=unused-wildcard-import
+    from .gen_iamge import *
 except ImportError:
     pass
-from . import register
-from .op import *
-from .symbol import *
-# pylint: enable=wildcard-import
 
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse']
+__all__ = []
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
index a07025e..a10b64e 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/__init__.py
@@ -17,7 +17,7 @@
 
 """Symbol API of MXNet."""
 
-from . import _internal, contrib, linalg, op, random, sparse
+from . import _internal, contrib, linalg, op, random, sparse, image
 # pylint: disable=wildcard-import, redefined-builtin
 try:
     from .gen_op import * # pylint: disable=unused-wildcard-import
@@ -28,4 +28,4 @@ from .op import *
 from .symbol import *
 # pylint: enable=wildcard-import
 
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse']
+__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/image.py
similarity index 67%
copy from python/mxnet/symbol/__init__.py
copy to python/mxnet/symbol/image.py
index a07025e..7624bcc 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/image.py
@@ -15,17 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Symbol API of MXNet."""
-
-from . import _internal, contrib, linalg, op, random, sparse
-# pylint: disable=wildcard-import, redefined-builtin
+# coding: utf-8
+# pylint: disable=wildcard-import, unused-wildcard-import
+"""Image Symbol API of MXNet."""
 try:
-    from .gen_op import * # pylint: disable=unused-wildcard-import
+    from .gen_iamge import *
 except ImportError:
     pass
-from . import register
-from .op import *
-from .symbol import *
-# pylint: enable=wildcard-import
 
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse']
+__all__ = []
diff --git a/src/operator/batch_norm_v1-inl.h b/src/operator/batch_norm_v1-inl.h
index ebfc469..e613b21 100644
--- a/src/operator/batch_norm_v1-inl.h
+++ b/src/operator/batch_norm_v1-inl.h
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file batch_norm-inl_v1.h
+ * \file batch_norm_v1-inl.h
  * \brief
  * \author Bing Xu
 */
diff --git a/src/operator/image/image_common.h b/src/operator/image/image_common.h
new file mode 100644
index 0000000..7cf3f96
--- /dev/null
+++ b/src/operator/image/image_common.h
@@ -0,0 +1,88 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+* \file image_common.h
+* \brief
+* \author
+*/
+#ifndef MXNET_OPERATOR_IMAGE_IMAGE_COMMON_H_
+#define MXNET_OPERATOR_IMAGE_IMAGE_COMMON_H_
+
+#include <mxnet/base.h>
+
+namespace mxnet {
+namespace op {
+
+/**
+* @brief convert TBlob to cv::Mat
+* @param input @see TBlob
+* @param hight
+* @param weight
+* @param channel
+* @return
+*/
+static cv::Mat mat_convert(TBlob input, int hight, int weight, int channel) {
+  cv::Mat m;
+  switch (input.type_flag_) {
+    case mshadow::kFloat32: {
+      typedef float DType;
+      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_32F, channel), input.dptr<DType>());
+    }
+    break;
+    case mshadow::kFloat64: {
+      typedef double DType;
+      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_64F, channel), input.dptr<DType>());
+    }
+    break;
+    case mshadow::kFloat16: {
+      typedef mshadow::half::half_t DType;
+      LOG(FATAL) << "not support type enum " << input.type_flag_;
+    }
+    break;
+    case mshadow::kUint8: {
+      typedef uint8_t DType;
+      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_8U, channel), input.dptr<DType>());
+    }
+    break;
+    case mshadow::kInt8: {
+      typedef int8_t DType;
+      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_8S, channel), input.dptr<DType>());
+    }
+    break;
+    case mshadow::kInt32: {
+      typedef int32_t DType;
+      m = cv::Mat(hight, weight, CV_MAKETYPE(CV_32S, channel), input.dptr<DType>());
+    }
+    break;
+    case mshadow::kInt64: {
+      typedef int64_t DType;
+      LOG(FATAL) << "not support type enum " << input.type_flag_;
+    }
+    break;
+    default:
+      LOG(FATAL) << "Unknown type enum " << input.type_flag_;
+  }
+  return m;
+}
+} // namespace op
+} // namespace mxnet
+
+
+#endif // MXNET_OPERATOR_IMAGE_IMAGE_COMMON_H_
\ No newline at end of file
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
new file mode 100644
index 0000000..027d587
--- /dev/null
+++ b/src/operator/image/image_random-inl.h
@@ -0,0 +1,99 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+* \file image_random-inl.h
+* \brief
+* \author
+*/
+#ifndef MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
+#define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
+
+#include <vector>
+#include <mxnet/base.h>
+#include <opencv2/opencv.hpp>
+#include <opencv2/core/mat.hpp>
+#include "mxnet/op_attr_types.h"
+#include "image_common.h"
+
+
+namespace mxnet {
+namespace op {
+struct RandomBrightnessParam : public dmlc::Parameter<RandomBrightnessParam> {
+  float max_brightness;
+  DMLC_DECLARE_PARAMETER(RandomBrightnessParam) {
+    DMLC_DECLARE_FIELD(max_brightness)
+    .set_default(0.0)
+    .describe("Max Contrast.");
+  }
+};
+
+
+template<typename xpu>
+static void RandomBrightness(const nnvm::NodeAttrs &attrs,
+                             const OpContext &ctx,
+                             const std::vector<TBlob> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<TBlob> &outputs) {
+  auto input = inputs[0];
+  auto output = outputs[0];
+  int hight = input.shape_[0];
+  int weight = input.shape_[1];
+  int channel = input.shape_[2];
+
+  auto input_mat = mat_convert(input, hight, weight, channel);
+  auto output_mat = mat_convert(output, hight, weight, channel);
+  //input_mat.convertTo(output_mat, -1, 1/255.0, 0);
+  std::default_random_engine generator;
+  const RandomBrightnessParam &param = nnvm::get<RandomBrightnessParam>(attrs.parsed);
+  float alpha_b = 1.0 + std::uniform_real_distribution<float>(-param.max_brightness, param.max_brightness)(generator);
+  output_mat.convertTo(output_mat, -1, alpha_b, 0);
+}
+
+
+template<typename xpu>
+static void RandomContrast(const nnvm::NodeAttrs &attrs,
+  const OpContext &ctx,
+  const std::vector<TBlob> &inputs,
+  const std::vector<OpReqType> &req,
+  const std::vector<TBlob> &outputs) {
+  auto input = inputs[0];
+  auto output = outputs[0];
+  int hight = input.shape_[0];
+  int weight = input.shape_[1];
+  int channel = input.shape_[2];
+
+  auto input_mat = mat_convert(input, hight, weight, channel);
+  auto output_mat = mat_convert(output, hight, weight, channel);
+  //input_mat.convertTo(output_mat, -1, 1/255.0, 0);
+  std::default_random_engine generator;
+  const RandomBrightnessParam &param = nnvm::get<RandomBrightnessParam>(attrs.parsed);
+  float alpha_c = 1.0 + std::uniform_real_distribution<float>(-param.max_brightness, param.max_brightness)(generator);
+  cv::Mat temp_;
+  cv::cvtColor(input_mat, temp_,  CV_RGB2GRAY);
+  float gray_mean = cv::mean(temp_)[0];
+  input_mat.convertTo(output_mat, -1, alpha_c, (1 - alpha_c) * gray_mean);
+
+}
+
+
+} // namespace op
+} // namespace mxnet
+
+#endif  // MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
new file mode 100644
index 0000000..3777e43
--- /dev/null
+++ b/src/operator/image/image_random.cc
@@ -0,0 +1,50 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+* \file image_random.cc
+* \brief
+* \author
+*/
+
+#include <mxnet/base.h>
+#include "./image_random-inl.h"
+#include "operator/operator_common.h"
+#include "operator/elemwise_op_common.h"
+
+
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(RandomBrightnessParam);
+NNVM_REGISTER_OP(_image_random_brightness)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<RandomBrightnessParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", RandomBrightness<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(RandomBrightnessParam::__FIELDS__());
+
+}
+}
diff --git a/src/operator/random/multisample_op.h b/src/operator/random/multisample_op.h
index f0851da..a2382f6 100644
--- a/src/operator/random/multisample_op.h
+++ b/src/operator/random/multisample_op.h
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file sampling_op.h
+ * \file multisample_op.h
  * \brief Function definitions of operators for sampling from multiple distributions
  */
 #ifndef MXNET_OPERATOR_RANDOM_MULTISAMPLE_OP_H_
diff --git a/src/operator/tensor/broadcast_reduce_op_index.cc b/src/operator/tensor/broadcast_reduce_op_index.cc
index 98cd736..3d4dd79 100644
--- a/src/operator/tensor/broadcast_reduce_op_index.cc
+++ b/src/operator/tensor/broadcast_reduce_op_index.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file broadcast_reduce_op.cc
+ * \file broadcast_reduce_op_index.cc
  * \brief CPU Implementation of broadcast and reduce functions.
  */
 #include "./broadcast_reduce_op.h"
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc
index 42da191..d9111c3 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file elemwise_binary_scalar_op.cc
+ * \file elemwise_binary_broadcast_op_extended.cc
  * \brief CPU Implementation of unary function.
  */
 #include "./elemwise_unary_op.h"
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
index 957b00b..1ead6a2 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file elemwise_binary_scalar_op.cc
+ * \file elemwise_binary_broadcast_op_logic.cc
  * \brief CPU Implementation of unary function.
  */
 #include "./elemwise_unary_op.h"

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].