You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/08 21:24:54 UTC

[GitHub] sandeep-krishnamurthy closed pull request #13614: Make to_tensor and normalize to accept 3D or 4D tensor inputs

sandeep-krishnamurthy closed pull request #13614: Make to_tensor and normalize to accept 3D or 4D tensor inputs
URL: https://github.com/apache/incubator-mxnet/pull/13614
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 3523be4d054..f5618292f1a 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -96,17 +96,20 @@ def hybrid_forward(self, F, x):
 
 
 class ToTensor(HybridBlock):
-    """Converts an image NDArray to a tensor NDArray.
+    """Converts an image NDArray or batch of image NDArray to a tensor NDArray.
 
     Converts an image NDArray of shape (H x W x C) in the range
     [0, 255] to a float32 tensor NDArray of shape (C x H x W) in
     the range [0, 1).
 
+    If batch input, converts a batch image NDArray of shape (N x H x W x C) in the 
+    range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W).
+
     Inputs:
-        - **data**: input tensor with (H x W x C) shape and uint8 type.
+        - **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type.
 
     Outputs:
-        - **out**: output tensor with (C x H x W) shape and float32 type.
+        - **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type.
 
     Examples
     --------
@@ -135,7 +138,7 @@ def hybrid_forward(self, F, x):
 
 
 class Normalize(HybridBlock):
-    """Normalize an tensor of shape (C x H x W) with mean and
+    """Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
     standard deviation.
 
     Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
@@ -154,10 +157,29 @@ class Normalize(HybridBlock):
 
 
     Inputs:
-        - **data**: input tensor with (C x H x W) shape.
+        - **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
 
     Outputs:
         - **out**: output tensor with the shape as `data`.
