You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/02/01 00:10:22 UTC

[incubator-mxnet] branch master updated: Export resize and support batch size (#14014)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a4634b  Export resize and support batch size (#14014)
2a4634b is described below

commit 2a4634b983be26cecdd5018b29b4f78e602dba2f
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Feb 1 08:09:59 2019 +0800

    Export resize and support batch size (#14014)
    
    * add image resize operator and unit test
    
    * refactor the resize operator and address lint issues
    
    * address comment and add doc
    
    * assert size is more than 2
    
    * add test case of 4D input
    
    * use ndarray datatype
    
    * add inline to Shape
    
    * add 4D input example
    
    * refactor the duplicate code and separate the resize from image_random
    
    * clean up the code
    
    * add resize implementation
    
    * delete the variable not used
    
    * refactor the code with structure and enum to make code more understandable
    
    * fix the lint
    
    * address comments
    
    * address comment 1. add description 2. refactor unit test and add dtype
    
    * update data type check
    
    * lint
    
    * move the common utitlity to image_utils
    
    * add default value for keep_ratio
    
    * change the operator doc
    
    * update the image utility function
    
    * fix lint
    
    * use Hang implementation to achieve image resize operator GPU
    
    * update the check and doc
    
    * refactor the caffe_gpu_interp2_kernel
    
    * update doc and fix the cpu compile error
    
    * update the comment
    
    * fix lint
    
    * add unit test for gpu
    
    * address comments
    
    * remove the crop and centercop utility function to make the PR clear
    
    * fix the syntax error
    
    * delete the warning
    
    * add unit test with 4D
    
    * fix typo
    
    * add more unit test
    
    * fix unit test
    
    * set atol = 1
    
    * fix missing numpy import
    
    * fix the unit test
    
    * delete test case
    
    * fix unit test missing dependency
    
    * fix error data type
    
    * unify the style and add invalid interp
    
    * update the doc
---
 python/mxnet/gluon/data/vision/transforms.py    |  34 ++--
 src/io/image_io.cc                              |  14 +-
 src/operator/contrib/bilinear_resize-inl.cuh    | 184 ++++++++++++++++++++
 src/operator/contrib/bilinear_resize.cu         |  79 +--------
 src/operator/image/image_random-inl.h           |   8 +-
 src/operator/image/image_utils.h                |  59 +++++++
 src/operator/image/resize-inl.h                 | 218 ++++++++++++++++++++++++
 src/operator/image/resize.cc                    |  83 +++++++++
 src/operator/image/resize.cu                    |  77 +++++++++
 tests/python/gpu/test_gluon_transforms.py       |  61 ++++++-
 tests/python/unittest/test_gluon_data_vision.py |  40 ++++-
 11 files changed, 744 insertions(+), 113 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 2f557f5..aa4a3e3 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -262,8 +262,8 @@ class CenterCrop(Block):
         return image.center_crop(x, *self._args)[0]
 
 
-class Resize(Block):
-    """Resize an image to the given size.
+class Resize(HybridBlock):
+    """Resize an image or a batch of image NDArray to the given size.
     Should be applied before `mxnet.gluon.data.vision.transforms.ToTensor`.
 
     Parameters
@@ -276,13 +276,17 @@ class Resize(Block):
     interpolation : int
         Interpolation method for resizing. By default uses bilinear
         interpolation. See OpenCV's resize function for available choices.
+        Note that the Resize on gpu use contrib.bilinearResize2D operator
+        which only support bilinear interpolation(1). The result would be slightly
+        different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D
+        use algorithm which aligns corner.
 
 
     Inputs:
-        - **data**: input tensor with (Hi x Wi x C) shape.
+        - **data**: input tensor with (H x W x C) or (N x H x W x C) shape.
 
     Outputs:
-        - **out**: output tensor with (H x W x C) shape.
+        - **out**: output tensor with (H x W x C) or (N x H x W x C) shape.
 
     Examples
     --------
@@ -290,6 +294,9 @@ class Resize(Block):
     >>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8)
     >>> transformer(image)
     <NDArray 500x1000x3 @cpu(0)>
+    >>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8)
+    >>> transformer(image)
+    <NDArray 3x500x1000x3 @cpu(0)>
     """
     def __init__(self, size, keep_ratio=False, interpolation=1):
         super(Resize, self).__init__()
@@ -297,23 +304,8 @@ class Resize(Block):
         self._size = size
         self._interpolation = interpolation
 
-    def forward(self, x):
-        if isinstance(self._size, numeric_types):
-            if not self._keep:
-                wsize = self._size
-                hsize = self._size
-            else:
-                h, w, _ = x.shape
-                if h > w:
-                    wsize = self._size
-                    hsize = int(h * wsize / w)
-                else:
-                    hsize = self._size
-                    wsize = int(w * hsize / h)
-        else:
-            wsize, hsize = self._size
-        return image.imresize(x, wsize, hsize, self._interpolation)
-
+    def hybrid_forward(self, F, x):
+        return F.image.resize(x, self._size, self._keep, self._interpolation)
 
 class RandomFlipLeftRight(HybridBlock):
     """Randomly flip the input image left to right with a probability
