You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/02/04 22:59:30 UTC

[incubator-mxnet] branch master updated: Image ToTensor operator - GPU support, 3D/4D inputs (#13837)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f86f21e  Image ToTensor operator - GPU support, 3D/4D inputs (#13837)
f86f21e is described below

commit f86f21e08356e3299d1d34c29398a48fec0f5607
Author: Sandeep Krishnamurthy <sa...@gmail.com>
AuthorDate: Mon Feb 4 14:59:13 2019 -0800

    Image ToTensor operator - GPU support, 3D/4D inputs (#13837)
    
    * Add CPU implementation of ToTensor
    
    * Add tests for cpu
    
    * Add gpu implementation and tests
    
    * Fix lint issues
    
    * Cleanup includes
    
    * Move back changes to original image operators files
    
    * Add 4D example
    
    * resolve merge conflicts
    
    * Fix failing tests
    
    * parallelize on channel in kernel launch
---
 python/mxnet/gluon/data/vision/transforms.py    |  9 ++-
 src/operator/image/image_random-inl.h           | 97 +++++++++++++++++++------
 src/operator/image/image_random.cc              | 63 +++++++++++++++-
 src/operator/image/image_random.cu              | 44 +++++------
 tests/python/gpu/test_gluon_transforms.py       | 36 ++++++++-
 tests/python/unittest/test_gluon_data_vision.py | 14 +++-
 6 files changed, 213 insertions(+), 50 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index aa4a3e3..9310e15 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -96,17 +96,20 @@ class Cast(HybridBlock):
 
 
 class ToTensor(HybridBlock):
-    """Converts an image NDArray to a tensor NDArray.
+    """Converts an image NDArray or batch of image NDArray to a tensor NDArray.
 
     Converts an image NDArray of shape (H x W x C) in the range
     [0, 255] to a float32 tensor NDArray of shape (C x H x W) in
     the range [0, 1).
 
+    If batch input, converts a batch image NDArray of shape (N x H x W x C) in the
+    range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W).
+
     Inputs:
-        - **data**: input tensor with (H x W x C) shape and uint8 type.
+        - **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type.
 
     Outputs:
-        - **out**: output tensor with (C x H x W) shape and float32 type.
+        - **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type.
 
     Examples
     --------
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index aeea0bc..c9dd85a 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -43,16 +43,28 @@ namespace mxnet {
 namespace op {
 namespace image {
 
+// There are no parameters for this operator.
+// Hence, no arameter registration.
+
+// Shape and Type inference for image to tensor operator
 inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
                           std::vector<TShape> *in_attrs,
                           std::vector<TShape> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
+
   TShape &shp = (*in_attrs)[0];
   if (!shp.ndim()) return false;
-  CHECK_EQ(shp.ndim(), 3)
-      << "Input image must have shape (height, width, channels), but got " << shp;
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
+
+  CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
+      << "Input image must have shape (height, width, channels), or "
+      << "(N, height, width, channels) but got " << shp;
+  if (shp.ndim() == 3) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
+  } else if (shp.ndim() == 4) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]}));
+  }
+
   return true;
 }
 
@@ -65,31 +77,74 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
   return (*in_attrs)[0] != -1;
 }
 
-inline void ToTensor(const nnvm::NodeAttrs &attrs,
-                     const OpContext &ctx,
-                     const std::vector<TBlob> &inputs,
-                     const std::vector<OpReqType> &req,
-                     const std::vector<TBlob> &outputs) {
-  CHECK_EQ(req[0], kWriteTo)
-    << "`to_tensor` does not support inplace";
+// Operator Implementation
 
-  int length = inputs[0].shape_[0] * inputs[0].shape_[1];
-  int channel = inputs[0].shape_[2];
+template<int req>
+struct totensor_forward {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data,
+                                  const int length, const int channel, const int step,
+                                  const float normalize_factor = 255.0f) {
+      #pragma omp parallel for
+      for (int i = 0; i < length; ++i) {
+        KERNEL_ASSIGN(out_data[step + c*length + i], req,
+                      (in_data[step + i*channel + c]) / normalize_factor);
+      }
+  }
+};
+
+template<typename xpu>
+void ToTensorImpl(const OpContext &ctx,
+                  const std::vector<TBlob> &inputs,
+                  const std::vector<TBlob> &outputs,
+                  const std::vector<OpReqType> &req,
+                  const int length,
+                  const uint32_t channel,
+                  const int step = 0) {
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
 
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    float* output = outputs[0].dptr<float>();
-    DType* input = inputs[0].dptr<DType>();
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      float* output = outputs[0].dptr<float>();
+      DType* input = inputs[0].dptr<DType>();
+      mxnet_op::Kernel<totensor_forward<req_type>, xpu>::Launch(
+          s, channel, output, input, length, channel, step);
+    });
+  });
+}
 
-    for (int l = 0; l < length; ++l) {
-      for (int c = 0; c < channel; ++c) {
-        output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
-      }
+template<typename xpu>
+void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+
+  CHECK_EQ(req[0], kWriteTo)
+    << "`to_tensor` does not support inplace updates";
+
+  // 3D Input - (h, w, c)
+  if (inputs[0].ndim() == 3) {
+    const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
+    const uint32_t channel = inputs[0].shape_[2];
+    ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel);
+  } else if (inputs[0].ndim() == 4) {
+    // 4D input (n, h, w, c)
+    const int batch_size = inputs[0].shape_[0];
+    const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
+    const uint32_t channel = inputs[0].shape_[3];
+    const int step = channel * length;
+
+    #pragma omp parallel for
+    for (auto n = 0; n < batch_size; ++n) {
+      ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel, n*step);
     }
-  });
+  }
 }
 
-// Normalize Operator
-// Parameter registration for image Normalize operator
 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
   nnvm::Tuple<float> mean;
   nnvm::Tuple<float> std;
diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc
index 7901747..fc6b17c 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -39,14 +39,71 @@ DMLC_REGISTER_PARAMETER(RandomLightingParam);
 DMLC_REGISTER_PARAMETER(RandomColorJitterParam);
 
 NNVM_REGISTER_OP(_image_to_tensor)
-.describe(R"code()code" ADD_FILELINE)
+.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C) 
+with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W)
+with values in the range [0, 1)
+
+Example:
+    .. code-block:: python
+        image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8)
+        to_tensor(image)
+            [[[ 0.85490197  0.72156864]
+              [ 0.09019608  0.74117649]
+              [ 0.61960787  0.92941177]
+              [ 0.96470588  0.1882353 ]]
+             [[ 0.6156863   0.73725492]
+              [ 0.46666667  0.98039216]
+              [ 0.44705883  0.45490196]
+              [ 0.01960784  0.8509804 ]]
+             [[ 0.39607844  0.03137255]
+              [ 0.72156864  0.52941179]
+              [ 0.16470589  0.7647059 ]
+              [ 0.05490196  0.70588237]]]
+             <NDArray 3x4x2 @cpu(0)>
+
+        image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
+        to_tensor(image)
+            [[[[0.11764706 0.5803922 ]
+               [0.9411765  0.10588235]
+               [0.2627451  0.73333335]
+               [0.5647059  0.32156864]]
+              [[0.7176471  0.14117648]
+               [0.75686276 0.4117647 ]
+               [0.18431373 0.45490196]
+               [0.13333334 0.6156863 ]]
+              [[0.6392157  0.5372549 ]
+               [0.52156866 0.47058824]
+               [0.77254903 0.21568628]
+               [0.01568628 0.14901961]]]
+             [[[0.6117647  0.38431373]
+               [0.6784314  0.6117647 ]
+               [0.69411767 0.96862745]
+               [0.67058825 0.35686275]]
+              [[0.21960784 0.9411765 ]
+               [0.44705883 0.43529412]
+               [0.09803922 0.6666667 ]
+               [0.16862746 0.1254902 ]]
+              [[0.6156863  0.9019608 ]
+               [0.35686275 0.9019608 ]
+               [0.05882353 0.6509804 ]
+               [0.20784314 0.7490196 ]]]]
+            <NDArray 2x3x4x2 @cpu(0)>
+)code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
 .set_attr<nnvm::FInferShape>("FInferShape", ToTensorShape)
 .set_attr<nnvm::FInferType>("FInferType", ToTensorType)
-.set_attr<FCompute>("FCompute<cpu>", ToTensor)
+.set_attr<FCompute>("FCompute<cpu>", ToTensorOpForward<cpu>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
-.add_argument("data", "NDArray-or-Symbol", "The input.");
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray");
 
 NNVM_REGISTER_OP(_image_normalize)
 .describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu
index 404c3d2..5f9aff2 100644
--- a/src/operator/image/image_random.cu
+++ b/src/operator/image/image_random.cu
@@ -1,26 +1,26 @@
 /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
 
 /*!
- * \file image_random.cu
- * \brief GPU Implementation of image transformation operators
- */
+* \file image_random.cu
+* \brief GPU Implementation of image transformation operators
+*/
 #include "./image_random-inl.h"
 #include "../elemwise_op_common.h"
 
@@ -28,13 +28,15 @@ namespace mxnet {
 namespace op {
 namespace image {
 
+NNVM_REGISTER_OP(_image_to_tensor)
+.set_attr<FCompute>("FCompute<gpu>", ToTensorOpForward<gpu>);
+
 NNVM_REGISTER_OP(_image_normalize)
 .set_attr<FCompute>("FCompute<gpu>", NormalizeOpForward<gpu>);
 
 NNVM_REGISTER_OP(_backward_image_normalize)
 .set_attr<FCompute>("FCompute<gpu>", NormalizeOpBackward<gpu>);
 
-
 }  // namespace image
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py
index 4a1017b..3927d4c 100644
--- a/tests/python/gpu/test_gluon_transforms.py
+++ b/tests/python/gpu/test_gluon_transforms.py
@@ -71,6 +71,41 @@ def test_normalize():
     normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
     assertRaises(MXNetError, normalize_transformer, invalid_data_in)
 
+@with_seed()
+def test_to_tensor():
+    # 3D Input
+    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
+    out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(out_nd.asnumpy(), np.transpose(
+        data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
+
+    # 4D Input
+    data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
+    out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
+    data_expected_4d = data_in_4d.asnumpy()
+    data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
+    data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
+    data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
+    data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
+    data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
+    data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
+    assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())
+
+    # Default normalize values i.e., mean=0, std=1
+    data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
+    out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
+    data_expected_3d_def = data_in_3d_def.asnumpy()
+    assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())
+
+    # Invalid Input - Neither 3D or 4D input
+    invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+
+    # Invalid Input - Channel neither 1 or 3
+    invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
+    normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
 
 @with_seed()
 def test_resize():
@@ -128,4 +163,3 @@ def test_resize():
     data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8')
     out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
     assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0)
-
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index f10f0ae..a855fc8 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -29,10 +29,22 @@ import numpy as np
 
 @with_seed()
 def test_to_tensor():
+    # 3D Input
     data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
     out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
     assert_almost_equal(out_nd.asnumpy(), np.transpose(
-        data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
+                        data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
+
+    # 4D Input
+    data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
+    out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
+    assert_almost_equal(out_nd.asnumpy(), np.transpose(
+                        data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))
+    
+    # Invalid Input
+    invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
+    transformer = transforms.ToTensor()
+    assertRaises(MXNetError, transformer, invalid_data_in)
 
 
 @with_seed()