You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/23 18:30:38 UTC

[GitHub] piiswrong closed pull request #10852: [MXNET-411] Add ROI Align

piiswrong closed pull request #10852: [MXNET-411] Add ROI Align
URL: https://github.com/apache/incubator-mxnet/pull/10852
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 5f5302a45a4..4bfafb60cba 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -170,4 +170,4 @@ List of Contributors
 * [Sina Afrooze](https://github.com/safrooze)
 * [Sergey Sokolov](https://github.com/Ishitori)
 * [Thomas Delteil](https://github.com/ThomasDelteil)
-
+* [Hang Zhang](http://hangzh.com)
diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md
index 25cabed808e..b017c601208 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -45,6 +45,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
     MultiProposal
     PSROIPooling
     Proposal
+    ROIAlign
     count_sketch
     ctc_loss
     dequantize
diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md
index 1af18bbf86d..f2bb3f15dee 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -45,6 +45,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
     MultiProposal
     PSROIPooling
     Proposal
+    ROIAlign
     count_sketch
     ctc_loss
     dequantize
diff --git a/src/operator/contrib/roi_align-inl.h b/src/operator/contrib/roi_align-inl.h
new file mode 100644
index 00000000000..5ac420cc3d4
--- /dev/null
+++ b/src/operator/contrib/roi_align-inl.h
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file roi_align-inl.h
+ * \brief roi align operator and symbol
+ * \author Hang Zhang
+ * modified from Caffe2
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_ROI_ALIGN_INL_H_
+#define MXNET_OPERATOR_CONTRIB_ROI_ALIGN_INL_H_
+
+#include <vector>
+#include <utility>
+#include "../mshadow_op.h"
+#include "../tensor/init_op.h"
+
+
+namespace mxnet {
+namespace op {
+
+
+// Declare enumeration of input order to make code more intuitive.
+// These enums are only visible within this header
+namespace roialign {
+enum ROIAlignOpInputs {kData, kBox};
+enum ROIAlignOpOutputs {kOut};
+}  // roialign
+
+
+struct ROIAlignParam : public dmlc::Parameter<ROIAlignParam> {
+  TShape pooled_size;
+  float spatial_scale;
+  DMLC_DECLARE_PARAMETER(ROIAlignParam) {
+    DMLC_DECLARE_FIELD(pooled_size)
+    .set_expect_ndim(2).enforce_nonzero()
+    .describe("ROI Align output roi feature map height and width: (h, w)");
+    DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0)
+    .describe("Ratio of input feature map height (or w) to raw image height (or w). "
+    "Equals the reciprocal of total stride in convolutional layers");
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_ROI_ALIGN_INL_H_
diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc
new file mode 100644
index 00000000000..c2cb929966a
--- /dev/null
+++ b/src/operator/contrib/roi_align.cc
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file roi_align.cc
+ * \brief roi align operator
+ * \author Hang Zhang
+ * Adapted from Caffe2
+*/
+#include "./roi_align-inl.h"
+
+
+namespace mxnet {
+namespace op {
+
+template <typename T>
+struct PreCalc {
+  int pos1;
+  int pos2;
+  int pos3;
+  int pos4;
+  T w1;
+  T w2;
+  T w3;
+  T w4;
+};
+
+template <typename T>
+void pre_calc_for_bilinear_interpolate(
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int iy_upper,
+    const int ix_upper,
+    T roi_start_h,
+    T roi_start_w,
+    T bin_size_h,
+    T bin_size_w,
+    int roi_bin_grid_h,
+    int roi_bin_grid_w,
+    std::vector<PreCalc<T>>* pre_calc) {
+  int pre_calc_index = 0;
+  for (int ph = 0; ph < pooled_height; ph++) {
+    for (int pw = 0; pw < pooled_width; pw++) {
+      for (int iy = 0; iy < iy_upper; iy++) {
+        const T yy = roi_start_h + ph * bin_size_h +
+            static_cast<T>(iy + .5f) * bin_size_h /
+                static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+        for (int ix = 0; ix < ix_upper; ix++) {
+          const T xx = roi_start_w + pw * bin_size_w +
+              static_cast<T>(ix + .5f) * bin_size_w /
+                  static_cast<T>(roi_bin_grid_w);
+
+          T x = xx;
+          T y = yy;
+          // deal with: inverse elements are out of feature map boundary
+          if (y < -1.0 || y > height || x < -1.0 || x > width) {
+            // empty
+            PreCalc<T> pc;
+            pc.pos1 = 0;
+            pc.pos2 = 0;
+            pc.pos3 = 0;
+            pc.pos4 = 0;
+            pc.w1 = 0;
+            pc.w2 = 0;
+            pc.w3 = 0;
+            pc.w4 = 0;
+            pre_calc->at(pre_calc_index) = pc;
+            pre_calc_index += 1;
+            continue;
+          }
+
+          if (y <= 0) {
+            y = 0;
+          }
+          if (x <= 0) {
+            x = 0;
+          }
+
+          int y_low = static_cast<int>(y);
+          int x_low = static_cast<int>(x);
+          int y_high;
+          int x_high;
+
+          if (y_low >= height - 1) {
+            y_high = y_low = height - 1;
+            y = (T)y_low;
+          } else {
+            y_high = y_low + 1;
+          }
+
+          if (x_low >= width - 1) {
+            x_high = x_low = width - 1;
+            x = (T)x_low;
+          } else {
+            x_high = x_low + 1;
+          }
+
+          T ly = y - y_low;
+          T lx = x - x_low;
+          T hy = 1. - ly, hx = 1. - lx;
+          T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+          // save weights and indeces
+          PreCalc<T> pc;
+          pc.pos1 = y_low * width + x_low;
+          pc.pos2 = y_low * width + x_high;
+          pc.pos3 = y_high * width + x_low;
+          pc.pos4 = y_high * width + x_high;
+          pc.w1 = w1;
+          pc.w2 = w2;
+          pc.w3 = w3;
+          pc.w4 = w4;
+          pre_calc->at(pre_calc_index) = pc;
+
+          pre_calc_index += 1;
+        }
+      }
+    }
+  }
+}
+
+template <typename T>
+void ROIAlignForward(
+    const int nthreads,
+    const T* bottom_data,
+    const T& spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    const T* bottom_rois,
+    int roi_cols,
+    T* top_data) {
+  DCHECK(roi_cols == 4 || roi_cols == 5);
+
+  int n_rois = nthreads / channels / pooled_width / pooled_height;
+  // (n, c, ph, pw) is an element in the pooled output
+  // can be parallelized using omp
+  for (int n = 0; n < n_rois; n++) {
+    int index_n = n * channels * pooled_width * pooled_height;
+
+    // roi could have 4 or 5 columns
+    const T* offset_bottom_rois = bottom_rois + n * roi_cols;
+    int roi_batch_ind = 0;
+    if (roi_cols == 5) {
+      roi_batch_ind = offset_bottom_rois[0];
+      offset_bottom_rois++;
+    }
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[0] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[1] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[3] * spatial_scale;
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    // we want to precalculate indeces and weights shared by all chanels,
+    // this is the key point of optimiation
+    std::vector<PreCalc<T>> pre_calc(
+        roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
+    pre_calc_for_bilinear_interpolate(
+        height,
+        width,
+        pooled_height,
+        pooled_width,
+        roi_bin_grid_h,
+        roi_bin_grid_w,
+        roi_start_h,
+        roi_start_w,
+        bin_size_h,
+        bin_size_w,
+        roi_bin_grid_h,
+        roi_bin_grid_w,
+        &pre_calc);
+
+    int c;
+#pragma omp parallel for private(c) \
+num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+    for (c = 0; c < channels; c++) {
+      int index_n_c = index_n + c * pooled_width * pooled_height;
+      const T* offset_bottom_data =
+          bottom_data + (roi_batch_ind * channels + c) * height * width;
+      int pre_calc_index = 0;
+
+      for (int ph = 0; ph < pooled_height; ph++) {
+        for (int pw = 0; pw < pooled_width; pw++) {
+          int index = index_n_c + ph * pooled_width + pw;
+
+          T output_val = 0.;
+          for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+            for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+              PreCalc<T> pc = pre_calc[pre_calc_index];
+              output_val += pc.w1 * offset_bottom_data[pc.pos1] +
+                  pc.w2 * offset_bottom_data[pc.pos2] +
+                  pc.w3 * offset_bottom_data[pc.pos3] +
+                  pc.w4 * offset_bottom_data[pc.pos4];
+
+              pre_calc_index += 1;
+            }
+          }
+          output_val /= count;
+
+          top_data[index] = output_val;
+        }  // for pw
+      }  // for ph
+    }  // for c
+  }  // for n
+}
+
+
+template <typename T>
+void bilinear_interpolate_gradient(
+    const int height,
+    const int width,
+    T y,
+    T x,
+    T* w1,
+    T* w2,
+    T* w3,
+    T* w4,
+    int* x_low,
+    int* x_high,
+    int* y_low,
+    int* y_high,
+    const int /*index*/ /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    // empty
+    *w1 = *w2 = *w3 = *w4 = 0.;
+    *x_low = *x_high = *y_low = *y_high = -1;
+    return;
+  }
+
+  if (y <= 0) {
+    y = 0;
+  }
+  if (x <= 0) {
+    x = 0;
+  }
+
+  *y_low = static_cast<int>(y);
+  *x_low = static_cast<int>(x);
+
+  if (*y_low >= height - 1) {
+    *y_high = *y_low = height - 1;
+    y = (T)*y_low;
+  } else {
+    *y_high = *y_low + 1;
+  }
+
+  if (*x_low >= width - 1) {
+    *x_high = *x_low = width - 1;
+    x = (T)*x_low;
+  } else {
+    *x_high = *x_low + 1;
+  }
+
+  T ly = y - *y_low;
+  T lx = x - *x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+
+  *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
+
+  return;
+}
+
+template <class T>
+inline void add(const T& val, T* address) {
+  *address += val;
+}
+
+template <typename T>
+void ROIAlignBackward(
+    const int nthreads,
+    const T* top_diff,
+    const int /*num_rois*/,
+    const T& spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    T* bottom_diff,
+    const T* bottom_rois,
+    int rois_cols) {
+  DCHECK(rois_cols == 4 || rois_cols == 5);
+
+  for (int index = 0; index < nthreads; index++) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * rois_cols;
+    int roi_batch_ind = 0;
+    if (rois_cols == 5) {
+      roi_batch_ind = offset_bottom_rois[0];
+      offset_bottom_rois++;
+    }
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[0] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[1] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[3] * spatial_scale;
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    T* offset_bottom_diff =
+        bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+    int top_offset = (n * channels + c) * pooled_height * pooled_width;
+    const T* offset_top_diff = top_diff + top_offset;
+    const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h /
+              static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const T x = roi_start_w + pw * bin_size_w +
+            static_cast<T>(ix + .5f) * bin_size_w /
+                static_cast<T>(roi_bin_grid_w);
+
+        T w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient(
+            height,
+            width,
+            y,
+            x,
+            &w1,
+            &w2,
+            &w3,
+            &w4,
+            &x_low,
+            &x_high,
+            &y_low,
+            &y_high,
+            index);
+
+        T g1 = top_diff_this_bin * w1 / count;
+        T g2 = top_diff_this_bin * w2 / count;
+        T g3 = top_diff_this_bin * w3 / count;
+        T g4 = top_diff_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          // atomic add is not needed for now since it is single threaded
+          add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
+          add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
+          add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
+          add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
+        }  // if
+      }  // ix
+    }  // iy
+  }  // for
+}  // ROIAlignBackward
+
+
+template<typename xpu>
+void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<TBlob>& in_data,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& out_data) {
+  using namespace mshadow;
+  size_t expected_in = 2;
+  size_t expected_out = 1;
+  CHECK_EQ(in_data.size(), expected_in);
+  CHECK_EQ(out_data.size(), expected_out);
+  CHECK_EQ(out_data[roialign::kOut].shape_[0], in_data[roialign::kBox].shape_[0]);
+
+  const ROIAlignParam param = nnvm::get<ROIAlignParam>(attrs.parsed);
+
+  const int count = out_data[roialign::kOut].Size();
+  // const int num_rois = in_data[roialign::kBox].size(0);
+  const int channels = in_data[roialign::kData].size(1);
+  const int height = in_data[roialign::kData].size(2);
+  const int width = in_data[roialign::kData].size(3);
+  const int pooled_height = out_data[roialign::kOut].size(2);
+  const int pooled_width = out_data[roialign::kOut].size(3);
+  const int rois_cols = in_data[roialign::kBox].size(1);
+
+  // assume all the data and gradient have the same type
+  MSHADOW_REAL_TYPE_SWITCH(in_data[0].type_flag_, DType, {
+    const DType *bottom_data = in_data[roialign::kData].dptr<DType>();
+    const DType *bottom_rois = in_data[roialign::kBox].dptr<DType>();
+    DType *top_data = out_data[roialign::kOut].dptr<DType>();
+
+    ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, channels,
+                           height, width, pooled_height, pooled_width, -1, bottom_rois,
+                           rois_cols, top_data);
+  })
+}
+
+template<typename xpu>
+void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+
+  CHECK_EQ(inputs.size(), 2);
+  CHECK_EQ(outputs.size(), 2);
+  // the order here relates to the order in ROIAlignGrad
+  std::vector<TBlob> out_grad(1, inputs[0]);
+  std::vector<TBlob> in_data(1, inputs[1]);
+  // std::vector<TBlob> out_data(1, inputs[2]);
+
+  CHECK_EQ(out_grad[0].shape_[0], in_data[0].shape_[0]);
+  CHECK_NE(req[0], kWriteInplace) <<
+    "ROIAlign: Backward doesn't support kWriteInplace.";
+  CHECK_NE(req[1], kWriteInplace) <<
+    "ROIAlign: Backward doesn't support kWriteInplace.";
+
+  const ROIAlignParam param = nnvm::get<ROIAlignParam>(attrs.parsed);
+
+  const int count = out_grad[0].Size();
+  const int num_rois = in_data[0].size(0);
+  const int channels = outputs[0].size(1);
+  const int height = outputs[0].size(2);
+  const int width = outputs[0].size(3);
+  const int pooled_height = out_grad[0].size(2);
+  const int pooled_width = out_grad[0].size(3);
+  const int rois_cols = in_data[0].size(1);
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  // assume all the data and gradient have the same type
+  MSHADOW_REAL_TYPE_SWITCH(out_grad[0].type_flag_, DType, {
+    const DType *top_diff = out_grad[0].dptr<DType>();
+    const DType *bottom_rois = in_data[0].dptr<DType>();
+    DType *grad_in = outputs[0].dptr<DType>();
+
+    if (kAddTo == req[roialign::kData] || kWriteTo == req[roialign::kData]) {
+      if (kWriteTo == req[roialign::kData]) {
+        Fill<false>(s, outputs[0], kWriteTo, static_cast<DType>(0));
+      }
+      ROIAlignBackward<DType>(count, top_diff, num_rois, param.spatial_scale,
+                     channels, height, width, pooled_height, pooled_width,
+                     -1, grad_in, bottom_rois, rois_cols);
+    }
+    if (kWriteTo == req[roialign::kBox]) {
+      Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
+    }
+  })
+}
+
+DMLC_REGISTER_PARAMETER(ROIAlignParam);
+
+NNVM_REGISTER_OP(_contrib_ROIAlign)
+.describe(R"code(
+This operator takes a 4D feature map as an input array and region proposals as `rois`,
+then align the feature map over sub-regions of input and produces a fixed-sized output array.
+This operator is typically used in Faster R-CNN & Mask R-CNN networks.
+
+Different from ROI pooling, ROI Align removes the harsh quantization, properly aligning
+the extracted features with the input. RoIAlign computes the value of each sampling point
+by bilinear interpolation from the nearby grid points on the feature map. No quantization is
+performed on any coordinates involved in the RoI, its bins, or the sampling points.
+Bilinear interpolation is used to compute the exact values of the
+input features at four regularly sampled locations in each RoI bin.
+Then the feature map can be aggregated by avgpooling.
+
+
+Reference
+---------
+
+He, Kaiming, et al. "Mask R-CNN." ICCV, 2017
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"data", "rois"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output"};
+})
+.set_attr_parser(ParamParser<ROIAlignParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
+      std::vector<TShape> *in_shape, std::vector<TShape> *out_shape){
+  using namespace mshadow;
+  const ROIAlignParam param = nnvm::get<ROIAlignParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]";
+  // data: [batch_size, c, h, w]
+  TShape dshape = in_shape->at(roialign::kData);
+  CHECK_EQ(dshape.ndim(), 4) << "data should be a 4D tensor";
+  // bbox: [num_rois, 5]
+  TShape bshape = in_shape->at(roialign::kBox);
+  CHECK_EQ(bshape.ndim(), 2) << "bbox should be a 2D tensor of shape [batch, 5]";
+  CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]";
+  // out: [num_rois, c, pooled_h, pooled_w]
+  out_shape->clear();
+  out_shape->push_back(
+       Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1]));
+  return true;
+})
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
+      std::vector<int> *in_type, std::vector<int> *out_type) {
+  CHECK_EQ(in_type->size(), 2);
+  int dtype = (*in_type)[0];
+  CHECK_EQ(dtype, (*in_type)[1]);
+  CHECK_NE(dtype, -1) << "Input must have specified type";
+
+  out_type->clear();
+  out_type->push_back(dtype);
+  return true;
+})
+.set_attr<FCompute>("FCompute<cpu>", ROIAlignForwardCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+    std::vector<nnvm::NodeEntry> heads;
+    heads.push_back(ograds[roialign::kOut]);
+    heads.push_back(n->inputs[roialign::kBox]);
+    return MakeGradNode("_backward_ROIAlign", n, heads, n->attrs.dict);
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps")
+.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array")
+.add_arguments(ROIAlignParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_ROIAlign)
+.set_num_outputs(2)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr_parser(ParamParser<ROIAlignParam>)
+.set_attr<FCompute>("FCompute<cpu>", ROIAlignBackwardCompute<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
+
diff --git a/src/operator/contrib/roi_align.cu b/src/operator/contrib/roi_align.cu
new file mode 100644
index 00000000000..21066ea15fa
--- /dev/null
+++ b/src/operator/contrib/roi_align.cu
@@ -0,0 +1,484 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file roi_align.cu
+ * \brief roi align operator
+ * \author Hang Zhang
+ * Adapted from Caffe2
+*/
+#include "./roi_align-inl.h"
+
+
+namespace mxnet {
+namespace op {
+
+#define CUDA_1D_KERNEL_LOOP(i, n)                                 \
+  for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+       i += blockDim.x * gridDim.x)
+
+using namespace mshadow::cuda;
+
+// The maximum number of blocks to use in the default kernel call.
+constexpr int ROI_MAXIMUM_NUM_BLOCKS = 4096;
+
+/**
+ * @brief Compute the number of blocks needed to run N threads.
+ */
+inline int ROI_GET_BLOCKS(const int N) {
+  return std::max(
+      std::min(
+          (N + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock,
+          ROI_MAXIMUM_NUM_BLOCKS),
+      // Use at least 1 block, since CUDA does not allow empty block
+      1);
+}
+
+
+template <typename T>
+__device__ T bilinear_interpolate(
+    const T* bottom_data,
+    const int height,
+    const int width,
+    T y,
+    T x,
+    const int index /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    // empty
+    return 0;
+  }
+
+  if (y <= 0) {
+    y = 0;
+  }
+  if (x <= 0) {
+    x = 0;
+  }
+
+  int y_low = static_cast<int>(y);
+  int x_low = static_cast<int>(x);
+  int y_high;
+  int x_high;
+
+  if (y_low >= height - 1) {
+    y_high = y_low = height - 1;
+    y = (T)y_low;
+  } else {
+    y_high = y_low + 1;
+  }
+
+  if (x_low >= width - 1) {
+    x_high = x_low = width - 1;
+    x = (T)x_low;
+  } else {
+    x_high = x_low + 1;
+  }
+
+  T ly = y - y_low;
+  T lx = x - x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+  // do bilinear interpolation
+  T v1 = bottom_data[y_low * width + x_low];
+  T v2 = bottom_data[y_low * width + x_high];
+  T v3 = bottom_data[y_high * width + x_low];
+  T v4 = bottom_data[y_high * width + x_high];
+  T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+  return val;
+}
+
+template <typename T>
+__global__ void RoIAlignForwardKernel(
+    const int nthreads,
+    const T* bottom_data,
+    const T spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    const T* bottom_rois,
+    T* top_data) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[1] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[3] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[4] * spatial_scale;
+    // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+    // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+    // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+    // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    const T* offset_bottom_data =
+        bottom_data + (roi_batch_ind * channels + c) * height * width;
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    T output_val = 0.;
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {  // e.g., iy = 0, 1
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h /
+              static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const T x = roi_start_w + pw * bin_size_w +
+            static_cast<T>(ix + .5f) * bin_size_w /
+                static_cast<T>(roi_bin_grid_w);
+
+        T val = bilinear_interpolate(
+            offset_bottom_data, height, width, y, x, index);
+        output_val += val;
+      }
+    }
+    output_val /= count;
+
+    top_data[index] = output_val;
+  }
+}
+
+
+template <typename T>
+__device__ void bilinear_interpolate_gradient(
+    const int height,
+    const int width,
+    T y,
+    T x,
+    T* w1,
+    T* w2,
+    T* w3,
+    T* w4,
+    int* x_low,
+    int* x_high,
+    int* y_low,
+    int* y_high,
+    const int /*index*/ /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    // empty
+    *w1 = *w2 = *w3 = *w4 = 0.;
+    *x_low = *x_high = *y_low = *y_high = -1;
+    return;
+  }
+
+  if (y <= 0) {
+    y = 0;
+  }
+  if (x <= 0) {
+    x = 0;
+  }
+
+  *y_low = static_cast<int>(y);
+  *x_low = static_cast<int>(x);
+
+  if (*y_low >= height - 1) {
+    *y_high = *y_low = height - 1;
+    y = (T)*y_low;
+  } else {
+    *y_high = *y_low + 1;
+  }
+
+  if (*x_low >= width - 1) {
+    *x_high = *x_low = width - 1;
+    x = (T)*x_low;
+  } else {
+    *x_high = *x_low + 1;
+  }
+
+  T ly = y - *y_low;
+  T lx = x - *x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+
+  // reference in forward
+  // T v1 = bottom_data[*y_low * width + *x_low];
+  // T v2 = bottom_data[*y_low * width + *x_high];
+  // T v3 = bottom_data[*y_high * width + *x_low];
+  // T v4 = bottom_data[*y_high * width + *x_high];
+  // T val = (w1 * v1 + *w2 * v2 + *w3 * v3 + *w4 * v4);
+
+  *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;
+
+  return;
+}
+
+template <typename T>
+__global__ void RoIAlignBackwardKernel(
+    const int nthreads,
+    const T* top_diff,
+    const int num_rois,
+    const T spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    T* bottom_diff,
+    const T* bottom_rois) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_bottom_rois = bottom_rois + n * 5;
+    int roi_batch_ind = offset_bottom_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_w = offset_bottom_rois[1] * spatial_scale;
+    T roi_start_h = offset_bottom_rois[2] * spatial_scale;
+    T roi_end_w = offset_bottom_rois[3] * spatial_scale;
+    T roi_end_h = offset_bottom_rois[4] * spatial_scale;
+    // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+    // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+    // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+    // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    T* offset_bottom_diff =
+        bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+    int top_offset = (n * channels + c) * pooled_height * pooled_width;
+    const T* offset_top_diff = top_diff + top_offset;
+    const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height);  // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) {  // e.g., iy = 0, 1
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h /
+              static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const T x = roi_start_w + pw * bin_size_w +
+            static_cast<T>(ix + .5f) * bin_size_w /
+                static_cast<T>(roi_bin_grid_w);
+
+        T w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient(
+            height,
+            width,
+            y,
+            x,
+            &w1,
+            &w2,
+            &w3,
+            &w4,
+            &x_low,
+            &x_high,
+            &y_low,
+            &y_high,
+            index);
+
+        T g1 = top_diff_this_bin * w1 / count;
+        T g2 = top_diff_this_bin * w2 / count;
+        T g3 = top_diff_this_bin * w3 / count;
+        T g4 = top_diff_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          atomicAdd(
+              offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
+          atomicAdd(
+              offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
+          atomicAdd(
+              offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
+          atomicAdd(
+              offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
+          /*
+          gpu_atomic_add(
+              static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
+          gpu_atomic_add(
+              static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
+          gpu_atomic_add(
+              static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
+          gpu_atomic_add(
+              static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
+          */
+        }  // if
+      }  // ix
+    }  // iy
+  }  // CUDA_1D_KERNEL_LOOP
+}  // RoIAlignBackward
+
+template<typename xpu>
+void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
+                            const OpContext& ctx,
+                            const std::vector<TBlob>& in_data,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& out_data) {
+  using namespace mshadow;
+  size_t expected_in = 2;
+  size_t expected_out = 1;
+  CHECK_EQ(in_data.size(), expected_in);
+  CHECK_EQ(out_data.size(), expected_out);
+  CHECK_EQ(out_data[roialign::kOut].shape_[0], in_data[roialign::kBox].shape_[0]);
+
+  const ROIAlignParam param = nnvm::get<ROIAlignParam>(attrs.parsed);
+
+  const int count = out_data[roialign::kOut].Size();
+  const int num_rois = in_data[roialign::kBox].size(0);
+  const int channels = in_data[roialign::kData].size(1);
+  const int height = in_data[roialign::kData].size(2);
+  const int width = in_data[roialign::kData].size(3);
+  const int pooled_height = out_data[roialign::kOut].size(2);
+  const int pooled_width = out_data[roialign::kOut].size(3);
+
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+  MSHADOW_REAL_TYPE_SWITCH(in_data[0].type_flag_, DType, {
+    const DType *bottom_data = in_data[roialign::kData].dptr<DType>();
+    const DType *bottom_rois = in_data[roialign::kBox].dptr<DType>();
+    DType *top_data = out_data[roialign::kOut].dptr<DType>();
+    RoIAlignForwardKernel<DType>
+      <<<ROI_GET_BLOCKS(count),
+         kMaxThreadsPerBlock,
+         0,
+         stream>>>(
+          count,
+          bottom_data,
+          param.spatial_scale,
+          channels,
+          height,
+          width,
+          pooled_height,
+          pooled_width,
+          -1,
+          bottom_rois,
+          top_data);
+  })
+}
+
+
+template<typename xpu>
+void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+
+  CHECK_EQ(inputs.size(), 2);
+  CHECK_EQ(outputs.size(), 2);
+  // the order here relates to the order in ROIAlignGrad
+  std::vector<TBlob> out_grad(1, inputs[0]);
+  std::vector<TBlob> in_data(1, inputs[1]);
+  // std::vector<TBlob> out_data(1, inputs[2]);
+
+  CHECK_EQ(out_grad[0].shape_[0], in_data[0].shape_[0]);
+  CHECK_NE(req[0], kWriteInplace) <<
+    "ROIAlign: Backward doesn't support kWriteInplace.";
+  CHECK_NE(req[1], kWriteInplace) <<
+    "ROIAlign: Backward doesn't support kWriteInplace.";
+
+  const ROIAlignParam param = nnvm::get<ROIAlignParam>(attrs.parsed);
+
+  const int count = out_grad[0].Size();
+  const int num_rois = in_data[0].size(0);
+  const int channels = outputs[0].size(1);
+  const int height = outputs[0].size(2);
+  const int width = outputs[0].size(3);
+  const int pooled_height = out_grad[0].size(2);
+  const int pooled_width = out_grad[0].size(3);
+
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+
+  // assume all the data and gradient have the same type
+  MSHADOW_REAL_TYPE_SWITCH(out_grad[0].type_flag_, DType, {
+    const DType *top_diff = out_grad[0].dptr<DType>();
+    const DType *bottom_rois = in_data[0].dptr<DType>();
+    DType *grad_in = outputs[0].dptr<DType>();
+
+    if (kWriteTo == req[roialign::kBox]) {
+      Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
+    }
+    if (kNullOp == req[roialign::kData]) return;
+    if (kWriteTo == req[roialign::kData]) {
+      Fill<false>(s, outputs[0], kWriteTo, static_cast<DType>(0));
+    }
+    RoIAlignBackwardKernel<DType>
+    <<<ROI_GET_BLOCKS(count),
+       kMaxThreadsPerBlock,
+       0,
+       stream>>>(
+        count,
+        top_diff,
+        num_rois,
+        param.spatial_scale,
+        channels,
+        height,
+        width,
+        pooled_height,
+        pooled_width,
+        -1,
+        grad_in,
+        bottom_rois);
+  })
+}
+
+
+NNVM_REGISTER_OP(_contrib_ROIAlign)
+.set_attr<FCompute>("FCompute<gpu>", ROIAlignForwardCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_ROIAlign)
+.set_attr<FCompute>("FCompute<gpu>", ROIAlignBackwardCompute<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index e7976e01f9d..c5bdee1d1e5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5965,6 +5965,7 @@ def get_output_names_callback(name, arr):
                             name='pooling')
     check_name(us_sym, ['pooling_output'])
 