diff --git a/src/io/image_io.cc b/src/io/image_io.cc
index b3f7c40..44fcdb8 100644
--- a/src/io/image_io.cc
+++ b/src/io/image_io.cc
@@ -38,6 +38,7 @@
 #include <cstring>
 
 #include "../operator/elemwise_op_common.h"
+#include "../operator/image/resize-inl.h"
 
 #if MXNET_USE_OPENCV
   #include <opencv2/opencv.hpp>
@@ -285,19 +286,8 @@ inline void Imresize(const nnvm::NodeAttrs& attrs,
                      const std::vector<TBlob> &inputs,
                      const std::vector<OpReqType> &req,
                      const std::vector<TBlob> &outputs) {
-#if MXNET_USE_OPENCV
-  CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "imresize doesn't support fp16";
-  const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S};
-  int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[2]);
   const auto& param = nnvm::get<ResizeParam>(attrs.parsed);
-  cv::Mat buf(inputs[0].shape_[0], inputs[0].shape_[1], cv_type, inputs[0].dptr_);
-  cv::Mat dst(outputs[0].shape_[0], outputs[0].shape_[1], cv_type, outputs[0].dptr_);
-  cv::resize(buf, dst, cv::Size(param.w, param.h), 0, 0, param.interp);
-  CHECK(!dst.empty());
-  CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr_);
-#else
-  LOG(FATAL) << "Build with USE_OPENCV=1 for image io.";
-#endif  // MXNET_USE_OPENCV
+  op::image::ResizeImpl(inputs, outputs, param.h, param.w, param.interp);
 }
 
 
