You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ar...@apache.org on 2019/05/24 16:21:41 UTC

[incubator-mxnet] branch master updated: Add cpu implementation for Deformable PSROIPooling (#14886)

This is an automated email from the ASF dual-hosted git repository.

arcadiaphy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new eb0b8af  Add cpu implementation for Deformable PSROIPooling (#14886)
eb0b8af is described below

commit eb0b8af43091bfebc6e4b89bfee7d97c095785d1
Author: Wang Jiajun <wa...@gmail.com>
AuthorDate: Fri May 24 11:21:15 2019 -0500

    Add cpu implementation for Deformable PSROIPooling (#14886)
    
    * add cpu deformable_psroi_pooling forward
    
    * add cpu deformable_psroi_pooling backward
    
    * add consistency checks
    
    * fix nullptr
    
    * fix code style
    
    * fix lint
    
    * fix code style
    
    * update to index_t
    
    * fix lint
    
    * fix compile
---
 .../contrib/deformable_psroi_pooling-inl.h         |  50 +--
 src/operator/contrib/deformable_psroi_pooling.cc   | 328 +++++++++++++++--
 src/operator/contrib/deformable_psroi_pooling.cu   | 406 ++++++++++-----------
 tests/python/gpu/test_operator_gpu.py              |  18 +
 4 files changed, 534 insertions(+), 268 deletions(-)

diff --git a/src/operator/contrib/deformable_psroi_pooling-inl.h b/src/operator/contrib/deformable_psroi_pooling-inl.h
index e466c06..78124d2 100644
--- a/src/operator/contrib/deformable_psroi_pooling-inl.h
+++ b/src/operator/contrib/deformable_psroi_pooling-inl.h
@@ -51,11 +51,11 @@ namespace deformablepsroipool {
 struct DeformablePSROIPoolingParam : public dmlc::Parameter<DeformablePSROIPoolingParam> {
   // mxnet::TShape pooled_size;
   float spatial_scale;
-  int output_dim;
-  int group_size;
-  int pooled_size;
-  int part_size;
-  int sample_per_part;
+  index_t output_dim;
+  index_t group_size;
+  index_t pooled_size;
+  index_t part_size;
+  index_t sample_per_part;
   float trans_std;
   bool no_trans;
   DMLC_DECLARE_PARAMETER(DeformablePSROIPoolingParam) {
@@ -82,10 +82,10 @@ class DeformablePSROIPoolingOp : public Operator {
   }
 
   virtual void Forward(const OpContext &ctx,
-    const std::vector<TBlob> &in_data,
-    const std::vector<OpReqType> &req,
-    const std::vector<TBlob> &out_data,
-    const std::vector<TBlob> &aux_args) {
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     size_t in_expected = param_.no_trans? 2 : 3;
     size_t out_expected = 2;
@@ -119,12 +119,12 @@ class DeformablePSROIPoolingOp : public Operator {
   }
 
   virtual void Backward(const OpContext &ctx,
-    const std::vector<TBlob> &out_grad,
-    const std::vector<TBlob> &in_data,
-    const std::vector<TBlob> &out_data,
-    const std::vector<OpReqType> &req,
-    const std::vector<TBlob> &in_grad,
-    const std::vector<TBlob> &aux_args) {
+                        const std::vector<TBlob> &out_grad,
+                        const std::vector<TBlob> &in_data,
+                        const std::vector<TBlob> &out_data,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &in_grad,
+                        const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     size_t in_expected = param_.no_trans ? 2 : 3;
     size_t out_expected = 2;
@@ -216,8 +216,8 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
   }
 
   bool InferShape(mxnet::ShapeVector *in_shape,
-    mxnet::ShapeVector *out_shape,
-    mxnet::ShapeVector *aux_shape) const override {
+                  mxnet::ShapeVector *out_shape,
+                  mxnet::ShapeVector *aux_shape) const override {
     using namespace mshadow;
     if (param_.no_trans) {
       CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]";
@@ -248,8 +248,8 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
   }
 
   bool InferType(std::vector<int> *in_type,
-    std::vector<int> *out_type,
-    std::vector<int> *aux_type) const override {
+                 std::vector<int> *out_type,
+                 std::vector<int> *aux_type) const override {
     CHECK_GE(in_type->size(), 2);
     int dtype = (*in_type)[0];
     CHECK_EQ(dtype, (*in_type)[1]);
@@ -272,10 +272,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
   }
 
   // decalre dependency and inplace optimization options
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
+  std::vector<int> DeclareBackwardDependency(const std::vector<int> &out_grad,
+                                             const std::vector<int> &in_data,
+                                             const std::vector<int> &out_data) const override {
     if (param_.no_trans) {
       return{ out_grad[deformablepsroipool::kOut], in_data[deformablepsroipool::kData],
               in_data[deformablepsroipool::kBox], out_data[deformablepsroipool::kTopCount] };
@@ -292,8 +291,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
     return NULL;
   }
 
-  Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
-    std::vector<int> *in_type) const override;
+  Operator* CreateOperatorEx(Context ctx,
+                             mxnet::ShapeVector *in_shape,
+                             std::vector<int> *in_type) const override;
 
 
  private:
diff --git a/src/operator/contrib/deformable_psroi_pooling.cc b/src/operator/contrib/deformable_psroi_pooling.cc
index d9d4cf8..697376d 100644
--- a/src/operator/contrib/deformable_psroi_pooling.cc
+++ b/src/operator/contrib/deformable_psroi_pooling.cc
@@ -35,43 +35,309 @@ using std::max;
 using std::min;
 using std::floor;
 using std::ceil;
+using std::round;
 
 namespace mshadow {
+
+  template <typename DType>
+  inline DType bilinear_interp_cpu(const DType* data,
+                                   const DType x, const DType y,
+                                   const index_t width, const index_t height) {
+    index_t x1 = floor(x);
+    index_t x2 = ceil(x);
+    index_t y1 = floor(y);
+    index_t y2 = ceil(y);
+    DType dist_x = static_cast<DType>(x - x1);
+    DType dist_y = static_cast<DType>(y - y1);
+    DType value11 = data[y1 * width + x1];
+    DType value12 = data[y2 * width + x1];
+    DType value21 = data[y1 * width + x2];
+    DType value22 = data[y2 * width + x2];
+    DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 +
+      dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
+    return value;
+  }
+
+  template <typename DType>
+  inline void DeformablePSROIPoolForwardCPU(const index_t count, const DType* bottom_data,
+                                            const DType spatial_scale, const index_t channels,
+                                            const index_t height, const index_t width,
+                                            const index_t pooled_height, const index_t pooled_width,
+                                            const DType* bottom_rois, const DType* bottom_trans,
+                                            const bool no_trans, const DType trans_std,
+                                            const index_t sample_per_part, const index_t output_dim,
+                                            const index_t group_size, const index_t part_size,
+                                            const index_t num_classes,
+                                            const index_t channels_each_class,
+                                            DType* top_data, DType* top_count) {
+    const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+#pragma omp parallel for num_threads(omp_threads)
+    for (index_t index = 0; index < count; index++) {
+      // The output is in order (n, ctop, ph, pw)
+      index_t pw = index % pooled_width;
+      index_t ph = (index / pooled_width) % pooled_height;
+      index_t ctop = (index / pooled_width / pooled_height) % output_dim;
+      index_t n = index / pooled_width / pooled_height / output_dim;
+
+      // [start, end) interval for spatial sampling
+      const DType* offset_bottom_rois = bottom_rois + n * 5;
+      index_t roi_batch_ind = offset_bottom_rois[0];
+      DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+      DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+      DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+      DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+      // Force too small ROIs to be 1x1
+      DType roi_width = max(roi_end_w - roi_start_w, static_cast<DType>(0.1));  // avoid 0
+      DType roi_height = max(roi_end_h - roi_start_h, static_cast<DType>(0.1));
+
+      // Compute w and h at bottom
+      DType bin_size_h = roi_height / static_cast<DType>(pooled_height);
+      DType bin_size_w = roi_width / static_cast<DType>(pooled_width);
+
+      DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
+      DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);
+
+      index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
+      index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
+      index_t class_id = ctop / channels_each_class;
+      DType trans_x = no_trans ? static_cast<DType>(0) :
+        bottom_trans[(((n * num_classes + class_id) * 2)
+                        * part_size + part_h)
+                        * part_size + part_w] * trans_std;
+      DType trans_y = no_trans ? static_cast<DType>(0) :
+        bottom_trans[(((n * num_classes + class_id) * 2 + 1)
+                        * part_size + part_h)
+                        * part_size + part_w] * trans_std;
+
+      DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w;
+      wstart += trans_x * roi_width;
+      DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h;
+      hstart += trans_y * roi_height;
+
+      DType sum = 0;
+      index_t count = 0;
+      index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
+      index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
+      gw = min(max(gw, static_cast<index_t>(0)), group_size - 1);
+      gh = min(max(gh, static_cast<index_t>(0)), group_size - 1);
+
+      const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
+      for (index_t ih = 0; ih < sample_per_part; ih++) {
+        for (index_t iw = 0; iw < sample_per_part; iw++) {
+          DType w = wstart + iw * sub_bin_size_w;
+          DType h = hstart + ih * sub_bin_size_h;
+          // bilinear interpolation
+          if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) {
+            continue;
+          }
+          w = min(max(w, static_cast<DType>(0)), static_cast<DType>(width - 1));
+          h = min(max(h, static_cast<DType>(0)), static_cast<DType>(height - 1));
+          index_t c = (ctop * group_size + gh) * group_size + gw;
+          DType val = bilinear_interp_cpu(offset_bottom_data + c * height * width,
+                                          w, h, width, height);
+          sum += val;
+          count++;
+        }
+      }
+      top_data[index] = count == 0 ? static_cast<DType>(0) : sum / count;
+      top_count[index] = count;
+    }
+  }
+
   template<typename DType>
   inline void DeformablePSROIPoolForward(const Tensor<cpu, 4, DType> &out,
-    const Tensor<cpu, 4, DType> &data,
-    const Tensor<cpu, 2, DType> &bbox,
-    const Tensor<cpu, 4, DType> &trans,
-    const Tensor<cpu, 4, DType> &top_count,
-    const bool no_trans,
-    const float spatial_scale,
-    const int output_dim,
-    const int group_size,
-    const int pooled_size,
-    const int part_size,
-    const int sample_per_part,
-    const float trans_std) {
-    // NOT_IMPLEMENTED;
+                                         const Tensor<cpu, 4, DType> &data,
+                                         const Tensor<cpu, 2, DType> &bbox,
+                                         const Tensor<cpu, 4, DType> &trans,
+                                         const Tensor<cpu, 4, DType> &top_count,
+                                         const bool no_trans, const float spatial_scale,
+                                         const index_t output_dim, const index_t group_size,
+                                         const index_t pooled_size, const index_t part_size,
+                                         const index_t sample_per_part, const float trans_std) {
+    const DType *bottom_data = data.dptr_;
+    const DType *bottom_rois = bbox.dptr_;
+    const DType *bottom_trans = no_trans ? nullptr : trans.dptr_;
+    DType *top_data = out.dptr_;
+    DType *top_count_data = top_count.dptr_;
+    const index_t count = out.shape_.Size();
+    const index_t channels = data.size(1);
+    const index_t height = data.size(2);
+    const index_t width = data.size(3);
+    const index_t pooled_height = pooled_size;
+    const index_t pooled_width = pooled_size;
+    const index_t num_classes = no_trans ? 1 : trans.size(1) / 2;
+    const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+    DeformablePSROIPoolForwardCPU<DType>(count, bottom_data, spatial_scale,
+                                         channels, height, width,
+                                         pooled_height, pooled_width,
+                                         bottom_rois, bottom_trans,
+                                         no_trans, trans_std, sample_per_part,
+                                         output_dim, group_size, part_size, num_classes,
+                                         channels_each_class, top_data, top_count_data);
+
     return;
   }
 
+  template <typename DType>
+  inline void DeformablePSROIPoolBackwardAccCPU(const index_t count, const DType* top_diff,
+                                                const DType* top_count, const index_t num_rois,
+                                                const DType spatial_scale, const index_t channels,
+                                                const index_t height, const index_t width,
+                                                const index_t pooled_height,
+                                                const index_t pooled_width,
+                                                const index_t output_dim,
+                                                DType* bottom_data_diff,
+                                                DType* bottom_trans_diff,
+                                                const DType* bottom_data,
+                                                const DType* bottom_rois,
+                                                const DType* bottom_trans,
+                                                const bool no_trans,
+                                                const DType trans_std,
+                                                const index_t sample_per_part,
+                                                const index_t group_size,
+                                                const index_t part_size,
+                                                const index_t num_classes,
+                                                const index_t channels_each_class) {
+    for (index_t index = 0; index < count; index++) {
+      // The output is in order (n, ctop, ph, pw)
+      index_t pw = index % pooled_width;
+      index_t ph = (index / pooled_width) % pooled_height;
+      index_t ctop = (index / pooled_width / pooled_height) % output_dim;
+      index_t n = index / pooled_width / pooled_height / output_dim;
+
+      // [start, end) interval for spatial sampling
+      const DType* offset_bottom_rois = bottom_rois + n * 5;
+      index_t roi_batch_ind = offset_bottom_rois[0];
+      DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+      DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+      DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+      DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+      // Force too small ROIs to be 1x1
+      DType roi_width = max(roi_end_w - roi_start_w, static_cast<DType>(0.1));  // avoid 0
+      DType roi_height = max(roi_end_h - roi_start_h, static_cast<DType>(0.1));
+
+      // Compute w and h at bottom
+      DType bin_size_h = roi_height / static_cast<DType>(pooled_height);
+      DType bin_size_w = roi_width / static_cast<DType>(pooled_width);
+
+      DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
+      DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);
+
+      index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
+      index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
+      index_t class_id = ctop / channels_each_class;
+      DType trans_x = no_trans ? static_cast<DType>(0) :
+        bottom_trans[(((n * num_classes + class_id) * 2)
+                        * part_size + part_h)
+                        * part_size + part_w] * trans_std;
+      DType trans_y = no_trans ? static_cast<DType>(0) :
+        bottom_trans[(((n * num_classes + class_id) * 2 + 1)
+                        * part_size + part_h)
+                        * part_size + part_w] * trans_std;
+
+      DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w;
+      wstart += trans_x * roi_width;
+      DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h;
+      hstart += trans_y * roi_height;
+
+      if (top_count[index] <= 0) {
+        continue;
+      }
+      DType diff_val = top_diff[index] / top_count[index];
+      const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
+      DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
+      index_t gw = floor(static_cast<DType>(pw)* group_size / pooled_width);
+      index_t gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
+      gw = min(max(gw, static_cast<index_t>(0)), group_size - 1);
+      gh = min(max(gh, static_cast<index_t>(0)), group_size - 1);
+
+      for (index_t ih = 0; ih < sample_per_part; ih++) {
+        for (index_t iw = 0; iw < sample_per_part; iw++) {
+          DType w = wstart + iw * sub_bin_size_w;
+          DType h = hstart + ih * sub_bin_size_h;
+          // bilinear interpolation
+          if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) {
+            continue;
+          }
+          w = min(max(w, static_cast<DType>(0)), static_cast<DType>(width - 1));
+          h = min(max(h, static_cast<DType>(0)), static_cast<DType>(height - 1));
+          index_t c = (ctop * group_size + gh) * group_size + gw;
+          // backward on feature
+          index_t x0 = floor(w);
+          index_t x1 = ceil(w);
+          index_t y0 = floor(h);
+          index_t y1 = ceil(h);
+          DType dist_x = w - x0, dist_y = h - y0;
+          DType q00 = (1 - dist_x) * (1 - dist_y);
+          DType q01 = (1 - dist_x) * dist_y;
+          DType q10 = dist_x * (1 - dist_y);
+          DType q11 = dist_x * dist_y;
+          index_t bottom_index_base = c * height * width;
+          offset_bottom_data_diff[bottom_index_base + y0 * width + x0] += q00 * diff_val;
+          offset_bottom_data_diff[bottom_index_base + y1 * width + x0] += q01 * diff_val;
+          offset_bottom_data_diff[bottom_index_base + y0 * width + x1] += q10 * diff_val;
+          offset_bottom_data_diff[bottom_index_base + y1 * width + x1] += q11 * diff_val;
+
+          if (no_trans) {
+            continue;
+          }
+          DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
+          DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
+          DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
+          DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
+          DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y);
+          diff_x *= trans_std * diff_val * roi_width;
+          DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x);
+          diff_y *= trans_std * diff_val * roi_height;
+
+          index_t offset_trans_diff = (((n * num_classes + class_id) * 2)
+            * part_size + part_h) * part_size + part_w;
+          bottom_trans_diff[offset_trans_diff] += diff_x;
+          bottom_trans_diff[offset_trans_diff + part_size * part_size] += diff_y;
+        }
+      }
+    }
+  }
+
   template<typename DType>
   inline void DeformablePSROIPoolBackwardAcc(const Tensor<cpu, 4, DType> &in_grad,
-    const Tensor<cpu, 4, DType> &trans_grad,
-    const Tensor<cpu, 4, DType> &out_grad,
-    const Tensor<cpu, 4, DType> &data,
-    const Tensor<cpu, 2, DType> &bbox,
-    const Tensor<cpu, 4, DType> &trans,
-    const Tensor<cpu, 4, DType> &top_count,
-    const bool no_trans,
-    const float spatial_scale,
-    const int output_dim,
-    const int group_size,
-    const int pooled_size,
-    const int part_size,
-    const int sample_per_part,
-    const float trans_std) {
-    // NOT_IMPLEMENTED;
+                                             const Tensor<cpu, 4, DType> &trans_grad,
+                                             const Tensor<cpu, 4, DType> &out_grad,
+                                             const Tensor<cpu, 4, DType> &data,
+                                             const Tensor<cpu, 2, DType> &bbox,
+                                             const Tensor<cpu, 4, DType> &trans,
+                                             const Tensor<cpu, 4, DType> &top_count,
+                                             const bool no_trans, const float spatial_scale,
+                                             const index_t output_dim, const index_t group_size,
+                                             const index_t pooled_size, const index_t part_size,
+                                             const index_t sample_per_part, const float trans_std) {
+    const DType *top_diff = out_grad.dptr_;
+    const DType *bottom_data = data.dptr_;
+    const DType *bottom_rois = bbox.dptr_;
+    const DType *bottom_trans = no_trans ? nullptr : trans.dptr_;
+    DType *bottom_data_diff = in_grad.dptr_;
+    DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_;
+    const DType *top_count_data = top_count.dptr_;
+    const index_t count = out_grad.shape_.Size();
+    const index_t num_rois = bbox.size(0);
+    const index_t channels = in_grad.size(1);
+    const index_t height = in_grad.size(2);
+    const index_t width = in_grad.size(3);
+    const index_t pooled_height = pooled_size;
+    const index_t pooled_width = pooled_size;
+    const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2;
+    const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+    DeformablePSROIPoolBackwardAccCPU<DType>(count, top_diff, top_count_data, num_rois,
+                                             spatial_scale, channels, height, width,
+                                             pooled_height, pooled_width, output_dim,
+                                             bottom_data_diff, bottom_trans_diff,
+                                             bottom_data, bottom_rois, bottom_trans,
+                                             no_trans, trans_std, sample_per_part,
+                                             group_size, part_size, num_classes,
+                                             channels_each_class);
+
     return;
   }
 }  // namespace mshadow
@@ -88,9 +354,9 @@ namespace op {
     return op;
   }
 
-  Operator *DeformablePSROIPoolingProp::CreateOperatorEx(
-    Context ctx, mxnet::ShapeVector *in_shape,
-    std::vector<int> *in_type) const {
+  Operator *DeformablePSROIPoolingProp::CreateOperatorEx(Context ctx,
+                                                         mxnet::ShapeVector *in_shape,
+                                                         std::vector<int> *in_type) const {
     mxnet::ShapeVector out_shape, aux_shape;
     std::vector<int> out_type, aux_type;
     CHECK(InferType(in_type, &out_type, &aux_type));
diff --git a/src/operator/contrib/deformable_psroi_pooling.cu b/src/operator/contrib/deformable_psroi_pooling.cu
index bf7d1c0..6c89746 100644
--- a/src/operator/contrib/deformable_psroi_pooling.cu
+++ b/src/operator/contrib/deformable_psroi_pooling.cu
@@ -46,56 +46,52 @@ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
 namespace mshadow {
 namespace cuda {
   template <typename DType>
-  __device__ DType bilinear_interp(
-    const DType* data,
-    const DType x,
-    const DType y,
-    const int width,
-    const int height) {
-    int x1 = floor(x);
-    int x2 = ceil(x);
-    int y1 = floor(y);
-    int y2 = ceil(y);
+  __device__ DType bilinear_interp(const DType* data,
+                                   const DType x, const DType y,
+                                   const index_t width, const index_t height) {
+    index_t x1 = floor(x);
+    index_t x2 = ceil(x);
+    index_t y1 = floor(y);
+    index_t y2 = ceil(y);
     DType dist_x = static_cast<DType>(x - x1);
     DType dist_y = static_cast<DType>(y - y1);
-    DType value11 = data[y1*width + x1];
-    DType value12 = data[y2*width + x1];
-    DType value21 = data[y1*width + x2];
-    DType value22 = data[y2*width + x2];
-    DType value = (1 - dist_x)*(1 - dist_y)*value11 + (1 - dist_x)*dist_y*value12
-      + dist_x*(1 - dist_y)*value21 + dist_x*dist_y*value22;
+    DType value11 = data[y1 * width + x1];
+    DType value12 = data[y2 * width + x1];
+    DType value21 = data[y1 * width + x2];
+    DType value22 = data[y2 * width + x2];
+    DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 +
+      dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
     return value;
   }
 
   template <typename DType>
-  __global__ void DeformablePSROIPoolForwardKernel(
-    const int count,
-    const DType* bottom_data,
-    const DType spatial_scale,
-    const int channels,
-    const int height, const int width,
-    const int pooled_height, const int pooled_width,
-    const DType* bottom_rois, const DType* bottom_trans,
-    const bool no_trans,
-    const DType trans_std,
-    const int sample_per_part,
-    const int output_dim,
-    const int group_size,
-    const int part_size,
-    const int num_classes,
-    const int channels_each_class,
-    DType* top_data,
-    DType* top_count) {
+  __global__ void DeformablePSROIPoolForwardKernel(const index_t count,
+                                                   const DType* bottom_data,
+                                                   const DType spatial_scale,
+                                                   const index_t channels,
+                                                   const index_t height, const index_t width,
+                                                   const index_t pooled_height,
+                                                   const index_t pooled_width,
+                                                   const DType* bottom_rois,
+                                                   const DType* bottom_trans,
+                                                   const bool no_trans, const DType trans_std,
+                                                   const index_t sample_per_part,
+                                                   const index_t output_dim,
+                                                   const index_t group_size,
+                                                   const index_t part_size,
+                                                   const index_t num_classes,
+                                                   const index_t channels_each_class,
+                                                   DType* top_data, DType* top_count) {
     CUDA_KERNEL_LOOP(index, count) {
       // The output is in order (n, ctop, ph, pw)
-      int pw = index % pooled_width;
-      int ph = (index / pooled_width) % pooled_height;
-      int ctop = (index / pooled_width / pooled_height) % output_dim;
-      int n = index / pooled_width / pooled_height / output_dim;
+      index_t pw = index % pooled_width;
+      index_t ph = (index / pooled_width) % pooled_height;
+      index_t ctop = (index / pooled_width / pooled_height) % output_dim;
+      index_t n = index / pooled_width / pooled_height / output_dim;
 
       // [start, end) interval for spatial sampling
       const DType* offset_bottom_rois = bottom_rois + n * 5;
-      int roi_batch_ind = offset_bottom_rois[0];
+      index_t roi_batch_ind = offset_bottom_rois[0];
       DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
       DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
       DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
@@ -112,9 +108,9 @@ namespace cuda {
       DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
       DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);
 
-      int part_h = floor(static_cast<DType>(ph) / pooled_height*part_size);
-      int part_w = floor(static_cast<DType>(pw) / pooled_width*part_size);
-      int class_id = ctop / channels_each_class;
+      index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
+      index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
+      index_t class_id = ctop / channels_each_class;
       DType trans_x = no_trans ? static_cast<DType>(0) :
         bottom_trans[(((n * num_classes + class_id) * 2)
                         * part_size + part_h)
@@ -124,33 +120,32 @@ namespace cuda {
                         * part_size + part_h)
                         * part_size + part_w] * trans_std;
 
-      DType wstart = static_cast<DType>(pw)* bin_size_w
-        + roi_start_w;
+      DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w;
       wstart += trans_x * roi_width;
-      DType hstart = static_cast<DType>(ph) * bin_size_h
-        + roi_start_h;
+      DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h;
       hstart += trans_y * roi_height;
 
       DType sum = 0;
-      int count = 0;
-      int gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
-      int gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
-      gw = min(max(gw, 0), group_size - 1);
-      gh = min(max(gh, 0), group_size - 1);
+      index_t count = 0;
+      index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
+      index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
+      gw = min(max(gw, static_cast<index_t>(0)), group_size - 1);
+      gh = min(max(gh, static_cast<index_t>(0)), group_size - 1);
 
       const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
-      for (int ih = 0; ih < sample_per_part; ih++) {
-        for (int iw = 0; iw < sample_per_part; iw++) {
-          DType w = wstart + iw*sub_bin_size_w;
-          DType h = hstart + ih*sub_bin_size_h;
+      for (index_t ih = 0; ih < sample_per_part; ih++) {
+        for (index_t iw = 0; iw < sample_per_part; iw++) {
+          DType w = wstart + iw * sub_bin_size_w;
+          DType h = hstart + ih * sub_bin_size_h;
           // bilinear interpolation
-          if (w<-0.5 || w>width - 0.5 || h<-0.5 || h>height - 0.5) {
+          if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) {
             continue;
           }
           w = min(max(w, 0.), width - 1.);
           h = min(max(h, 0.), height - 1.);
-          int c = (ctop*group_size + gh)*group_size + gw;
-          DType val = bilinear_interp(offset_bottom_data + c*height*width, w, h, width, height);
+          index_t c = (ctop * group_size + gh) * group_size + gw;
+          DType val = bilinear_interp(offset_bottom_data + c * height * width,
+                                      w, h, width, height);
           sum += val;
           count++;
         }
@@ -162,75 +157,74 @@ namespace cuda {
 
   template<typename DType>
   inline void DeformablePSROIPoolForward(const Tensor<gpu, 4, DType> &out,
-    const Tensor<gpu, 4, DType> &data,
-    const Tensor<gpu, 2, DType> &bbox,
-    const Tensor<gpu, 4, DType> &trans,
-    const Tensor<gpu, 4, DType> &top_count,
-    const bool no_trans,
-    const float spatial_scale,
-    const int output_dim,
-    const int group_size,
-    const int pooled_size,
-    const int part_size,
-    const int sample_per_part,
-    const float trans_std) {
-    // LOG(INFO) << "DeformablePSROIPoolForward";
+                                         const Tensor<gpu, 4, DType> &data,
+                                         const Tensor<gpu, 2, DType> &bbox,
+                                         const Tensor<gpu, 4, DType> &trans,
+                                         const Tensor<gpu, 4, DType> &top_count,
+                                         const bool no_trans, const float spatial_scale,
+                                         const index_t output_dim, const index_t group_size,
+                                         const index_t pooled_size, const index_t part_size,
+                                         const index_t sample_per_part, const float trans_std) {
     const DType *bottom_data = data.dptr_;
     const DType *bottom_rois = bbox.dptr_;
     const DType *bottom_trans = no_trans ? NULL : trans.dptr_;
     DType *top_data = out.dptr_;
     DType *top_count_data = top_count.dptr_;
-    const int count = out.shape_.Size();
-    const int channels = data.size(1);
-    const int height = data.size(2);
-    const int width = data.size(3);
-    const int pooled_height = pooled_size;
-    const int pooled_width = pooled_size;
-    const int num_classes = no_trans ? 1 : trans.size(1) / 2;
-    const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+    const index_t count = out.shape_.Size();
+    const index_t channels = data.size(1);
+    const index_t height = data.size(2);
+    const index_t width = data.size(3);
+    const index_t pooled_height = pooled_size;
+    const index_t pooled_width = pooled_size;
+    const index_t num_classes = no_trans ? 1 : trans.size(1) / 2;
+    const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
 
     cudaStream_t stream = Stream<gpu>::GetStream(out.stream_);
-    DeformablePSROIPoolForwardKernel<DType> << <mxnet::op::mxnet_op::cuda_get_num_blocks(count),
-      kBaseThreadNum, 0, stream >> >(
-      count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width,
-      bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim,
-      group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
+    DeformablePSROIPoolForwardKernel<DType><<<
+      mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum,
+      0, stream>>>(count, bottom_data, spatial_scale, channels, height, width,
+                   pooled_height, pooled_width, bottom_rois, bottom_trans,
+                   no_trans, trans_std, sample_per_part, output_dim,
+                   group_size, part_size, num_classes,
+                   channels_each_class, top_data, top_count_data);
     DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
   }
 
 
   template <typename DType>
-  __global__ void DeformablePSROIPoolBackwardAccKernel(
-    const int count,
-    const DType* top_diff,
-    const DType* top_count,
-    const int num_rois,
-    const DType spatial_scale,
-    const int channels,
-    const int height, const int width,
-    const int pooled_height, const int pooled_width,
-    const int output_dim,
-    DType* bottom_data_diff, DType* bottom_trans_diff,
-    const DType* bottom_data,
-    const DType* bottom_rois,
-    const DType* bottom_trans,
-    const bool no_trans,
-    const DType trans_std,
-    const int sample_per_part,
-    const int group_size,
-    const int part_size,
-    const int num_classes,
-    const int channels_each_class) {
+  __global__ void DeformablePSROIPoolBackwardAccKernel(const index_t count,
+                                                       const DType* top_diff,
+                                                       const DType* top_count,
+                                                       const index_t num_rois,
+                                                       const DType spatial_scale,
+                                                       const index_t channels,
+                                                       const index_t height,
+                                                       const index_t width,
+                                                       const index_t pooled_height,
+                                                       const index_t pooled_width,
+                                                       const index_t output_dim,
+                                                       DType* bottom_data_diff,
+                                                       DType* bottom_trans_diff,
+                                                       const DType* bottom_data,
+                                                       const DType* bottom_rois,
+                                                       const DType* bottom_trans,
+                                                       const bool no_trans,
+                                                       const DType trans_std,
+                                                       const index_t sample_per_part,
+                                                       const index_t group_size,
+                                                       const index_t part_size,
+                                                       const index_t num_classes,
+                                                       const index_t channels_each_class) {
     CUDA_KERNEL_LOOP(index, count) {
       // The output is in order (n, ctop, ph, pw)
-      int pw = index % pooled_width;
-      int ph = (index / pooled_width) % pooled_height;
-      int ctop = (index / pooled_width / pooled_height) % output_dim;
-      int n = index / pooled_width / pooled_height / output_dim;
+      index_t pw = index % pooled_width;
+      index_t ph = (index / pooled_width) % pooled_height;
+      index_t ctop = (index / pooled_width / pooled_height) % output_dim;
+      index_t n = index / pooled_width / pooled_height / output_dim;
 
       // [start, end) interval for spatial sampling
       const DType* offset_bottom_rois = bottom_rois + n * 5;
-      int roi_batch_ind = offset_bottom_rois[0];
+      index_t roi_batch_ind = offset_bottom_rois[0];
       DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
       DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
       DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
@@ -247,9 +241,9 @@ namespace cuda {
       DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
       DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);
 
-      int part_h = floor(static_cast<DType>(ph) / pooled_height*part_size);
-      int part_w = floor(static_cast<DType>(pw) / pooled_width*part_size);
-      int class_id = ctop / channels_each_class;
+      index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
+      index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
+      index_t class_id = ctop / channels_each_class;
       DType trans_x = no_trans ? static_cast<DType>(0) :
         bottom_trans[(((n * num_classes + class_id) * 2)
                         * part_size + part_h)
@@ -259,11 +253,9 @@ namespace cuda {
                         * part_size + part_h)
                         * part_size + part_w] * trans_std;
 
-      DType wstart = static_cast<DType>(pw)* bin_size_w
-        + roi_start_w;
+      DType wstart = static_cast<DType>(pw) * bin_size_w + roi_start_w;
       wstart += trans_x * roi_width;
-      DType hstart = static_cast<DType>(ph) * bin_size_h
-        + roi_start_h;
+      DType hstart = static_cast<DType>(ph) * bin_size_h + roi_start_h;
       hstart += trans_y * roi_height;
 
       if (top_count[index] <= 0) {
@@ -272,51 +264,49 @@ namespace cuda {
       DType diff_val = top_diff[index] / top_count[index];
       const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
       DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
-      int gw = floor(static_cast<DType>(pw)* group_size / pooled_width);
-      int gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
-      gw = min(max(gw, 0), group_size - 1);
-      gh = min(max(gh, 0), group_size - 1);
-
-      for (int ih = 0; ih < sample_per_part; ih++) {
-        for (int iw = 0; iw < sample_per_part; iw++) {
-          DType w = wstart + iw*sub_bin_size_w;
-          DType h = hstart + ih*sub_bin_size_h;
+      index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
+      index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
+      gw = min(max(gw, static_cast<index_t>(0)), group_size - 1);
+      gh = min(max(gh, static_cast<index_t>(0)), group_size - 1);
+
+      for (index_t ih = 0; ih < sample_per_part; ih++) {
+        for (index_t iw = 0; iw < sample_per_part; iw++) {
+          DType w = wstart + iw * sub_bin_size_w;
+          DType h = hstart + ih * sub_bin_size_h;
           // bilinear interpolation
-          if (w<-0.5 || w>width - 0.5 || h<-0.5 || h>height - 0.5) {
+          if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) {
             continue;
           }
           w = min(max(w, 0.), width - 1.);
           h = min(max(h, 0.), height - 1.);
-          int c = (ctop*group_size + gh)*group_size + gw;
+          index_t c = (ctop * group_size + gh) * group_size + gw;
           // backward on feature
-          int x0 = floor(w);
-          int x1 = ceil(w);
-          int y0 = floor(h);
-          int y1 = ceil(h);
+          index_t x0 = floor(w);
+          index_t x1 = ceil(w);
+          index_t y0 = floor(h);
+          index_t y1 = ceil(h);
           DType dist_x = w - x0, dist_y = h - y0;
-          DType q00 = (1 - dist_x)*(1 - dist_y);
-          DType q01 = (1 - dist_x)*dist_y;
-          DType q10 = dist_x*(1 - dist_y);
-          DType q11 = dist_x*dist_y;
-          int bottom_index_base = c * height *width;
-          atomicAdd(offset_bottom_data_diff + bottom_index_base + y0*width + x0, q00*diff_val);
-          atomicAdd(offset_bottom_data_diff + bottom_index_base + y1*width + x0, q01*diff_val);
-          atomicAdd(offset_bottom_data_diff + bottom_index_base + y0*width + x1, q10*diff_val);
-          atomicAdd(offset_bottom_data_diff + bottom_index_base + y1*width + x1, q11*diff_val);
+          DType q00 = (1 - dist_x) * (1 - dist_y);
+          DType q01 = (1 - dist_x) * dist_y;
+          DType q10 = dist_x * (1 - dist_y);
+          DType q11 = dist_x * dist_y;
+          index_t bottom_index_base = c * height * width;
+          atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
+          atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
+          atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
+          atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
 
           if (no_trans) {
             continue;
           }
-          DType U00 = offset_bottom_data[bottom_index_base + y0*width + x0];
-          DType U01 = offset_bottom_data[bottom_index_base + y1*width + x0];
-          DType U10 = offset_bottom_data[bottom_index_base + y0*width + x1];
-          DType U11 = offset_bottom_data[bottom_index_base + y1*width + x1];
-          DType diff_x = (U11*dist_y + U10*(1 - dist_y) - U01*dist_y - U00*(1 - dist_y))
-            *trans_std*diff_val;
-          diff_x *= roi_width;
-          DType diff_y = (U11*dist_x + U01*(1 - dist_x) - U10*dist_x - U00*(1 - dist_x))
-            *trans_std*diff_val;
-          diff_y *= roi_height;
+          DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
+          DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
+          DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
+          DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
+          DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y);
+          diff_x *= trans_std * diff_val * roi_width;
+          DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x);
+          diff_y *= trans_std * diff_val * roi_height;
 
           atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2)
                                            * part_size + part_h)
@@ -332,21 +322,16 @@ namespace cuda {
 
   template<typename DType>
   inline void DeformablePSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad,
-    const Tensor<gpu, 4, DType> &trans_grad,
-    const Tensor<gpu, 4, DType> &out_grad,
-    const Tensor<gpu, 4, DType> &data,
-    const Tensor<gpu, 2, DType> &bbox,
-    const Tensor<gpu, 4, DType> &trans,
-    const Tensor<gpu, 4, DType> &top_count,
-    const bool no_trans,
-    const float spatial_scale,
-    const int output_dim,
-    const int group_size,
-    const int pooled_size,
-    const int part_size,
-    const int sample_per_part,
-    const float trans_std) {
-    // LOG(INFO) << "DeformablePSROIPoolBackward";
+                                             const Tensor<gpu, 4, DType> &trans_grad,
+                                             const Tensor<gpu, 4, DType> &out_grad,
+                                             const Tensor<gpu, 4, DType> &data,
+                                             const Tensor<gpu, 2, DType> &bbox,
+                                             const Tensor<gpu, 4, DType> &trans,
+                                             const Tensor<gpu, 4, DType> &top_count,
+                                             const bool no_trans, const float spatial_scale,
+                                             const index_t output_dim, const index_t group_size,
+                                             const index_t pooled_size, const index_t part_size,
+                                             const index_t sample_per_part, const float trans_std) {
     const DType *top_diff = out_grad.dptr_;
     const DType *bottom_data = data.dptr_;
     const DType *bottom_rois = bbox.dptr_;
@@ -354,23 +339,25 @@ namespace cuda {
     DType *bottom_data_diff = in_grad.dptr_;
     DType *bottom_trans_diff = no_trans ? NULL : trans_grad.dptr_;
     const DType *top_count_data = top_count.dptr_;
-    const int count = out_grad.shape_.Size();
-    const int num_rois = bbox.size(0);
-    const int channels = in_grad.size(1);
-    const int height = in_grad.size(2);
-    const int width = in_grad.size(3);
-    const int pooled_height = pooled_size;
-    const int pooled_width = pooled_size;
-    const int num_classes = no_trans ? 1 : trans_grad.size(1) / 2;
-    const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+    const index_t count = out_grad.shape_.Size();
+    const index_t num_rois = bbox.size(0);
+    const index_t channels = in_grad.size(1);
+    const index_t height = in_grad.size(2);
+    const index_t width = in_grad.size(3);
+    const index_t pooled_height = pooled_size;
+    const index_t pooled_width = pooled_size;
+    const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2;
+    const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
 
     cudaStream_t stream = Stream<gpu>::GetStream(in_grad.stream_);
-    DeformablePSROIPoolBackwardAccKernel<DType> << <mxnet::op::mxnet_op::cuda_get_num_blocks(count),
-      kBaseThreadNum, 0, stream >> >(
-      count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width,
-      pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
-      bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part,
-      group_size, part_size, num_classes, channels_each_class);
+    DeformablePSROIPoolBackwardAccKernel<DType><<<
+      mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum,
+      0, stream >>>(count, top_diff, top_count_data, num_rois, spatial_scale,
+                    channels, height, width, pooled_height, pooled_width,
+                    output_dim, bottom_data_diff, bottom_trans_diff,
+                    bottom_data, bottom_rois, bottom_trans,
+                    no_trans, trans_std, sample_per_part, group_size,
+                    part_size, num_classes, channels_each_class);
     DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
   }
 
@@ -378,41 +365,36 @@ namespace cuda {
 
   template<typename DType>
   inline void DeformablePSROIPoolForward(const Tensor<gpu, 4, DType> &out,
-    const Tensor<gpu, 4, DType> &data,
-    const Tensor<gpu, 2, DType> &bbox,
-    const Tensor<gpu, 4, DType> &trans,
-    const Tensor<gpu, 4, DType> &top_count,
-    const bool no_trans,
-    const float spatial_scale,
-    const int output_dim,
-    const int group_size,
-    const int pooled_size,
-    const int part_size,
-    const int sample_per_part,
-    const float trans_std) {
-    cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count, no_trans, spatial_scale,
-      output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std);
+                                         const Tensor<gpu, 4, DType> &data,
+                                         const Tensor<gpu, 2, DType> &bbox,
+                                         const Tensor<gpu, 4, DType> &trans,
+                                         const Tensor<gpu, 4, DType> &top_count,
+                                         const bool no_trans, const float spatial_scale,
+                                         const index_t output_dim, const index_t group_size,
+                                         const index_t pooled_size, const index_t part_size,
+                                         const index_t sample_per_part, const float trans_std) {
+    cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count,
+                                     no_trans, spatial_scale, output_dim,
+                                     group_size, pooled_size, part_size,
+                                     sample_per_part, trans_std);
   }
 
   template<typename DType>
   inline void DeformablePSROIPoolBackwardAcc(const Tensor<gpu, 4, DType> &in_grad,
-    const Tensor<gpu, 4, DType> &trans_grad,
-    const Tensor<gpu, 4, DType> &out_grad,
-    const Tensor<gpu, 4, DType> &data,
-    const Tensor<gpu, 2, DType> &bbox,
-    const Tensor<gpu, 4, DType> &trans,
-    const Tensor<gpu, 4, DType> &top_count,
-    const bool no_trans,
-    const float spatial_scale,
-    const int output_dim,
-    const int group_size,
-    const int pooled_size,
-    const int part_size,
-    const int sample_per_part,
-    const float trans_std) {
-    cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox, trans,
-      top_count, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size,
-      sample_per_part, trans_std);
+                                             const Tensor<gpu, 4, DType> &trans_grad,
+                                             const Tensor<gpu, 4, DType> &out_grad,
+                                             const Tensor<gpu, 4, DType> &data,
+                                             const Tensor<gpu, 2, DType> &bbox,
+                                             const Tensor<gpu, 4, DType> &trans,
+                                             const Tensor<gpu, 4, DType> &top_count,
+                                             const bool no_trans, const float spatial_scale,
+                                             const index_t output_dim, const index_t group_size,
+                                             const index_t pooled_size, const index_t part_size,
+                                             const index_t sample_per_part, const float trans_std) {
+    cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox,
+                                         trans, top_count, no_trans, spatial_scale,
+                                         output_dim, group_size, pooled_size,
+                                         part_size, sample_per_part, trans_std);
   }
 
 }  // namespace mshadow
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 9c88dc1..9c004cd 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1633,6 +1633,24 @@ def test_deformable_psroipooling_with_type():
                  'deformable_psroipool_trans': (2, 4, 3, 3),
                  'type_dict': {'deformable_psroipool_data': np.float16, 'deformable_psroipool_rois': np.float16,
                                'deformable_psroipool_trans': np.float16}},
+                {'ctx': mx.cpu(0),
+                 'deformable_psroipool_data': (1, 18, 14, 14),
+                 'deformable_psroipool_rois': (2, 5),
+                 'deformable_psroipool_trans': (2, 4, 3, 3),
+                 'type_dict': {'deformable_psroipool_data': np.float64, 'deformable_psroipool_rois': np.float64,
+                               'deformable_psroipool_trans': np.float64}},
+                {'ctx': mx.cpu(0),
+                 'deformable_psroipool_data': (1, 18, 14, 14),
+                 'deformable_psroipool_rois': (2, 5),
+                 'deformable_psroipool_trans': (2, 4, 3, 3),
+                 'type_dict': {'deformable_psroipool_data': np.float32, 'deformable_psroipool_rois': np.float32,
+                               'deformable_psroipool_trans': np.float32}},
+                {'ctx': mx.cpu(0),
+                 'deformable_psroipool_data': (1, 18, 14, 14),
+                 'deformable_psroipool_rois': (2, 5),
+                 'deformable_psroipool_trans': (2, 4, 3, 3),
+                 'type_dict': {'deformable_psroipool_data': np.float16, 'deformable_psroipool_rois': np.float16,
+                               'deformable_psroipool_trans': np.float16}},
                 ]
 
     check_consistency(sym, ctx_list, scale=0.1, tol=tol,