+
 @with_seed()
 def test_activation():
     shape=(9, 10)
@@ -6018,6 +6019,157 @@ def test_context_num_gpus():
         if str(e).find("CUDA") == -1:
             raise e
 
+    
+@with_seed()
+def test_op_roi_align():
+    # Adapted from https://github.com/wkcn/MobulaOP/blob/master/tests/test_roi_align_op.py
+    def bilinear_interpolate(bottom, height, width, y, x):
+        if y < -1.0 or y > height or x < -1.0 or x > width:
+            return 0.0, []
+        x = max(0.0, x)
+        y = max(0.0, y)
+        x_low = int(x)
+        y_low = int(y)
+        if x_low >= width - 1:
+            x_low = x_high = width - 1
+            x = x_low
+        else:
+            x_high = x_low + 1
+
+        if y_low >= height - 1:
+            y_low = y_high = height - 1
+            y = y_low
+        else:
+            y_high = y_low + 1
+
+        ly = y - y_low
+        lx = x - x_low
+        hy = 1.0 - ly
+        hx = 1.0 - lx
+
+        v1 = bottom[y_low, x_low]
+        v2 = bottom[y_low, x_high]
+        v3 = bottom[y_high, x_low]
+        v4 = bottom[y_high, x_high]
+
+        '''
+        ----------->x
+        |hx hy | lx hy
+        |------+------
+        |hx ly | lx ly
+        V
+        y
+        v1|v2
+        --+--
+        v3|v4
+        '''
+        w1 = hy * hx
+        w2 = hy * lx
+        w3 = ly * hx
+        w4 = ly * lx
+
+        val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
+        grad = [(y_low, x_low, w1), (y_low, x_high, w2),
+                (y_high, x_low, w3), (y_high, x_high, w4)
+               ]
+        return val, grad
+
+    def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio, dy):
+        N, C, H, W = data.shape
+        R = rois.shape[0]
+        PH, PW = pooled_size
+        assert len(rois.shape) == 2
+        assert rois.shape[1] == 5
+
+        out = np.zeros((R, C, PH, PW))
+        dx = np.zeros_like(data)
+        drois = np.zeros_like(rois)
+
+        for r in range(R):
+            batch_ind = int(rois[r, 0])
+            sw, sh, ew, eh = rois[r, 1:5] * spatial_scale
+            roi_w = max(ew - sw, 1.0)
+            roi_h = max(eh - sh, 1.0)
+            bin_h = roi_h * 1.0 / PH
+            bin_w = roi_w * 1.0 / PW
+            bdata = data[batch_ind]
+            if sampling_ratio > 0:
+                roi_bin_grid_h = roi_bin_grid_w = sampling_ratio
+            else:
+                roi_bin_grid_h = int(np.ceil(roi_h * 1.0 / PH))
+                roi_bin_grid_w = int(np.ceil(roi_w * 1.0 / PW))
+            count = roi_bin_grid_h * roi_bin_grid_w
+            for c in range(C):
+                for ph in range(PH):
+                    for pw in range(PW):
+                        val = 0.0
+                        for iy in range(roi_bin_grid_h):
+                            y = sh + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
+                            for ix in range(roi_bin_grid_w):
+                                x = sw + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
+                                v, g = bilinear_interpolate(bdata[c], H, W, y, x)
+                                val += v
+                                # compute grad
+                                for qy, qx, qw in g:
+                                    dx[batch_ind, c, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count
+
+                        out[r, c, ph, pw] = val * 1.0 / count
+        return out, [dx, drois]
+
+    def test_roi_align_value():
+        ctx=default_context()
+        dtype = np.float32
+
+        dlen = 224
+        N, C, H, W = 5, 3, 16, 16
+        assert H == W
+        R = 7
+        pooled_size = (3, 4)
+
+        spatial_scale = H * 1.0 / dlen
+        sampling_ratio = 0
+        data = mx.nd.array(np.arange(N*C*W*H).reshape((N,C,H,W)), ctx=ctx, dtype = dtype)
+        # data = mx.nd.random.uniform(0, 1, (N, C, H, W), dtype = dtype)
+        center_xy = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype = dtype)
+        wh = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype = dtype)
+        batch_ind = mx.nd.array(np.random.randint(0, N, size = (R,1)), ctx=ctx)
+        pos = mx.nd.concat(center_xy - wh / 2, center_xy + wh / 2, dim = 1)
+        rois = mx.nd.concat(batch_ind, pos, dim = 1)
+
+        data.attach_grad()
+        rois.attach_grad()
+        with mx.autograd.record():
+            output = mx.nd.contrib.ROIAlign(data, rois, pooled_size=pooled_size,
+                    spatial_scale=spatial_scale)
+        dy = mx.nd.random.uniform(-1, 1, (R, C) + pooled_size, ctx=ctx, dtype = dtype)
+        output.backward(dy)
+        real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size, spatial_scale, sampling_ratio, dy.asnumpy())
+        assert np.allclose(output.asnumpy(), real_output)
+        # It seems that the precision between Cfloat and Pyfloat is different.
+        assert np.allclose(data.grad.asnumpy(), dx, atol = 1e-6), np.abs(data.grad.asnumpy() - dx).max()
+        assert np.allclose(rois.grad.asnumpy(), drois)
+
+    # modified from test_roipooling()
+    def test_roi_align_autograd():
+        ctx=default_context()
+        data = mx.symbol.Variable(name='data')
+        rois = mx.symbol.Variable(name='rois')
+        test = mx.symbol.contrib.ROIAlign(data=data, rois=rois, pooled_size=(4, 4), spatial_scale=1)
+
+        x1 = np.random.rand(4, 1, 12, 12).astype('float64')
+        x2 = np.array([[0, 1.1, 1.1, 6.2, 6.2], [2, 6.1, 2.1, 8.2, 11.2],
+                       [1, 3.1, 1.1, 5.2, 10.2]], dtype='float64')
+
+        check_numeric_gradient(sym=test, location=[x1, x2],
+                               grad_nodes={'data':'write', 'rois':'null'},
+                               numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)
+        check_numeric_gradient(sym=test, location=[x1, x2],
+                               grad_nodes={'data':'add', 'rois':'null'},
+                               numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)
+
+    test_roi_align_value()
+    test_roi_align_autograd()
+
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services