diff --git a/src/operator/contrib/bilinear_resize-inl.cuh b/src/operator/contrib/bilinear_resize-inl.cuh
new file mode 100644
index 0000000..b8dacb1
--- /dev/null
+++ b/src/operator/contrib/bilinear_resize-inl.cuh
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file bilinear_resize-inl.cuh
+ * \brief bilinear resize operator cuda implementation
+ * \author Hang Zhang, Jake Lee
+*/
+
+#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
+#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
+
+#include <cuda_runtime_api.h>
+#include <algorithm>
+
+namespace mxnet {
+namespace op {
+
+using namespace mshadow;
+
+enum ImageLayout {
+  HWC,
+  NHWC,
+  NCHW
+};
+
+template<typename In, typename Out>
+struct ScalarConvert {
+  static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
+};
+
+// The maximum number of threads in a block
+static const unsigned MAX_BLOCK_SIZE = 512U;
+
+// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
+static unsigned getNumThreads(int nElem, const bool smaller) {
+  unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
+  const int maxi = smaller ? 4 : 5;
+  for (int i = 0; i != maxi; ++i) {
+    if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
+      return threadSizes[i];
+    }
+  }
+  return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
+}
+
+// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 3, DType>
+template<typename xpu, typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel(const int n,
+    const Acctype rheight, const Acctype rwidth,
+    const Tensor<xpu, 3, Dtype> data1,
+    Tensor<xpu, 3, Dtype> data2,
+    ImageLayout layout) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int channels = data1.size(2);
+  const int height1 = data1.size(0);
+  const int width1 = data1.size(1);
+  const int height2 = data2.size(0);
+  const int width2 = data2.size(1);
+
+  if (index < n) {
+    const int w2 = index % width2;  // 0:width2-1
+    const int h2 = index / width2;  // 0:height2-1
+    // special case: just copy
+    if (height1 == height2 && width1 == width2) {
+      const int h1 = h2;
+      const int w1 = w2;
+      for (int c = 0; c < channels; ++c) {
+        const Dtype val = data1[h1][w1][c];
+        data2[h2][w2][c] = val;
+      }
+      return;
+    }
+    //
+    const Acctype h1r = rheight * h2;
+    const int h1 = h1r;
+    const int h1p = (h1 < height1 - 1) ? 1 : 0;
+    const Acctype h1lambda = h1r - h1;
+    const Acctype h0lambda = Acctype(1) - h1lambda;
+    //
+    const Acctype w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < width1 - 1) ? 1 : 0;
+    const Acctype w1lambda = w1r - w1;
+    const Acctype w0lambda = Acctype(1) - w1lambda;
+    for (int c = 0; c < channels; ++c) {
+      const Acctype val = h0lambda * (w0lambda * data1[h1][w1][c]
+                            + w1lambda * data1[h1][w1+w1p][c])
+                            + h1lambda * (w0lambda * data1[h1+h1p][w1][c]
+                            + w1lambda * data1[h1+h1p][w1+w1p][c]);
+      data2[h2][w2][c] = ScalarConvert<Acctype, Dtype>::to(val);
+    }
+  }
+}
+
+// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 4, DType>
+template<typename xpu, typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel(const int n,
+    const Acctype rheight, const Acctype rwidth,
+    const Tensor<xpu, 4, Dtype> data1,
+    Tensor<xpu, 4, Dtype> data2,
+    ImageLayout layout) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  int batch_size = (layout == NHWC) ? data1.size(0) : data1.size(0);
+  int channels = (layout == NHWC) ? data1.size(3) : data1.size(1);
+  int height1 = (layout == NHWC) ? data1.size(1) : data1.size(2);
+  int width1 = (layout == NHWC) ? data1.size(2) : data1.size(3);
+  int height2 = (layout == NHWC) ? data2.size(1) : data2.size(2);
+  int width2 = (layout == NHWC) ? data2.size(2): data2.size(3);
+
+  if (index < n) {
+    const int w2 = index % width2;  // 0:width2-1
+    const int h2 = index / width2;  // 0:height2-1
+    // special case: just copy
+    if (height1 == height2 && width1 == width2) {
+      const int h1 = h2;
+      const int w1 = w2;
+
+      for (int n = 0; n < batch_size; ++n) {
+        for (int c = 0; c < channels; ++c) {
+          if (layout == NHWC) {
+            const Dtype val = data1[n][h1][w1][c];
+            data2[n][h2][w2][c] = val;
+          } else {
+            const Dtype val = data1[n][c][h1][w1];
+            data2[n][c][h2][w2] = val;
+          }
+        }
+      }
+      return;
+    }
+    //
+    const Acctype h1r = rheight * h2;
+    const int h1 = h1r;
+    const int h1p = (h1 < height1 - 1) ? 1 : 0;
+    const Acctype h1lambda = h1r - h1;
+    const Acctype h0lambda = Acctype(1) - h1lambda;
+    //
+    const Acctype w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < width1 - 1) ? 1 : 0;
+    const Acctype w1lambda = w1r - w1;
+    const Acctype w0lambda = Acctype(1) - w1lambda;
+
+    for (auto n = 0; n < batch_size; ++n) {
+      for (int c = 0; c < channels; ++c) {
+        if (layout == NHWC) {
+          const Acctype val = h0lambda * (w0lambda * data1[n][h1][w1][c]
+                            + w1lambda * data1[n][h1][w1+w1p][c])
+                            + h1lambda * (w0lambda * data1[n][h1+h1p][w1][c]
+                            + w1lambda * data1[n][h1+h1p][w1+w1p][c]);
+          data2[n][h2][w2][c] = ScalarConvert<Acctype, Dtype>::to(val);
+        } else {
+          const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
+                            + w1lambda * data1[n][c][h1][w1+w1p])
+                            + h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
+                            + w1lambda * data1[n][c][h1+h1p][w1+w1p]);
+          data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
\ No newline at end of file
diff --git a/src/operator/contrib/bilinear_resize.cu b/src/operator/contrib/bilinear_resize.cu
index f01c9c2..b0a4c4b 100644
--- a/src/operator/contrib/bilinear_resize.cu
+++ b/src/operator/contrib/bilinear_resize.cu
@@ -25,86 +25,13 @@
 #include <cuda_runtime_api.h>
 #include <algorithm>
 #include "bilinear_resize-inl.h"
+#include "bilinear_resize-inl.cuh"
 
 namespace mxnet {
 namespace op {
 
 using namespace mshadow;
 
-template<typename In, typename Out>
-struct ScalarConvert {
-  static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
-};
-
-
-// The maximum number of threads in a block
-static const unsigned MAX_BLOCK_SIZE = 512U;
-
-// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
-static unsigned getNumThreads(int nElem, const bool smaller) {
-  unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
-  const int maxi = smaller ? 4 : 5;
-  for (int i = 0; i != maxi; ++i) {
-    if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
-      return threadSizes[i];
-    }
-  }
-  return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
-}
-
-template<typename xpu, typename Dtype, typename Acctype>
-__global__ void caffe_gpu_interp2_kernel(const int n,
-    const Acctype rheight, const Acctype rwidth,
-    const Tensor<xpu, 4, Dtype> data1,
-    Tensor<xpu, 4, Dtype> data2) {
-  int index = threadIdx.x + blockIdx.x * blockDim.x;
-  const int batchsize = data1.size(0);
-  const int channels = data1.size(1);
-  const int height1 = data1.size(2);
-  const int width1 = data1.size(3);
-  const int height2 = data2.size(2);
-  const int width2 = data2.size(3);
-
-  if (index < n) {
-    const int w2 = index % width2;  // 0:width2-1
-    const int h2 = index / width2;  // 0:height2-1
-    // special case: just copy
-    if (height1 == height2 && width1 == width2) {
-      const int h1 = h2;
-      const int w1 = w2;
-      for (int n = 0; n < batchsize ; n++) {
-        for (int c = 0; c < channels; ++c) {
-          const Dtype val = data1[n][c][h1][w1];
-          data2[n][c][h2][w2] = val;
-        }
-      }
-      return;
-    }
-    //
-    const Acctype h1r = rheight * h2;
-    const int h1 = h1r;
-    const int h1p = (h1 < height1 - 1) ? 1 : 0;
-    const Acctype h1lambda = h1r - h1;
-    const Acctype h0lambda = Acctype(1) - h1lambda;
-    //
-    const Acctype w1r = rwidth * w2;
-    const int w1 = w1r;
-    const int w1p = (w1 < width1 - 1) ? 1 : 0;
-    const Acctype w1lambda = w1r - w1;
-    const Acctype w0lambda = Acctype(1) - w1lambda;
-    //
-    for (int n = 0; n < batchsize ; n++) {
-        for (int c = 0; c < channels; ++c) {
-        const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
-                            + w1lambda * data1[n][c][h1][w1+w1p])
-                            + h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
-                            + w1lambda * data1[n][c][h1+h1p][w1+w1p]);
-        data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
-      }
-    }
-  }
-}
-
 // Backward (adjoint) operation 1 <- 2 (accumulates)
 template<typename xpu, typename Dtype, typename Acctype>
 __global__ void caffe_gpu_interp2_kernel_backward(const int n,
@@ -181,9 +108,10 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
   dim3 blocks(static_cast<int>(num_kernels / num_threads) + 1);
   dim3 threads(num_threads);
   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  ImageLayout layout = NCHW;
   caffe_gpu_interp2_kernel<xpu, DType, AccReal>
   <<<blocks, threads , 0, stream>>>(
-    num_kernels, rheight, rwidth, idata, odata);
+    num_kernels, rheight, rwidth, idata, odata, layout);
   MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput);
 }
 