+    
+    Examples
+    --------
+    >>> transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    >>> image = mx.nd.random.uniform(0, 1, (3, 4, 2))
+    >>> transformer(image)
+    [[[ 0.18293785  0.19761486]
+      [ 0.23839645  0.28142193]
+      [ 0.20092112  0.28598186]
+      [ 0.18162774  0.28241724]]
+     [[-0.2881726  -0.18821815]
+      [-0.17705294 -0.30780914]
+      [-0.2812064  -0.3512327 ]
+      [-0.05411351 -0.4716435 ]]
+     [[-1.0363373  -1.7273437 ]
+      [-1.6165586  -1.5223348 ]
+      [-1.208275   -1.1878313 ]
+      [-1.4711051  -1.5200229 ]]]
+    <NDArray 3x4x2 @cpu(0)>
     """
     def __init__(self, mean, std):
         super(Normalize, self).__init__()
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index c64ed28ecc2..a7e1161ff9e 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -47,9 +47,14 @@ inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   TShape &shp = (*in_attrs)[0];
   if (!shp.ndim()) return false;
-  CHECK_EQ(shp.ndim(), 3)
-      << "Input image must have shape (height, width, channels), but got " << shp;
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
+  CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
+      << "Input image must have shape (height, width, channels), or "
+      << "(N, height, width, channels) but got " << shp;
+  if (shp.ndim() == 3) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
+  } else if (shp.ndim() == 4) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]}));
+  }
   return true;
 }
 
@@ -62,6 +67,23 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
   return (*in_attrs)[0] != -1;
 }
 
+void ToTensorImpl(const std::vector<TBlob> &inputs,
+                        const std::vector<TBlob> &outputs,
+                        const int length,
+                        const int channel,
+                        const int step = 0) {
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      float* output = outputs[0].dptr<float>();
+      DType* input = inputs[0].dptr<DType>();
+
+      for (int l = 0; l < length; ++l) {
+        for (int c = 0; c < channel; ++c) {
+          output[step + c*length + l] = static_cast<float>(input[step + l*channel + c]) / 255.0f;
+        }
+      }
+    });
+}
+
 void ToTensor(const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<TBlob> &inputs,
@@ -70,19 +92,23 @@ void ToTensor(const nnvm::NodeAttrs &attrs,
   CHECK_EQ(req[0], kWriteTo)
     << "`to_tensor` does not support inplace";
 
-  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
-  int channel = inputs[0].shape_[2];
-
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    float* output = outputs[0].dptr<float>();
-    DType* input = inputs[0].dptr<DType>();
-
-    for (int l = 0; l < length; ++l) {
-      for (int c = 0; c < channel; ++c) {
-        output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
-      }
+  // 3D Input - 1 image
+  if (inputs[0].ndim() == 3) {
+    const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+    const int channel = inputs[0].shape_[2];
+    ToTensorImpl(inputs, outputs, length, channel);
+  } else if (inputs[0].ndim() == 4) {
+    // 4D input batch of images
+    const int batch_size = inputs[0].shape_[0];
+    const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
+    const int channel = inputs[0].shape_[3];
+    const int step = channel * length;
+
+    #pragma omp parallel for
+    for (auto n = 0; n < batch_size; ++n) {
+      ToTensorImpl(inputs, outputs, length, channel, n*step);
     }
-  });
+  }
 }
 
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
@@ -103,14 +129,24 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
   const auto& dshape = (*in_attrs)[0];
   if (!dshape.ndim()) return false;
 
-  CHECK_EQ(dshape.ndim(), 3)
-      << "Input tensor must have shape (channels, height, width), but got "
-      << dshape;
-  auto nchannels = dshape[0];
-  CHECK(nchannels == 3 || nchannels == 1)
+  CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4))
+      << "Input tensor must have shape (channels, height, width), or "
+      << "(N, channels, height, width), but got " << dshape;
+
+  uint32_t nchannels;
+  if (dshape.ndim() == 3) {
+    nchannels = dshape[0];
+    CHECK(nchannels == 3 || nchannels == 1)
       << "The first dimension of input tensor must be the channel dimension with "
       << "either 1 or 3 elements, but got input with shape " << dshape;
-  CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels)
+  } else if (dshape.ndim() == 4) {
+    nchannels = dshape[1];
+    CHECK(nchannels == 3 || nchannels == 1)
+      << "The second dimension of input tensor must be the channel dimension with "
+      << "either 1 or 3 elements, but got input with shape " << dshape;
+  }
+
+  CHECK((param.mean.ndim() == 1) || (param.mean.ndim() == nchannels))
       << "Invalid mean for input with shape " << dshape
       << ". mean must have either 1 or " << nchannels
       << " elements, but got " << param.mean;
@@ -123,6 +159,26 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+void NormalizeImpl(const std::vector<TBlob> &inputs,
+                          const std::vector<TBlob> &outputs,
+                          const NormalizeParam &param,
+                          const int length,
+                          const int channel,
+                          const int step = 0) {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      DType* input = inputs[0].dptr<DType>();
+      DType* output = outputs[0].dptr<DType>();
+
+      for (int i = 0; i < channel; ++i) {
+        DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
+        DType std_dev = param.std[param.std.ndim() > 1 ? i : 0];
+        for (int j = 0; j < length; ++j) {
+          output[step + i*length + j] = (input[step + i*length + j] - mean) / std_dev;
+        }
+      }
+    });
+}
+
 void Normalize(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
@@ -130,21 +186,23 @@ void Normalize(const nnvm::NodeAttrs &attrs,
                       const std::vector<TBlob> &outputs) {
   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
 
-  int nchannels = inputs[0].shape_[0];
-  int length = inputs[0].shape_[1] * inputs[0].shape_[2];
-
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    DType* input = inputs[0].dptr<DType>();
-    DType* output = outputs[0].dptr<DType>();
-
-    for (int i = 0; i < nchannels; ++i) {
-      DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
-      DType std = param.std[param.std.ndim() > 1 ? i : 0];
-      for (int j = 0; j < length; ++j) {
-        output[i*length + j] = (input[i*length + j] - mean) / std;
-      }
+  // 3D input (c, h, w)
+  if (inputs[0].ndim() == 3) {
+    const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
+    const int channel = inputs[0].shape_[0];
+    NormalizeImpl(inputs, outputs, param, length, channel);
+  } else if (inputs[0].ndim() == 4) {
+    // 4D input (n, c, h, w)
+    const int batch_size = inputs[0].shape_[0];
+    const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
+    const int channel = inputs[0].shape_[1];
+    const int step = channel*length;
+
+    #pragma omp parallel for
+    for (auto n = 0; n < batch_size; ++n) {
+      NormalizeImpl(inputs, outputs, param, length, channel, n*step);
     }
-  });
+  }
 }
 
 template<typename DType>
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index 2ff9c5cb2a1..b88bc09c6a3 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -19,30 +19,66 @@
 import mxnet.ndarray as nd
 import numpy as np
 from mxnet import gluon
+from mxnet.base import MXNetError
 from mxnet.gluon.data.vision import transforms
 from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import almost_equal
-from common import setup_module, with_seed, teardown
-
+from common import assertRaises, setup_module, with_seed, teardown
 
 @with_seed()
 def test_to_tensor():
+    # 3D Input
     data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
-    out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    out_nd = transforms.ToTensor()(nd.array(data_in))
     assert_almost_equal(out_nd.asnumpy(), np.transpose(
         data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
 
+    # 4D Input
+    data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
+    out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(out_nd.asnumpy(), np.transpose(
+        data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))
+    
+    # Invalid Input
+    invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
+    transformer = transforms.ToTensor()
+    assertRaises(MXNetError, transformer, invalid_data_in)
+
 
 @with_seed()
 def test_normalize():
-    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
-    data_in = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
-    out_nd = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in)
-    data_expected = data_in.asnumpy()
-    data_expected[:][:][0] = data_expected[:][:][0] / 3.0
-    data_expected[:][:][1] = (data_expected[:][:][1] - 1.0) / 2.0
-    data_expected[:][:][2] = data_expected[:][:][2] - 2.0
-    assert_almost_equal(data_expected, out_nd.asnumpy())
+    # 3D Input
+    data_in_3d = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
+    data_in_3d = transforms.ToTensor()(nd.array(data_in_3d, dtype='uint8'))
+    out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
+    data_expected_3d = data_in_3d.asnumpy()
+    data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
+    data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
+    data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
+    assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())
+
+    # 4D Input
+    data_in_4d = np.random.uniform(0, 255, (2, 300, 300, 3)).astype(dtype=np.uint8)
+    data_in_4d = transforms.ToTensor()(nd.array(data_in_4d, dtype='uint8'))
+    out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
+    data_expected_4d = data_in_4d.asnumpy()
+    data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
+    data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
+    data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
+    data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
+    data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
+    data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
+    assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())
+
+    # Invalid Input - Neither 3D or 4D input
+    invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)).astype(dtype=np.float32)
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+
+    # Invalid Input - Channel neither 1 or 3
+    invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)).astype(dtype=np.float32)
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
 
 
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services