You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/12 21:28:45 UTC

[GitHub] piiswrong closed pull request #8919: [Vision] add test cases for flip, normalize, to_tensor

piiswrong closed pull request #8919: [Vision] add test cases for flip, normalize, to_tensor
URL: https://github.com/apache/incubator-mxnet/pull/8919
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 8daf88e6f4..38eb6902dc 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -184,26 +184,26 @@ def forward(self, x):
         return image.imresize(x, *self._args)
 
 
-class RandomHorizontalFlip(HybridBlock):
-    """Randomly flip the input image horizontally with a probability
+class RandomFlipLeftRight(HybridBlock):
+    """Randomly flip the input image left to right with a probability
     of 0.5.
     """
     def __init__(self):
-        super(RandomHorizontalFlip, self).__init__()
+        super(RandomFlipLeftRight, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.image.random_horizontal_flip(x)
+        return F.image.random_flip_left_right(x)
 
 
-class RandomVerticalFlip(HybridBlock):
-    """Randomly flip the input image vertically with a probability
+class RandomFlipTopBottom(HybridBlock):
+    """Randomly flip the input image top to bottom with a probability
     of 0.5.
     """
     def __init__(self):
-        super(RandomVerticalFlip, self).__init__()
+        super(RandomFlipTopBottom, self).__init__()
 
     def hybrid_forward(self, F, x):
-        return F.image.random_vertical_flip(x)
+        return F.image.random_flip_top_bottom(x)
 
 
 class RandomBrightness(HybridBlock):
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index cbc7f4012b..ec961492f9 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -85,14 +85,6 @@ void ToTensor(const nnvm::NodeAttrs &attrs,
   });
 }
 
-inline bool TensorShape(const nnvm::NodeAttrs& attrs,
-                       std::vector<TShape> *in_attrs,
-                       std::vector<TShape> *out_attrs) {
-  TShape& dshape = (*in_attrs)[0];
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
-  return true;
-}
-
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
   nnvm::Tuple<float> mean;
   nnvm::Tuple<float> std;
@@ -179,16 +171,16 @@ inline bool ImageShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-template<typename DType>
-void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
+template<typename DType, int axis>
+void FlipImpl(const TShape &shape, DType *src, DType *dst) {
   int head = 1, mid = shape[axis], tail = 1;
   for (int i = 0; i < axis; ++i) head *= shape[i];
   for (int i = axis+1; i < shape.ndim(); ++i) tail *= shape[i];
 
   for (int i = 0; i < head; ++i) {
-    for (int j = 0; j < (mid >>2); ++j) {
-      int idx1 = (i*mid + j)*tail;
-      int idx2 = idx1 + (mid - (j<<2))*tail;
+    for (int j = 0; j < (mid >> 1); ++j) {
+      int idx1 = (i*mid + j) * tail;
+      int idx2 = idx1 + (mid-(j << 1)-1) * tail;
       for (int k = 0; k < tail; ++k, ++idx1, ++idx2) {
         DType tmp = src[idx1];
         dst[idx1] = src[idx2];
@@ -198,7 +190,31 @@ void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) {
   }
 }
 
-void RandomHorizontalFlip(
+void FlipLeftRight(const nnvm::NodeAttrs &attrs,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    FlipImpl<DType, 1>(inputs[0].shape_, inputs[0].dptr<DType>(),
+                       outputs[0].dptr<DType>());
+  });
+}
+
+void FlipTopBottom(const nnvm::NodeAttrs &attrs,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  using namespace mshadow;
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    FlipImpl<DType, 0>(inputs[0].shape_, inputs[0].dptr<DType>(),
+                       outputs[0].dptr<DType>());
+  });
+}
+
+void RandomFlipLeftRight(
     const nnvm::NodeAttrs &attrs,
     const OpContext &ctx,
     const std::vector<TBlob> &inputs,
@@ -207,14 +223,19 @@ void RandomHorizontalFlip(
   using namespace mshadow;
   Stream<cpu> *s = ctx.get_stream<cpu>();
   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
-  if (std::bernoulli_distribution()(prnd->GetRndEngine())) return;
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    FlipImpl(inputs[0].shape_, inputs[0].dptr<DType>(),
-             outputs[0].dptr<DType>(), 1);
+    if (std::bernoulli_distribution()(prnd->GetRndEngine())) {
+      if (outputs[0].dptr_ != inputs[0].dptr_) {
+        std::memcpy(outputs[0].dptr_, inputs[0].dptr_, inputs[0].Size() * sizeof(DType));
+      }
+    } else {
+      FlipImpl<DType, 1>(inputs[0].shape_, inputs[0].dptr<DType>(),
+                         outputs[0].dptr<DType>());
+    }
   });
 }
 
