You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/19 00:35:21 UTC

[incubator-mxnet] 14/19: image flip op (#8759)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch vision
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit e97ef363bbae809f39c07415a6c6740b884c76f6
Author: Yizhi Liu <ja...@gmail.com>
AuthorDate: Sun Nov 26 17:20:55 2017 -0800

    image flip op (#8759)
    
    * image flip op
    
    * rm image_common.h
    
    * fix
    
    * lint code
    
    * flip optimize
---
 src/operator/image/image_random-inl.h | 66 +++++++++++++++++++++++++++++++++--
 src/operator/image/image_random.cc    | 16 +++++++++
 2 files changed, 79 insertions(+), 3 deletions(-)

diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index ebbf60a..5c552b2 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -28,14 +28,20 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <vector>
-#include <opencv2/opencv.hpp>
-#include <opencv2/core/mat.hpp>
+#include <algorithm>
+#include <utility>
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 
 namespace mxnet {
 namespace op {
 
+inline bool CheckIsImage(const TBlob &image) {
+  CHECK_EQ(image.type_flag_, mshadow::kUint8) << "input type is not an image.";
+  CHECK_EQ(image.ndim(), 3) << "input dimension is not 3.";
+  CHECK(image.shape_[2] == 1 || image.shape_[2] == 3) << "image channel should be 1 or 3.";
+}
+
 static void RandomFlip(const nnvm::NodeAttrs &attrs,
                        const OpContext &ctx,
                        const std::vector<TBlob> &inputs,
@@ -76,6 +82,7 @@ static void ToTensor(const nnvm::NodeAttrs &attrs,
                      const std::vector<TBlob> &outputs) {
   CHECK_EQ(req[0], kWriteTo)
     << "`to_tensor` does not support inplace";
+  CheckIsImage(inputs[0]);
 
   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
   int channel = inputs[0].shape_[2];
@@ -101,7 +108,6 @@ struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
   }
 };
 
-
 inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
                           std::vector<TShape> *in_attrs,
                           std::vector<TShape> *out_attrs) {
@@ -145,6 +151,60 @@ static void Normalize(const nnvm::NodeAttrs &attrs,
   });
 }
 
+struct FlipParam : public dmlc::Parameter<FlipParam> {
+  int axis;
+  DMLC_DECLARE_PARAMETER(FlipParam) {
+    DMLC_DECLARE_FIELD(axis)
+    .describe("0 or 1. 0 for horizontal flip, 1 for vertical flip.");
+  }
+};
+
+#define SWAP_IF_INPLACE(dst, dst_idx, src, src_idx) \
+  if (dst == src) {                                 \
+    std::swap(dst[dst_idx], src[src_idx]);          \
+  } else {                                          \
+    dst[dst_idx] = src[src_idx];                    \
+  }
+
+template<typename DType>
+static void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
+  const int height = shape[0];
+  const int width = shape[1];
+  const int nchannel = shape[2];
+
+  const int length = width * nchannel;
+  const int height_stride = (src == dst && axis == 1) ? (height >> 1) : height;
+  const int width_stride = (src == dst && axis == 0) ? (width >> 1) : width;
+
+  for (int h = 0; h < height_stride; ++h) {
+    const int h_dst = (axis == 0) ? h : (height - h);
+    for (int w = 0; w < width_stride; ++w) {
+      const int w_dst = (axis == 0) ? (width - w) : w;
+      const int idx_dst = h_dst * length + w_dst * nchannel;
+      const int idx_src = h * length + w * nchannel;
+      SWAP_IF_INPLACE(dst, idx_dst, src, idx_src);
+      if (nchannel > 1) {
+        SWAP_IF_INPLACE(dst, idx_dst + 1, src, idx_src + 1);
+        SWAP_IF_INPLACE(dst, idx_dst + 2, src, idx_src + 2);
+      }
+    }
+  }
+}
+
+static void Flip(const nnvm::NodeAttrs &attrs,
+                  const OpContext &ctx,
+                  const std::vector<TBlob> &inputs,
+                  const std::vector<OpReqType> &req,
+                  const std::vector<TBlob> &outputs) {
+  const FlipParam &param = nnvm::get<FlipParam>(attrs.parsed);
+  CHECK(param.axis == 0 || param.axis == 1) << "flip axis must be 0 or 1.";
+  CheckIsImage(inputs[0]);
+  const TShape& ishape = inputs[0].shape_;
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    FlipImpl(ishape, inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), param.axis);
+  });
+}
+
 struct RandomBrightnessParam : public dmlc::Parameter<RandomBrightnessParam> {
   float max_brightness;
   DMLC_DECLARE_PARAMETER(RandomBrightnessParam) {
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 5b47f50..4184382 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -59,6 +59,22 @@ NNVM_REGISTER_OP(_image_normalize)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(NormalizeParam::__FIELDS__());
 
+DMLC_REGISTER_PARAMETER(FlipParam);
+NNVM_REGISTER_OP(_image_flip)
+.describe(R"code()code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<FlipParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+                                [](const NodeAttrs& attrs){
+                                  return std::vector<std::pair<int, int> >{{0, 0}};
+                                })
+.set_attr<FCompute>("FCompute<cpu>", Flip)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(FlipParam::__FIELDS__());
 
 DMLC_REGISTER_PARAMETER(RandomBrightnessParam);
 NNVM_REGISTER_OP(_image_random_brightness)

-- 
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.