@@ -215,6 +143,5 @@ NNVM_REGISTER_OP(_contrib_BilinearResize2D)
 
 NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
 .set_attr<FCompute>("FCompute<gpu>", BilinearSampleOpBackward<gpu>);
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h
index 74807b9..aeea0bc 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -26,14 +26,18 @@
 #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
 
 
-#include <mxnet/base.h>
 #include <algorithm>
-#include <vector>
 #include <cmath>
 #include <limits>
+#include <tuple>
 #include <utility>
+#include <vector>
+#include "mxnet/base.h"
 #include "../mxnet_op.h"
 #include "../operator_common.h"
+#if MXNET_USE_OPENCV
+  #include <opencv2/opencv.hpp>
+#endif  // MXNET_USE_OPENCV
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/image/image_utils.h b/src/operator/image/image_utils.h
new file mode 100644
index 0000000..a715534
--- /dev/null
+++ b/src/operator/image/image_utils.h
@@ -0,0 +1,59 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file image_utils.h
+ * \brief the image operator utility function implementation
+ * \author Jake Lee
+ */
+
+#ifndef MXNET_OPERATOR_IMAGE_IMAGE_UTILS_H_
+#define MXNET_OPERATOR_IMAGE_IMAGE_UTILS_H_
+
+#include <vector>
+#if MXNET_USE_OPENCV
+  #include <opencv2/opencv.hpp>
+#endif  // MXNET_USE_OPENCV
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+enum ImageLayout {H, W, C};
+enum BatchImageLayout {N, kH, kW, kC};
+
+struct SizeParam {
+  int height;
+  int width;
+  SizeParam() {
+    height = 0;
+    width = 0;
+  }
+  SizeParam(int height_, int width_) {
+    height = height_;
+    width = width_;
+  }
+};  // struct SizeParam
+
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_IMAGE_IMAGE_UTILS_H_
diff --git a/src/operator/image/resize-inl.h b/src/operator/image/resize-inl.h
new file mode 100644
index 0000000..3e13100
--- /dev/null
+++ b/src/operator/image/resize-inl.h
@@ -0,0 +1,218 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+/*!
+* \file resize-inl.h
+* \brief image resize operator using opencv and only support bilinear resize
+* \author Jake Lee
+*/
+#ifndef MXNET_OPERATOR_IMAGE_RESIZE_INL_H_
+#define MXNET_OPERATOR_IMAGE_RESIZE_INL_H_
+
+#include <mxnet/base.h>
+#include <vector>
+
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "image_utils.h"
+
+#if MXNET_USE_OPENCV
+  #include <opencv2/opencv.hpp>
+#endif  // MXNET_USE_OPENCV
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+using namespace mshadow;
+
+#if MXNET_USE_CUDA
+template<typename DType, typename T, typename Acctype>
+void ResizeImplCUDA(Stream<gpu> *s,
+                      const T input,
+                      const T output);
+#endif  // MXNET_USE_CUDA
+
+struct ResizeParam : public dmlc::Parameter<ResizeParam> {
+  nnvm::Tuple<int> size;
+  bool keep_ratio;
+  int interp;
+  DMLC_DECLARE_PARAMETER(ResizeParam) {
+    DMLC_DECLARE_FIELD(size)
+    .set_default(nnvm::Tuple<int>())
+    .describe("Size of new image. Could be (width, height) or (size)");
+    DMLC_DECLARE_FIELD(keep_ratio)
+    .describe("Whether to resize the short edge or both edges to `size`, "
+      "if size is give as an integer.")
+    .set_default(false);
+    DMLC_DECLARE_FIELD(interp)
+    .set_default(1)
+    .describe("Interpolation method for resizing. By default uses bilinear interpolation"
+        "Options are INTER_NEAREST - a nearest-neighbor interpolation"
+        "INTER_LINEAR - a bilinear interpolation"
+        "INTER_AREA - resampling using pixel area relation"
+        "INTER_CUBIC - a bicubic interpolation over 4x4 pixel neighborhood"
+        "INTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhood"
+        "Note that the GPU version only support bilinear interpolation(1)"
+        " and the result on cpu would be slightly different from gpu."
+        "It uses opencv resize function which tend to align center on cpu"
+        "while using contrib.bilinearResize2D which aligns corner on gpu");
+  }
+};
+// handle the keep ratio param
+inline SizeParam GetHeightAndWidth(int data_h,
+                                    int data_w,
+                                    const ResizeParam& param) {
+  CHECK((param.size.ndim() == 1) || (param.size.ndim() == 2))
+      << "Input size dimension must be 1 or 2, but got "
+      << param.size.ndim();
+  int resized_h;
+  int resized_w;
+  if (param.size.ndim() == 1) {
+    CHECK_GT(param.size[0], 0)
+      << "Input size should be greater than 0, but got "
+      << param.size[0];
+    if (!param.keep_ratio) {
+      resized_h = param.size[0];
+      resized_w = param.size[0];
+    } else {
+      if (data_h > data_w) {
+        resized_w = param.size[0];
+        resized_h = static_cast<int>(data_h * resized_w / data_w);
+      } else {
+        resized_h = param.size[0];
+        resized_w = static_cast<int>(data_w * resized_h / data_h);
+      }
+    }
+  } else {
+    CHECK_GT(param.size[0], 0)
+        << "Input width should be greater than 0, but got "
+        << param.size[0];
+    CHECK_GT(param.size[1], 0)
+        << "Input height should be greater than 0, but got "
+        << param.size[1];
+    resized_h = param.size[1];
+    resized_w = param.size[0];
+  }
+  return SizeParam(resized_h, resized_w);
+}
+
+inline bool ResizeShape(const nnvm::NodeAttrs& attrs,
+                             std::vector<TShape> *in_attrs,
+                             std::vector<TShape> *out_attrs) {
+  // input attrs should only be (h, w, c) or (n, h, w, c)
+  CHECK((in_attrs->at(0).ndim() == 3U) || (in_attrs->at(0).ndim() == 4U))
+    << "Input image dimension should be 3 or 4 but got "
+    << in_attrs->at(0).ndim();
+  const auto& ishape = (*in_attrs)[0];
+  const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
+  SizeParam size;
+  if (ishape.ndim() == 3) {
+    size = GetHeightAndWidth(ishape[H], ishape[W], param);
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({size.height, size.width, ishape[C]}));
+  } else {
+    size = GetHeightAndWidth(ishape[kH], ishape[kW], param);
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0,
+      TShape({ishape[N], size.height, size.width, ishape[kC]}));
+  }
+  return true;
+}
+
+inline void ResizeImpl(const std::vector<TBlob> &inputs,
+                      const std::vector<TBlob> &outputs,
+                      const int height,
+                      const int width,
+                      const int interp,
+                      const int input_index = 0,
+                      const int output_index = 0) {
+#if MXNET_USE_OPENCV
+  CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "opencv image mat doesn't support fp16";
+  CHECK((inputs[0].type_flag_ != mshadow::kInt32) || (inputs[0].type_flag_ != mshadow::kInt64))
+      << "opencv resize doesn't support int32, int64";
+  // mapping to opencv matrix element type according to channel
+  const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S};
+  if (inputs[0].ndim() == 3) {
+    const int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[C]);
+    cv::Mat buf(inputs[0].shape_[H], inputs[0].shape_[W], cv_type, inputs[0].dptr_);
+    cv::Mat dst(outputs[0].shape_[H], outputs[0].shape_[W], cv_type, outputs[0].dptr_);
+    cv::resize(buf, dst, cv::Size(width, height), 0, 0, interp);
+    CHECK(!dst.empty());
+    CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr_);
+  } else {
+    const int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[kC]);
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      cv::Mat buf(inputs[0].shape_[kH], inputs[0].shape_[kW], cv_type,
+        inputs[0].dptr<DType>() + input_index);
+      cv::Mat dst(outputs[0].shape_[kH], outputs[0].shape_[kW], cv_type,
+        outputs[0].dptr<DType>() + output_index);
+      cv::resize(buf, dst, cv::Size(width, height), 0, 0, interp);
+      CHECK(!dst.empty());
+      CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr<DType>() + output_index);
+    });
+  }
+#else
+  LOG(FATAL) << "Build with USE_OPENCV=1 for image resize operator.";
+#endif  // MXNET_USE_OPENCV
+}
+
+template <typename xpu>
+inline void Resize(const nnvm::NodeAttrs &attrs,
+                   const OpContext &ctx,
+                   const std::vector<TBlob> &inputs,
+                   const std::vector<OpReqType> &req,
+                   const std::vector<TBlob> &outputs) {
+  CHECK_EQ(outputs.size(), 1U);
+  const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
+  SizeParam size;
+  if (std::is_same<xpu, gpu>::value) {
+#if MXNET_USE_CUDA
+    CHECK(param.interp == 1) << "interp should be 1 for using Resize on GPU.";
+    mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      if (inputs[0].ndim() == 3) {
+        Tensor<gpu, 3, DType> input = inputs[0].get<gpu, 3, DType>(s);
+        Tensor<gpu, 3, DType> output = outputs[0].get<gpu, 3, DType>(s);
+        ResizeImplCUDA<DType, Tensor<gpu, 3, DType>, float>(s, input, output);
+      } else {
+        Tensor<gpu, 4, DType> input = inputs[0].get<gpu, 4, DType>(s);
+        Tensor<gpu, 4, DType> output = outputs[0].get<gpu, 4, DType>(s);
+        ResizeImplCUDA<DType, Tensor<gpu, 4, DType>, float>(s, input, output);
+      }
+    });
+#endif  // MXNET_USE_CUDA
+  } else if (inputs[0].ndim() == 3) {
+    size = GetHeightAndWidth(inputs[0].shape_[H], inputs[0].shape_[W], param);
+    ResizeImpl(inputs, outputs, size.height, size.width, param.interp);
+  } else {
+    size = GetHeightAndWidth(inputs[0].shape_[kH], inputs[0].shape_[kW], param);
+    const auto batch_size = inputs[0].shape_[N];
+    const auto input_step = inputs[0].shape_[kH] * inputs[0].shape_[kW] * inputs[0].shape_[kC];
+    const auto output_step = size.height * size.width * inputs[0].shape_[kC];
+    #pragma omp parallel for
+    for (auto i = 0; i < batch_size; ++i) {
+      ResizeImpl(inputs, outputs, size.height, size.width,
+        param.interp, i * input_step, i * output_step);
+    }
+  }
+}
+
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_IMAGE_RESIZE_INL_H_
diff --git a/src/operator/image/resize.cc b/src/operator/image/resize.cc
new file mode 100644
index 0000000..d3b28f0
--- /dev/null
+++ b/src/operator/image/resize.cc
@@ -0,0 +1,83 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file resize.cc
+ * \brief resize operator cpu
+ * \author Jake Lee
+*/
+#include <mxnet/base.h>
+#include "./resize-inl.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+DMLC_REGISTER_PARAMETER(ResizeParam);
+
+NNVM_REGISTER_OP(_image_resize)
+.describe(R"code(Resize an image NDArray of shape (H x W x C) or (N x H x W x C) 
+to the given size
+Example:
+    .. code-block:: python
+        image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8)
+        mx.nd.image.resize(image, (3, 3))
+            [[[124 111 197]
+              [158  80 155]
+              [193  50 112]]
+
+             [[110 100 113]
+              [134 165 148]
+              [157 231 182]]
+
+             [[202 176 134]
+              [174 191 149]
+              [147 207 164]]]
+            <NDArray 3x3x3 @cpu(0)>
+        image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
+        mx.nd.image.resize(image, (2, 2))            
+            [[[[ 59 133  80]
+               [187 114 153]]
+
+              [[ 38 142  39]
+               [207 131 124]]]
+
+
+              [[[117 125 136]
+               [191 166 150]]
+
+              [[129  63 113]
+               [182 109  48]]]]
+            <NDArray 2x2x2x3 @cpu(0)>
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ResizeParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ResizeShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", Resize<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(ResizeParam::__FIELDS__());
+
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/image/resize.cu b/src/operator/image/resize.cu
new file mode 100644
index 0000000..f045f3b
--- /dev/null
+++ b/src/operator/image/resize.cu
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file bilinear_resize.cu
+ * \brief bilinear resize operator
+ * \author Hang Zhang, Jake Lee
+*/
+#include <algorithm>
+#include "./resize-inl.h"
+#include "../contrib/bilinear_resize-inl.cuh"
+
+namespace mxnet {
+namespace op {
+namespace image {
+
+using namespace mshadow;
+
+template<typename DType, typename T, typename AccReal>
+void ResizeImplCUDA(mshadow::Stream<gpu> *s,
+                      const T input,
+                      const T output) {
+  int outputHeight;
+  int outputWidth;
+  int inputHeight;
+  int inputWidth;
+  mxnet::op::ImageLayout layout;
+  if (std::is_same<T, Tensor<gpu, 3, DType>>::value) {
+    layout = HWC;
+    outputHeight = output.size(0);
+    outputWidth = output.size(1);
+    inputHeight = input.size(0);
+    inputWidth = input.size(1);
+  } else {
+    layout = NHWC;
+    outputHeight = output.size(1);
+    outputWidth = output.size(2);
+    inputHeight = input.size(1);
+    inputWidth = input.size(2);
+  }
+  const AccReal rheight = (outputHeight > 1) ? (AccReal)(inputHeight - 1)/
+                         (outputHeight - 1) : AccReal(0);
+  const AccReal rwidth = (outputWidth > 1) ? (AccReal)(inputWidth - 1)/
+                         (outputWidth - 1) : AccReal(0);
+  const int num_kernels = outputHeight * outputWidth;
+  const int num_threads = getNumThreads(inputHeight * inputWidth, false);
+  dim3 blocks(static_cast<int>(num_kernels / num_threads) + 1);
+  dim3 threads(num_threads);
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  caffe_gpu_interp2_kernel<gpu, DType, AccReal>
+  <<<blocks, threads , 0, stream>>>(
+    num_kernels, rheight, rwidth, input, output, layout);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(caffe_gpu_interp2_kernel);
+}
+
+NNVM_REGISTER_OP(_image_resize)
+.set_attr<FCompute>("FCompute<gpu>", Resize<gpu>);
+
+}  // namespace image
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py
index c7afc76..4a1017b 100644
--- a/tests/python/gpu/test_gluon_transforms.py
+++ b/tests/python/gpu/test_gluon_transforms.py
@@ -69,4 +69,63 @@ def test_normalize():
     # Invalid Input - Channel neither 1 or 3
     invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
     normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
-    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
\ No newline at end of file
+    assertRaises(MXNetError, normalize_transformer, invalid_data_in)
+
+
+@with_seed()
+def test_resize():
+    # Test with normal case 3D input float type
+    data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
+    out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
+    data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1)
+    data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3))[0]
+    assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy())
+
+    # Test with normal case 4D input float type
+    data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3))
+    out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
+    data_in_4d_nchw = nd.moveaxis(data_in_4d, 3, 1)
+    data_expected_4d = nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3)
+    assert_almost_equal(out_nd_4d.asnumpy(), data_expected_4d.asnumpy())
+
+    # Test invalid interp
+    data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
+    invalid_transform = transforms.Resize(-150, keep_ratio=False, interpolation=2)
+    assertRaises(MXNetError, invalid_transform, data_in_3d)
+
+    # Credited to Hang Zhang
+    def py_bilinear_resize_nhwc(x, outputHeight, outputWidth):
+        batch, inputHeight, inputWidth, channel = x.shape
+        if outputHeight == inputHeight and outputWidth == inputWidth:
+            return x
+        y = np.empty([batch, outputHeight, outputWidth, channel]).astype('uint8')
+        rheight = 1.0 * (inputHeight - 1) / (outputHeight - 1) if outputHeight > 1 else 0.0
+        rwidth = 1.0 * (inputWidth - 1) / (outputWidth - 1) if outputWidth > 1 else 0.0
+        for h2 in range(outputHeight):
+            h1r = 1.0 * h2 * rheight
+            h1 = int(np.floor(h1r))
+            h1lambda = h1r - h1
+            h1p = 1 if h1 < (inputHeight - 1) else 0
+            for w2 in range(outputWidth):
+                w1r = 1.0 * w2 * rwidth
+                w1 = int(np.floor(w1r))
+                w1lambda = w1r - w1
+                w1p = 1 if w1 < (inputHeight - 1) else 0
+                for b in range(batch):
+                    for c in range(channel):
+                        y[b][h2][w2][c] = (1-h1lambda)*((1-w1lambda)*x[b][h1][w1][c] + \
+                            w1lambda*x[b][h1][w1+w1p][c]) + \
+                            h1lambda*((1-w1lambda)*x[b][h1+h1p][w1][c] + \
+                            w1lambda*x[b][h1+h1p][w1+w1p][c])
+        return y
+
+    # Test with normal case 3D input int8 type
+    data_in_4d = nd.random.uniform(0, 255, (1, 300, 300, 3)).astype('uint8')
+    out_nd_3d = transforms.Resize((100, 100))(data_in_4d[0])
+    assert_almost_equal(out_nd_3d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100)[0], atol=1.0)
+
+    # Test with normal case 4D input int8 type
+    data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8')
+    out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
+    assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0)
+
diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py
index c83778f..f10f0ae 100644
--- a/tests/python/unittest/test_gluon_data_vision.py
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -17,7 +17,7 @@
 from __future__ import print_function
 import mxnet as mx
 import mxnet.ndarray as nd