-void RandomVerticalFlip(
+void RandomFlipTopBottom(
     const nnvm::NodeAttrs &attrs,
     const OpContext &ctx,
     const std::vector<TBlob> &inputs,
@@ -223,10 +244,15 @@ void RandomVerticalFlip(
   using namespace mshadow;
   Stream<cpu> *s = ctx.get_stream<cpu>();
   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
-  if (std::bernoulli_distribution()(prnd->GetRndEngine())) return;
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    FlipImpl(inputs[0].shape_, inputs[0].dptr<DType>(),
-             outputs[0].dptr<DType>(), 0);
+    if (std::bernoulli_distribution()(prnd->GetRndEngine())) {
+      if (outputs[0].dptr_ != inputs[0].dptr_) {
+        std::memcpy(outputs[0].dptr_, inputs[0].dptr_, inputs[0].Size() * sizeof(DType));
+      }
+    } else {
+      FlipImpl<DType, 0>(inputs[0].shape_, inputs[0].dptr<DType>(),
+                         outputs[0].dptr<DType>());
+    }
   });
 }
 
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 481dfce010..26f520bb8c 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -48,7 +48,6 @@ NNVM_REGISTER_OP(_image_to_tensor)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
 .add_argument("data", "NDArray-or-Symbol", "The input.");
 
-
 NNVM_REGISTER_OP(_image_normalize)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
@@ -65,11 +64,21 @@ NNVM_REGISTER_OP(_image_normalize)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(NormalizeParam::__FIELDS__());
 
+MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_left_right)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", FlipLeftRight);
+
+MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_flip_left_right)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", RandomFlipLeftRight);
 
-MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_horizontal_flip)
+MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_top_bottom)
 .describe(R"code()code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", RandomHorizontalFlip);
+.set_attr<FCompute>("FCompute<cpu>", FlipTopBottom);
 
+MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_flip_top_bottom)
+.describe(R"code()code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", RandomFlipTopBottom);
 
 MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_brightness)
 .describe(R"code()code" ADD_FILELINE)
@@ -77,7 +86,6 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_brightness)
 .set_attr<FCompute>("FCompute<cpu>", RandomBrightness)
 .add_arguments(RandomEnhanceParam::__FIELDS__());
 
-
 MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_contrast)
 .describe(R"code()code" ADD_FILELINE)
 .set_attr_parser(ParamParser<RandomEnhanceParam>)
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index 0c9e5c1c11..5e9ff879fb 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -18,22 +18,39 @@
 import mxnet as mx
 import mxnet.ndarray as nd
 import numpy as np
+from PIL import Image
 from mxnet import gluon
-from mxnet.gluon.data.vision.transforms import AdjustLighting
+from mxnet.gluon.data.vision import transforms
 from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import almost_equal
 
-def test_adjust_lighting():
+def test_to_tensor():
     data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
-    alpha_rgb = [0.05, 0.06, 0.07]
-    eigval = np.array([55.46, 4.794, 1.148])
-    eigvec = np.array([[-0.5675, 0.7192, 0.4009],
-                       [-0.5808, -0.0045, -0.8140],
-                       [-0.5808, -0.0045, -0.8140]])
-    f = AdjustLighting(alpha_rgb=alpha_rgb, eigval=eigval.ravel().tolist(), eigvec=eigvec.ravel().tolist())
-    out_nd = f(nd.array(data_in, dtype=np.uint8))
-    out_gt = np.clip(data_in.astype(np.float32)
-                     + np.dot(eigvec * alpha_rgb, eigval.reshape((3, 1))).reshape((1, 1, 3)), 0, 255).astype(np.uint8)
-    assert_almost_equal(out_nd.asnumpy(), out_gt)
+    out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(out_nd.asnumpy(), np.transpose(
+        data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
+
+def test_normalize():
+    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
+    data_in = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    out_nd = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in)
+    data_expected = data_in.asnumpy()
+    data_expected[:][:][0] = data_expected[:][:][0] / 3.0
+    data_expected[:][:][1] = (data_expected[:][:][1] - 1.0) / 2.0
+    data_expected[:][:][2] = data_expected[:][:][2] - 2.0
+    assert_almost_equal(data_expected, out_nd.asnumpy())
+
+def test_flip_left_right():
+    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
+    pil_img = Image.fromarray(data_in).transpose(Image.FLIP_LEFT_RIGHT)
+    data_trans = nd.image.flip_left_right(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(np.array(pil_img), data_trans.asnumpy())
+
+def test_flip_top_bottom():
+    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
+    pil_img = Image.fromarray(data_in).transpose(Image.FLIP_TOP_BOTTOM)
+    data_trans = nd.image.flip_top_bottom(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(np.array(pil_img), data_trans.asnumpy())
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services