You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/22 20:43:45 UTC
[incubator-mxnet] 14/20: image flip op (#8759)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit a15a6e7c491912c84b588ae7aa606ae8d4f48df9
Author: Yizhi Liu <ja...@gmail.com>
AuthorDate: Sun Nov 26 17:20:55 2017 -0800
image flip op (#8759)
* image flip op
* rm image_common.h
* fix
* lint code
* flip optimize
---
src/operator/image/image_random-inl.h | 66 +++++++++++++++++++++++++++++++++--
src/operator/image/image_random.cc | 16 +++++++++
2 files changed, 79 insertions(+), 3 deletions(-)
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index ebbf60a..5c552b2 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -28,14 +28,20 @@
#include <mxnet/base.h>
#include <algorithm>
#include <vector>
-#include <opencv2/opencv.hpp>
-#include <opencv2/core/mat.hpp>
+#include <algorithm>
+#include <utility>
#include "../mxnet_op.h"
#include "../operator_common.h"
namespace mxnet {
namespace op {
+inline bool CheckIsImage(const TBlob &image) {
+ CHECK_EQ(image.type_flag_, mshadow::kUint8) << "input type is not an image.";
+ CHECK_EQ(image.ndim(), 3) << "input dimension is not 3.";
+ CHECK(image.shape_[2] == 1 || image.shape_[2] == 3) << "image channel should be 1 or 3.";
+}
+
static void RandomFlip(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
@@ -76,6 +82,7 @@ static void ToTensor(const nnvm::NodeAttrs &attrs,
const std::vector<TBlob> &outputs) {
CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace";
+ CheckIsImage(inputs[0]);
int length = inputs[0].shape_[0] * inputs[0].shape_[1];
int channel = inputs[0].shape_[2];
@@ -101,7 +108,6 @@ struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
}
};
-
inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
@@ -145,6 +151,60 @@ static void Normalize(const nnvm::NodeAttrs &attrs,
});
}
+struct FlipParam : public dmlc::Parameter<FlipParam> {
+ int axis;
+ DMLC_DECLARE_PARAMETER(FlipParam) {
+ DMLC_DECLARE_FIELD(axis)
+ .describe("0 or 1. 0 for horizontal flip, 1 for vertical flip.");
+ }
+};
+
+#define SWAP_IF_INPLACE(dst, dst_idx, src, src_idx) \
+ if (dst == src) { \
+ std::swap(dst[dst_idx], src[src_idx]); \
+ } else { \
+ dst[dst_idx] = src[src_idx]; \
+ }
+
+template<typename DType>
+static void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
+ const int height = shape[0];
+ const int width = shape[1];
+ const int nchannel = shape[2];
+
+ const int length = width * nchannel;
+ const int height_stride = (src == dst && axis == 1) ? (height >> 1) : height;
+ const int width_stride = (src == dst && axis == 0) ? (width >> 1) : width;
+
+ for (int h = 0; h < height_stride; ++h) {
+ const int h_dst = (axis == 0) ? h : (height - h);
+ for (int w = 0; w < width_stride; ++w) {
+ const int w_dst = (axis == 0) ? (width - w) : w;
+ const int idx_dst = h_dst * length + w_dst * nchannel;
+ const int idx_src = h * length + w * nchannel;
+ SWAP_IF_INPLACE(dst, idx_dst, src, idx_src);
+ if (nchannel > 1) {
+ SWAP_IF_INPLACE(dst, idx_dst + 1, src, idx_src + 1);
+ SWAP_IF_INPLACE(dst, idx_dst + 2, src, idx_src + 2);
+ }
+ }
+ }
+}
+
+static void Flip(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<TBlob> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &outputs) {
+ const FlipParam ¶m = nnvm::get<FlipParam>(attrs.parsed);
+ CHECK(param.axis == 0 || param.axis == 1) << "flip axis must be 0 or 1.";
+ CheckIsImage(inputs[0]);
+ const TShape& ishape = inputs[0].shape_;
+ MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ FlipImpl(ishape, inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), param.axis);
+ });
+}
+
struct RandomBrightnessParam : public dmlc::Parameter<RandomBrightnessParam> {
float max_brightness;
DMLC_DECLARE_PARAMETER(RandomBrightnessParam) {
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 5b47f50..4184382 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -59,6 +59,22 @@ NNVM_REGISTER_OP(_image_normalize)
.add_argument("data", "NDArray-or-Symbol", "The input.")
.add_arguments(NormalizeParam::__FIELDS__());
+DMLC_REGISTER_PARAMETER(FlipParam);
+NNVM_REGISTER_OP(_image_flip)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<FlipParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0, 0}};
+ })
+.set_attr<FCompute>("FCompute<cpu>", Flip)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(FlipParam::__FIELDS__());
DMLC_REGISTER_PARAMETER(RandomBrightnessParam);
NNVM_REGISTER_OP(_image_random_brightness)
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.