-import numpy as np
+from mxnet.base import MXNetError
 from mxnet import gluon
 from mxnet.base import MXNetError
 from mxnet.gluon.data.vision import transforms
@@ -25,6 +25,7 @@ from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import almost_equal
 from common import assertRaises, setup_module, with_seed, teardown
 
+import numpy as np
 
 @with_seed()
 def test_to_tensor():
@@ -69,6 +70,43 @@ def test_normalize():
 
 
 @with_seed()
+def test_resize():
+    def _test_resize_with_diff_type(dtype):
+        # test normal case
+        data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype)
+        out_nd = transforms.Resize(200)(data_in)
+        data_expected = mx.image.imresize(data_in, 200, 200, 1)
+        assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
+        # test 4D input
+        data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype)
+        out_batch_nd = transforms.Resize(200)(data_bath_in)
+        for i in range(len(out_batch_nd)):
+            assert_almost_equal(mx.image.imresize(data_bath_in[i], 200, 200, 1).asnumpy(),
+                out_batch_nd[i].asnumpy())
+        # test interp = 2
+        out_nd = transforms.Resize(200, interpolation=2)(data_in)
+        data_expected = mx.image.imresize(data_in, 200, 200, 2)
+        assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
+        # test height not equals to width
+        out_nd = transforms.Resize((200, 100))(data_in)
+        data_expected = mx.image.imresize(data_in, 200, 100, 1)
+        assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
+        # test keep_ratio
+        out_nd = transforms.Resize(150, keep_ratio=True)(data_in)
+        data_expected = mx.image.imresize(data_in, 150, 225, 1)
+        assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
+        # test size below zero
+        invalid_transform = transforms.Resize(-150, keep_ratio=True)
+        assertRaises(MXNetError, invalid_transform, data_in)
+        # test size more than 2:
+        invalid_transform = transforms.Resize((100, 100, 100), keep_ratio=True)
+        assertRaises(MXNetError, invalid_transform, data_in)
+
+    for dtype in ['uint8', 'float32', 'float64']:
+        _test_resize_with_diff_type(dtype)    
+
+
+@with_seed()
 def test_flip_left_right():
     data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
     flip_in = data_in[:, ::-1, :]