You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/05 20:12:31 UTC

[GitHub] anirudh2290 closed pull request #11742: [MXNET-623] Fixing an integer overflow bug in large NDArray

anirudh2290 closed pull request #11742: [MXNET-623] Fixing an integer overflow bug in large NDArray
URL: https://github.com/apache/incubator-mxnet/pull/11742
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/3rdparty/mshadow b/3rdparty/mshadow
index 8a9e337f3c4..696803bd772 160000
--- a/3rdparty/mshadow
+++ b/3rdparty/mshadow
@@ -1 +1 @@
-Subproject commit 8a9e337f3c4794876bd04d5351d967333bcabee3
+Subproject commit 696803bd7723ade8230af878460d96c68a550fbc
diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc
index cea8c9553cc..7091be2e72c 100644
--- a/src/c_api/c_api_function.cc
+++ b/src/c_api/c_api_function.cc
@@ -55,7 +55,7 @@ std::vector<nnvm::NodeEntry> Gradient(
   g->inputs = out_grads;
 
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < g->num_outputs(); ++i) {
+  for (uint32_t i = 0; i < g->num_outputs(); ++i) {
     ret.emplace_back(nnvm::NodeEntry{g, i, 0});
   }
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 922917f7947..aa45df16922 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1308,7 +1308,7 @@ void GraphExecutor::ExecuteMonCallback(size_t nid) {
     }
   }
   CHECK_EQ(opnode.exec->out_array.size(), output_names.size());
-  for (index_t i = 0; i < opnode.exec->out_array.size(); ++i) {
+  for (size_t i = 0; i < opnode.exec->out_array.size(); ++i) {
     NDArray *cpy = new NDArray(opnode.exec->out_array[i]);
     std::string name = inode.source->attrs.name + "_" + output_names[i];
     this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h
index 8580ff8f9f9..a2324a4b5c5 100644
--- a/src/io/image_iter_common.h
+++ b/src/io/image_iter_common.h
@@ -42,7 +42,7 @@ class ImageLabelMap {
    * \param label_width predefined label_width
    */
   explicit ImageLabelMap(const char *path_imglist,
-                         mshadow::index_t label_width,
+                         index_t label_width,
                          bool silent) {
     this->label_width = label_width;
     image_index_.clear();
@@ -58,7 +58,7 @@ class ImageLabelMap {
       // skip space
       while (isspace(*p) && p != end) ++p;
       image_index_.push_back(static_cast<size_t>(atol(p)));
-      for (size_t i = 0; i < label_width; ++i) {
+      for (index_t i = 0; i < label_width; ++i) {
         // skip till space
         while (!isspace(*p) && p != end) ++p;
         // skip space
@@ -171,7 +171,7 @@ struct ImageRecParserParam : public dmlc::Parameter<ImageRecParserParam> {
 // Batch parameters
 struct BatchParam : public dmlc::Parameter<BatchParam> {
   /*! \brief label width */
-  index_t batch_size;
+  uint32_t batch_size;
   /*! \brief use round roubin to handle overflow batch */
   bool round_batch;
   // declare parameters
diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc
index 668eb0059f8..f4148395549 100644
--- a/src/io/iter_image_recordio_2.cc
+++ b/src/io/iter_image_recordio_2.cc
@@ -75,7 +75,7 @@ class ImageRecordIOParser2 {
   cv::Mat TJimdecode(cv::Mat buf, int color);
 #endif
 #endif
-  inline unsigned ParseChunk(DType* data_dptr, real_t* label_dptr, const unsigned current_size,
+  inline size_t ParseChunk(DType* data_dptr, real_t* label_dptr, const size_t current_size,
     dmlc::InputSplit::Blob * chunk);
   inline void CreateMeanImg(void);
 
@@ -104,10 +104,10 @@ class ImageRecordIOParser2 {
   /*! \brief temp space */
   mshadow::TensorContainer<cpu, 3> img_;
   /*! \brief internal instance order */
-  std::vector<std::pair<unsigned, unsigned> > inst_order_;
-  unsigned inst_index_;
+  std::vector<std::pair<size_t, size_t> > inst_order_;
+  size_t inst_index_;
   /*! \brief internal counter tracking number of already parsed entries */
-  unsigned n_parsed_;
+  size_t n_parsed_;
   /*! \brief overflow marker */
   bool overflow;
   /*! \brief unit size */
@@ -200,7 +200,7 @@ inline void ImageRecordIOParser2<DType>::Init(
                       "larger chunk size";
       }
       // 1.1 ratio is for a bit more shuffle parts to avoid boundary issue
-      unsigned num_shuffle_parts =
+      size_t num_shuffle_parts =
           std::ceil(source_->GetTotalSize() * 1.1 /
                     (param_.num_parts * (param_.shuffle_chunk_size << 20UL)));
 
@@ -262,7 +262,7 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
   }
   CHECK(source_ != nullptr);
   dmlc::InputSplit::Blob chunk;
-  unsigned current_size = 0;
+  size_t current_size = 0;
   out->index.resize(batch_param_.batch_size);
 
   // InitBatch
@@ -295,7 +295,7 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
 
   while (current_size < batch_param_.batch_size) {
     // int n_to_copy;
-    unsigned n_to_out = 0;
+    size_t n_to_out = 0;
     if (n_parsed_ == 0) {
       if (source_->NextBatch(&chunk, batch_param_.batch_size)) {
         inst_order_.clear();
@@ -328,15 +328,16 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
         n_to_out = 0;
       }
     } else {
-      int n_to_copy = std::min(n_parsed_, batch_param_.batch_size - current_size);
+      size_t n_to_copy = std::min(n_parsed_,
+                                  static_cast<size_t>(batch_param_.batch_size) - current_size);
       n_parsed_ -= n_to_copy;
       // Copy
       #pragma omp parallel for num_threads(param_.preprocess_threads)
-      for (int i = 0; i < n_to_copy; ++i) {
+      for (int i = 0; i < static_cast<int>(n_to_copy); ++i) {
         omp_exc_.Run([&] {
-        std::pair<unsigned, unsigned> place = inst_order_[inst_index_ + i];
+        std::pair<size_t, size_t> place = inst_order_[inst_index_ + i];
         const DataInst& batch = temp_[place.first][place.second];
-        for (unsigned j = 0; j < batch.data.size(); ++j) {
+        for (size_t j = 0; j < batch.data.size(); ++j) {
           CHECK_EQ(unit_size_[j], batch.data[j].Size());
           MSHADOW_TYPE_SWITCH(out->data[j].data().type_flag_, dtype, {
           mshadow::Copy(
@@ -482,18 +483,18 @@ cv::Mat ImageRecordIOParser2<DType>::TJimdecode(cv::Mat image, int color) {
 
 // Returns the number of images that are put into output
 template<typename DType>
-inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t* label_dptr,
-  const unsigned current_size, dmlc::InputSplit::Blob * chunk) {
+inline size_t ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t* label_dptr,
+  const size_t current_size, dmlc::InputSplit::Blob * chunk) {
   temp_.resize(param_.preprocess_threads);
 #if MXNET_USE_OPENCV
   // save opencv out
   dmlc::RecordIOChunkReader reader(*chunk, 0, 1);
-  unsigned gl_idx = current_size;
+  size_t gl_idx = current_size;
   #pragma omp parallel num_threads(param_.preprocess_threads)
   {
     omp_exc_.Run([&] {
     CHECK(omp_get_num_threads() == param_.preprocess_threads);
-    unsigned int tid = omp_get_thread_num();
+    int tid = omp_get_thread_num();
     // dmlc::RecordIOChunkReader reader(*chunk, tid, param_.preprocess_threads);
     ImageRecordIO rec;
     dmlc::InputSplit::Blob blob;
@@ -502,7 +503,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
     out_tmp.Clear();
     while (true) {
       bool reader_has_data;
-      unsigned idx;
+      size_t idx;
       #pragma omp critical
       {
         reader_has_data = reader.NextRecord(&blob);
@@ -567,7 +568,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
         data = mshadow::Tensor<cpu, 3, DType>(data_dptr + idx*unit_size_[0],
           mshadow::Shape3(n_channels, res.rows, res.cols));
       } else {
-        out_tmp.Push(static_cast<unsigned>(rec.image_index()),
+        out_tmp.Push(static_cast<size_t>(rec.image_index()),
                  mshadow::Shape3(n_channels, res.rows, res.cols),
                  mshadow::Shape1(param_.label_width));
         data = out_tmp.data().Back();
@@ -612,7 +613,7 @@ inline unsigned ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr, real_t
   });
   }
   omp_exc_.Rethrow();
-  return (std::min(batch_param_.batch_size, gl_idx) - current_size);
+  return (std::min(static_cast<size_t>(batch_param_.batch_size), gl_idx) - current_size);
 #else
   LOG(FATAL) << "Opencv is needed for image decoding and augmenting.";
   return 0;
@@ -633,8 +634,8 @@ inline void ImageRecordIOParser2<DType>::CreateMeanImg(void) {
       inst_order_.clear();
       // Parse chunk w/o putting anything in out
       ParseChunk(nullptr, nullptr, batch_param_.batch_size, &chunk);
-      for (unsigned i = 0; i < inst_order_.size(); ++i) {
-        std::pair<unsigned, unsigned> place = inst_order_[i];
+      for (size_t i = 0; i < inst_order_.size(); ++i) {
+        std::pair<size_t, size_t> place = inst_order_[i];
         mshadow::Tensor<cpu, 3> outimg =
           temp_[place.first][place.second].data[0].template get<cpu, 3, real_t>();
         if (imcnt == 0) {
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 47e0c5bbe75..04b6a0217e2 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -2105,10 +2105,10 @@ void Imdecode(NDArray *ret, NDArray mean, size_t index,
   if (mean.is_none()) {
     MSHADOW_TYPE_SWITCH(buff.dtype(), DType, {
       mshadow::Tensor<cpu, 4, DType> tensor = buff.data().get<cpu, 4, DType>();
-      for (index_t i = 0; i < y1-y0; i++) {
+      for (size_t i = 0; i < y1-y0; i++) {
         uchar* im_data = res.ptr<uchar>(y0+i) + res.channels()*x0;
-        for (index_t j = 0; j < x1-x0; j++) {
-          for (index_t k = 0; k < n_channels; k++) {
+        for (size_t j = 0; j < x1-x0; j++) {
+          for (size_t k = 0; k < n_channels; k++) {
             tensor[0][k][i][j] = DType(im_data[k]);  // NOLINT(*)
           }
           im_data += res.channels();
@@ -2125,10 +2125,10 @@ void Imdecode(NDArray *ret, NDArray mean, size_t index,
     MSHADOW_TYPE_SWITCH(buff.dtype(), DType, {
       mshadow::Tensor<cpu, 4, DType> tensor = buff.data().get<cpu, 4, DType>();
       mshadow::Tensor<cpu, 3, DType> tmean = mean.data().get<cpu, 3, DType>();
-      for (index_t i = 0; i < y1-y0; i++) {
+      for (size_t i = 0; i < y1-y0; i++) {
         uchar* im_data = res.ptr<uchar>(y0+i) + res.channels()*x0;
-        for (index_t j = 0; j < x1-x0; j++) {
-          for (index_t k = 0; k < n_channels; k++) {
+        for (size_t j = 0; j < x1-x0; j++) {
+          for (size_t k = 0; k < n_channels; k++) {
             tensor[0][k][i][j] = DType(im_data[k]) - tmean[k][i][j];  // NOLINT(*)
           }
           im_data += res.channels();
diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc
index 022302aca40..43295d6e101 100644
--- a/src/ndarray/ndarray_function.cc
+++ b/src/ndarray/ndarray_function.cc
@@ -92,7 +92,7 @@ void ElementwiseSumRspImpl(mshadow::Stream<cpu>* s,
               auto out_value_cur_row = out_values[irow];
               const auto offset = row_idx_ptr - nd_indices_start;
               auto nd_value_cur_row = nd_values[offset];
-              for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) {
+              for (index_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) {
                 out_value_cur_row[j] += nd_value_cur_row[j];
               }
               ++irow;
diff --git a/src/operator/batch_norm_v1-inl.h b/src/operator/batch_norm_v1-inl.h
index 1e048452275..f4116e30186 100644
--- a/src/operator/batch_norm_v1-inl.h
+++ b/src/operator/batch_norm_v1-inl.h
@@ -286,14 +286,14 @@ class BatchNormV1Prop : public OperatorProperty {
     // For other input types, these parameters have the same type as input
     // NOTE: This requirement is from cuDNN (v. 4 and 5)
     int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype;
-    for (index_t i = 1; i < in_type->size(); ++i) {
+    for (size_t i = 1; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype_param;
       } else {
         UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]);
       }
     }
-    for (index_t i = 0; i < aux_type->size(); ++i) {
+    for (size_t i = 0; i < aux_type->size(); ++i) {
       if ((*aux_type)[i] != -1) {
         UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]);
       }
diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu
index 2e6be3e1ef3..03734a61316 100644
--- a/src/operator/bilinear_sampler.cu
+++ b/src/operator/bilinear_sampler.cu
@@ -51,8 +51,8 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h,
     int h = (index / o_w) % o_h;
     int c = (index / o_w / o_h) % o_c;
     int n = index / o_w / o_h / o_c;
-    index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
-    index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
+    int out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
+    int grid_index = n * o_h * o_w * 2 + h * o_w + w;
     DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
     DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
     int top_left_y = static_cast<int>(floor(y_real));
@@ -96,7 +96,7 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
     int n = index / o_w / o_h;
     DType top_left_y_gw = 0.0;
     DType top_left_x_gw = 0.0;
-    index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
+    int grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
     DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
     DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
 
@@ -104,8 +104,8 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
     int top_left_x = static_cast<int>(floor(x_real));
     DType top_left_y_w = 1.0 - (y_real - top_left_y);
     DType top_left_x_w = 1.0 - (x_real - top_left_x);
-    for (index_t c = 0; c < o_c; ++c) {
-      index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
+    for (int c = 0; c < o_c; ++c) {
+      int grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
       int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
       // calc 4 vertex value in input data
       DType top_left_v = 0;
diff --git a/src/operator/channel_op_common.h b/src/operator/channel_op_common.h
index 00cd8ae084b..1afc13ad259 100644
--- a/src/operator/channel_op_common.h
+++ b/src/operator/channel_op_common.h
@@ -44,7 +44,7 @@ inline void concatenate_helper(const std::vector<mshadow::Tensor<xpu, dim, DType
     mshadow::Tensor<xpu, dim, DType> out = *output;
     size_t size = input.size();
     index_t begin = 0;
-    for (index_t i = 0; i < size; ++i) {
+    for (size_t i = 0; i < size; ++i) {
       index_t end = begin + input[i].size(cdim);
       Assign(slice<cdim>(out, begin, end), req, input[i]);
       begin = end;
@@ -79,7 +79,7 @@ void split_helper(const mshadow::Tensor<xpu, dim, DType> &input,
     std::vector<mshadow::Tensor<xpu, dim, DType> > out = *output;
     size_t size = out.size();
     index_t begin = 0;
-    for (index_t i = 0; i < size; ++i) {
+    for (size_t i = 0; i < size; ++i) {
       index_t end = begin + out[i].size(cdim);
       Assign(out[i], req[i], slice<cdim>(input, begin, end));
       begin = end;
diff --git a/src/operator/contrib/count_sketch-inl.h b/src/operator/contrib/count_sketch-inl.h
index 76d1a7efb87..dd3bf54ab6a 100644
--- a/src/operator/contrib/count_sketch-inl.h
+++ b/src/operator/contrib/count_sketch-inl.h
@@ -185,7 +185,7 @@ class CountSketchProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/contrib/deformable_convolution-inl.h b/src/operator/contrib/deformable_convolution-inl.h
index 480f675bdbf..7328eb38308 100644
--- a/src/operator/contrib/deformable_convolution-inl.h
+++ b/src/operator/contrib/deformable_convolution-inl.h
@@ -129,7 +129,7 @@ class DeformableConvolutionOp : public Operator {
     // calculate the shape of col_buffer
     TShape col_buffer_shape(num_spatial_axes_ + 1);
     col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size();
-    for (index_t i = 1; i < col_buffer_shape.ndim(); ++i) {
+    for (size_t i = 1; i < col_buffer_shape.ndim(); ++i) {
       col_buffer_shape[i] = out_data[0].shape_[i + 1];
     }
     // create a column buffer using workspace and col_buffer_shape
@@ -453,7 +453,7 @@ class DeformableConvolutionProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/contrib/fft-inl.h b/src/operator/contrib/fft-inl.h
index be7b64aeb0c..c5c8574f19e 100644
--- a/src/operator/contrib/fft-inl.h
+++ b/src/operator/contrib/fft-inl.h
@@ -258,7 +258,7 @@ class FFTProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/contrib/ifft-inl.h b/src/operator/contrib/ifft-inl.h
index e48d653d927..da560c3c517 100644
--- a/src/operator/contrib/ifft-inl.h
+++ b/src/operator/contrib/ifft-inl.h
@@ -250,7 +250,7 @@ class IFFTProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i=0; i < in_type->size(); ++i) {
+    for (size_t i=0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h
index 1f548dbc7e5..78f1c09dfe0 100644
--- a/src/operator/contrib/sync_batch_norm-inl.h
+++ b/src/operator/contrib/sync_batch_norm-inl.h
@@ -500,14 +500,14 @@ class SyncBatchNormProp : public OperatorProperty {
     // For other input types, these parameters have the same type as input
     // NOTE: This requirement is from cuDNN (v. 4 and 5)
     int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype;
-    for (index_t i = 1; i < in_type->size(); ++i) {
+    for (size_t i = 1; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype_param;
       } else {
         UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]);
       }
     }
-    for (index_t i = 0; i < aux_type->size(); ++i) {
+    for (size_t i = 0; i < aux_type->size(); ++i) {
       if ((*aux_type)[i] != -1) {
         UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]);
       }
diff --git a/src/operator/convolution_v1-inl.h b/src/operator/convolution_v1-inl.h
index d8310e6f1fc..758ce12d800 100644
--- a/src/operator/convolution_v1-inl.h
+++ b/src/operator/convolution_v1-inl.h
@@ -335,12 +335,10 @@ class ConvolutionV1Op : public Operator {
                                      oshape[2] * oshape[3]);
     // param_.workspace is in elements of sizeof(DType)
     // if param_.workspace is set to zero the nstep_ equals ishape[0] (batch)
-    nstep_ = std::max(
-        std::min(
-            static_cast<index_t>(
-                param_.workspace / (shape_colunit_.Size() + shape_dstunit_.Size())),
-            ishape[0]),
-        1U);
+    nstep_ = std::max<index_t>(
+        std::min(static_cast<index_t>(param_.workspace) /
+          (shape_colunit_.Size() + shape_dstunit_.Size()), ishape[0]),
+      1);
 
     mshadow::Shape<2> scol = mshadow::Shape2(shape_colunit_[0],
                                              shape_colunit_[1] * nstep_);
@@ -502,7 +500,7 @@ class ConvolutionV1Prop : public OperatorProperty {
     CHECK_GE(in_type->size(), 1);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 2061324e9e2..1bcbbbcf3e8 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -238,14 +238,14 @@ std::vector<nnvm::NodeEntry> Gradient(
   }
 
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < params.num_args; ++i) {
-    ret.emplace_back(nnvm::NodeEntry{g, i, 0});
+  for (size_t i = 0; i < params.num_args; ++i) {
+    ret.emplace_back(nnvm::NodeEntry{g, static_cast<uint32_t>(i), 0});
   }
   if (params.num_auxs) {
     nnvm::NodePtr ng = nnvm::Node::Create();
     ng->attrs.op = nnvm::Op::Get("_NoGradient");
     ng->attrs.name = "NoGradient";
-    for (index_t i = 0; i < params.num_auxs; ++i) {
+    for (size_t i = 0; i < params.num_auxs; ++i) {
       ret.emplace_back(nnvm::NodeEntry{ng, 0, 0});
     }
   }
diff --git a/src/operator/custom/native_op-inl.h b/src/operator/custom/native_op-inl.h
index 3c222071a4d..05ad124ad34 100644
--- a/src/operator/custom/native_op-inl.h
+++ b/src/operator/custom/native_op-inl.h
@@ -77,7 +77,7 @@ class NativeOp : public Operator {
     s->Wait();
     param_.pinfo->forward(ptrs.size(), ptrs.data(), ndims.data(), shapes.data(),
         tags.data(), param_.pinfo->p_forward);
-    for (index_t i = 0; i < out_data.size(); ++i) {
+    for (size_t i = 0; i < out_data.size(); ++i) {
       CHECK_NE(req[i], kAddTo) << "NativeOp doesn't support AddTo for output";
       if (req[i] != kNullOp) {
         std::stringstream ss;
@@ -111,7 +111,7 @@ class NativeOp : public Operator {
     s->Wait();
     param_.pinfo->backward(ptrs.size(), ptrs.data(), ndims.data(), shapes.data(),
         tags.data(), param_.pinfo->p_backward);
-    for (index_t i = 0; i < in_grad.size(); ++i) {
+    for (size_t i = 0; i < in_grad.size(); ++i) {
       CHECK_NE(req[i], kAddTo) << "NativeOp doesn't support AddTo for output";
       if (req[i] != kNullOp) {
         std::stringstream ss;
diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h
index 16aa0c388cd..cf44da69915 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -197,8 +197,8 @@ struct ElemwiseGradUseOut {
   std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
                                           const std::vector<nnvm::NodeEntry>& ograds) const {
     std::vector<nnvm::NodeEntry> heads;
-    index_t n_out = n->num_outputs();
-    for (index_t i = 0; i < n_out; ++i) {
+    uint32_t n_out = n->num_outputs();
+    for (uint32_t i = 0; i < n_out; ++i) {
       heads.emplace_back(nnvm::NodeEntry{n, i, 0});
     }
     return MakeNonlossGradNode(op_name, n, ograds, heads, n->attrs.dict);
@@ -214,8 +214,8 @@ struct ElemwiseGradUseInOut {
     for (auto& h : n->inputs) {
       heads.push_back(h);
     }
-    index_t n_out = n->num_outputs();
-    for (index_t i = 0; i < n_out; ++i) {
+    uint32_t n_out = n->num_outputs();
+    for (uint32_t i = 0; i < n_out; ++i) {
       heads.emplace_back(nnvm::NodeEntry{n, i, 0});
     }
     return MakeGradNode(op_name, n, heads, n->attrs.dict);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index f28f5d7a436..be542ba5b6b 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -363,7 +363,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
       dtype_param = mshadow::DataType<AccRealX>::kFlag; });
   std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
   CHECK_LE(in_type->size(), args.size());
-  for (index_t i = 1; i < in_type->size(); ++i) {
+  for (size_t i = 1; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype_param;
     } else {
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index d5abe629123..53b0c1380ed 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -279,7 +279,7 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
   CHECK_GE(in_type->size(), 1U);
   int dtype = (*in_type)[0];
   CHECK_NE(dtype, -1) << "First input must have specified type";
-  for (index_t i = 0; i < in_type->size(); ++i) {
+  for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
     } else {
diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h
index c7ccfb2fb4e..d89a489c018 100644
--- a/src/operator/nn/deconvolution-inl.h
+++ b/src/operator/nn/deconvolution-inl.h
@@ -131,7 +131,7 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
     if (bCal) {
       size_t input_ndim = input.ndim();
 
-      for (index_t i = 0; i < ndim; i++) {
+      for (size_t i = 0; i < ndim; i++) {
         // input.ndim() can be larger than ndim, in case that the complete input
         // shape was passed and not only the ndim last ones
         o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i);
@@ -141,7 +141,7 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
         o_pad[i] = (o_pad[i] + 1) / 2;
       }
     } else {
-      for (index_t i = 0; i < ndim; i++) {
+      for (size_t i = 0; i < ndim; i++) {
         o_pad[i] = pad[i];
         o_adj[i] = adj[i];
       }
@@ -459,12 +459,10 @@ class DeconvolutionOp {
                                      oshape[1] / param_.num_group,
                                      oshape[2] * oshape[3]);
     // See convolution for workspace calculations. nstep_ will be the effective batch size
-    nstep_ = std::max(
-        std::min(
-            static_cast<index_t>(
-                param_.workspace / (shape_colunit_.Size() + shape_dstunit_.Size())),
-            ishape[0]),
-        1U);
+    nstep_ = std::max<index_t>(
+        std::min(static_cast<index_t>(param_.workspace) /
+          (shape_colunit_.Size() + shape_dstunit_.Size()), ishape[0]),
+      1);
 
     mshadow::Shape<2> scol = mshadow::Shape2(shape_colunit_[0],
                                              shape_colunit_[1] * nstep_);
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index 1ab391d92b0..039c732c831 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -248,7 +248,7 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
   CHECK_GE(in_type->size(), 1U);
   int dtype = (*in_type)[0];
   CHECK_NE(dtype, -1) << "First input must have specified type";
-  for (index_t i = 0; i < in_type->size(); ++i) {
+  for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
     } else {
diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc
index 99f0dc46f1b..020cb479acc 100644
--- a/src/operator/nn/lrn.cc
+++ b/src/operator/nn/lrn.cc
@@ -57,7 +57,7 @@ bool LRNType(const nnvm::NodeAttrs& attrs,
   CHECK_GE(in_type->size(), 1U);
   int dtype = (*in_type)[0];
   CHECK_NE(dtype, -1) << "First input must have specified type";
-  for (index_t i = 0; i < in_type->size(); ++i) {
+  for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
     } else {
diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h
index 4b9159edd17..feb44c894a7 100644
--- a/src/operator/nn/upsampling-inl.h
+++ b/src/operator/nn/upsampling-inl.h
@@ -48,8 +48,8 @@ enum UpSamplingMultiInputMode {kConcat, kSum};
 }  // namespace up_enum
 
 struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
-  index_t scale;
-  index_t num_filter;
+  int scale;
+  int num_filter;
   int sample_type;
   int num_args;
   int multi_input_mode;
diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc
index 5aa111e26f7..b6b3d873df7 100644
--- a/src/operator/nn/upsampling.cc
+++ b/src/operator/nn/upsampling.cc
@@ -92,7 +92,7 @@ static bool UpSamplingType(const nnvm::NodeAttrs& attrs,
   CHECK_GE(in_type->size(), 1U);
   int dtype = (*in_type)[0];
   CHECK_NE(dtype, -1) << "First input must have specified type";
-  for (index_t i = 0; i < in_type->size(); ++i) {
+  for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
     } else {
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 6a4c3d02707..d7c14172477 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -395,7 +395,7 @@ inline std::vector<nnvm::NodeEntry> MakeGradNode(
   auto p = MakeNode(op_name, n->attrs.name + "_backward",
                     &inputs, &dict, &n);
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < p->num_outputs(); ++i) {
+  for (uint32_t i = 0; i < p->num_outputs(); ++i) {
     ret.emplace_back(nnvm::NodeEntry{p, i, 0});
   }
   return ret;
@@ -406,7 +406,7 @@ inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(
     const nnvm::NodePtr& n,
     const std::vector<nnvm::NodeEntry>& ograds) {
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < n->num_inputs(); ++i) {
+  for (uint32_t i = 0; i < n->num_inputs(); ++i) {
     std::ostringstream os;
     if (1 == n->num_inputs()) {
       os << n->attrs.name << "_backward";
@@ -445,7 +445,7 @@ inline std::vector<nnvm::NodeEntry> MakeNonlossGradNode(
   p->inputs.insert(p->inputs.end(), ograds.begin(), ograds.end());
   p->inputs.insert(p->inputs.end(), inputs.begin(), inputs.end());
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < p->num_outputs(); ++i) {
+  for (uint32_t i = 0; i < p->num_outputs(); ++i) {
     ret.emplace_back(nnvm::NodeEntry{p, i, 0});
   }
   return ret;
diff --git a/src/operator/pad.cu b/src/operator/pad.cu
index 372683a2be8..1aab12a3a79 100644
--- a/src/operator/pad.cu
+++ b/src/operator/pad.cu
@@ -56,9 +56,9 @@ __global__ void image_2d_pad_edge_kernel(Tensor<gpu, 4, DType> dst,
   int oStartY = max(0, padT);
 
   int inputPointX =
-      min(max(padL, outputPointX), src.size(3) + padL - 1) - oStartX + iStartX;
+      min(max(padL, outputPointX), static_cast<int>(src.size(3)) + padL - 1) - oStartX + iStartX;
   int inputPointY =
-      min(max(padT, outputPointY), src.size(2) + padT - 1) - oStartY + iStartY;
+      min(max(padT, outputPointY), static_cast<int>(src.size(2)) + padT - 1) - oStartY + iStartY;
 
   DType valueToCopy = src[batch][plane][inputPointY][inputPointX];
   dst[batch][plane][outputPointY][outputPointX] = valueToCopy;
@@ -98,9 +98,9 @@ __global__ void image_2d_pad_edge_grad_kernel(
   int iStartY = max(0, -padT);
   int oStartX = max(0, padL);
   int oStartY = max(0, padT);
-  int inputPointX = min(max(padL, outputPointX), grad_in.size(3) + padL - 1) -
+  int inputPointX = min(max(padL, outputPointX), static_cast<int>(grad_in.size(3)) + padL - 1) -
                     oStartX + iStartX;
-  int inputPointY = min(max(padT, outputPointY), grad_in.size(2) + padT - 1) -
+  int inputPointY = min(max(padT, outputPointY), static_cast<int>(grad_in.size(2)) + padT - 1) -
                     oStartY + iStartY;
   DType valueToCopy = grad_out[batch][plane][outputPointY][outputPointX];
   atomicAdd(&grad_in[batch][plane][inputPointY][inputPointX], valueToCopy);
@@ -346,11 +346,11 @@ __global__ void image_3d_pad_edge_kernel(Tensor<gpu, 5, DType> dst,
   int oStartZ = max(0, padF);
 
   int inputPointX =
-      min(max(padL, outputPointX), src.size(4) + padL - 1) - oStartX + iStartX;
+      min(max(padL, outputPointX), static_cast<int>(src.size(4)) + padL - 1) - oStartX + iStartX;
   int inputPointY =
-      min(max(padT, outputPointY), src.size(3) + padT - 1) - oStartY + iStartY;
+      min(max(padT, outputPointY), static_cast<int>(src.size(3)) + padT - 1) - oStartY + iStartY;
   int inputPointZ =
-      min(max(padF, outputPointZ), src.size(2) + padF - 1) - oStartZ + iStartZ;
+      min(max(padF, outputPointZ), static_cast<int>(src.size(2)) + padF - 1) - oStartZ + iStartZ;
 
   DType valueToCopy = src[batch][plane][inputPointZ][inputPointY][inputPointX];
   dst[batch][plane][outputPointZ][outputPointY][outputPointX] = valueToCopy;
@@ -395,11 +395,11 @@ __global__ void image_3d_pad_edge_grad_kernel(
   int oStartY = max(0, padT);
   int oStartZ = max(0, padF);
 
-  int inputPointX = min(max(padL, outputPointX), grad_in.size(4) + padL - 1) -
+  int inputPointX = min(max(padL, outputPointX), static_cast<int>(grad_in.size(4)) + padL - 1) -
                     oStartX + iStartX;
-  int inputPointY = min(max(padT, outputPointY), grad_in.size(3) + padT - 1) -
+  int inputPointY = min(max(padT, outputPointY), static_cast<int>(grad_in.size(3)) + padT - 1) -
                     oStartY + iStartY;
-  int inputPointZ = min(max(padF, outputPointZ), grad_in.size(2) + padF - 1) -
+  int inputPointZ = min(max(padF, outputPointZ), static_cast<int>(grad_in.size(2)) + padF - 1) -
                     oStartZ + iStartZ;
   DType valueToCopy =
       grad_out[batch][plane][outputPointZ][outputPointY][outputPointX];
diff --git a/src/operator/random/shuffle_op.cu b/src/operator/random/shuffle_op.cu
index 5bf8320c078..51588494a63 100644
--- a/src/operator/random/shuffle_op.cu
+++ b/src/operator/random/shuffle_op.cu
@@ -60,7 +60,7 @@ void ShuffleForwardGPU(const nnvm::NodeAttrs& attrs,
   const index_t stride = size / first_axis_len;
   Stream<gpu> *s = ctx.get_stream<gpu>();
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    using KeyType = index_t;
+    using KeyType = uint32_t;
     Tensor<gpu, 1, DType> in = inputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
     Tensor<gpu, 1, DType> out = outputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
     Random<gpu, KeyType> *prnd = ctx.requested[0].get_random<gpu, KeyType>(s);
@@ -82,7 +82,8 @@ void ShuffleForwardGPU(const nnvm::NodeAttrs& attrs,
       Tensor<gpu, 1, index_t> indices(reinterpret_cast<index_t*>(tmp_space_ptr),
                                       Shape1(first_axis_len), s);
       tmp_space_ptr += sizeof(index_t) * first_axis_len;
-      Kernel<range_fwd, gpu>::Launch(s, first_axis_len, 1, 0U, 1U, kWriteTo, indices.dptr_);
+      Kernel<range_fwd, gpu>::Launch(s, static_cast<int>(first_axis_len),
+                                     1, index_t(0), index_t(1), kWriteTo, indices.dptr_);
       Tensor<gpu, 1, KeyType> keys(reinterpret_cast<KeyType*>(tmp_space_ptr),
                                    Shape1(first_axis_len), s);
       tmp_space_ptr += sizeof(KeyType) * first_axis_len;
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 3901a805b9d..11ca066c783 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -689,7 +689,7 @@ class RNNProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h
index 58562862a4e..1a59473cfc3 100644
--- a/src/operator/sequence_last-inl.h
+++ b/src/operator/sequence_last-inl.h
@@ -278,7 +278,7 @@ class SequenceLastProp : public OperatorProperty {
     CHECK_GE(in_type->size(), param_.use_sequence_length ? 2U : 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h
index a34cea04965..c93ffb5f17b 100644
--- a/src/operator/sequence_mask-inl.h
+++ b/src/operator/sequence_mask-inl.h
@@ -267,7 +267,7 @@ class SequenceMaskProp : public OperatorProperty {
     CHECK_GE(in_type->size(), param_.use_sequence_length ? 2U : 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/sequence_reverse-inl.h b/src/operator/sequence_reverse-inl.h
index 943ca6e933c..5c48729e18f 100644
--- a/src/operator/sequence_reverse-inl.h
+++ b/src/operator/sequence_reverse-inl.h
@@ -246,7 +246,7 @@ class SequenceReverseProp : public OperatorProperty {
     CHECK_GE(in_type->size(), param_.use_sequence_length ? 2U : 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h
index 9a4db2c9694..fec321b97e4 100644
--- a/src/operator/softmax_output-inl.h
+++ b/src/operator/softmax_output-inl.h
@@ -333,7 +333,7 @@ class SoftmaxOutputProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/svm_output-inl.h b/src/operator/svm_output-inl.h
index 9ae0ced7a74..011b9ad1028 100644
--- a/src/operator/svm_output-inl.h
+++ b/src/operator/svm_output-inl.h
@@ -159,7 +159,7 @@ class SVMOutputProp : public OperatorProperty {
     CHECK_GE(in_type->size(), 1U);
     int dtype = (*in_type)[0];
     CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
+    for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
       } else {
diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh
index b2f514e20cd..5c5236363a5 100644
--- a/src/operator/tensor/indexing_op-inl.cuh
+++ b/src/operator/tensor/indexing_op-inl.cuh
@@ -213,10 +213,10 @@ inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor<gpu, 2, DType> dst
   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
   const int num_unique_est = min(num_rows, src.size(0));
   const int max_nthread = 128;
-  const int num_y = max(src.size(0)/num_unique_est, 1);
+  const int num_y = max(static_cast<int>(src.size(0))/num_unique_est, 1);
   const int block_dim_x = kWarpSize;
   const int block_dim_y = min(num_y, max_nthread/block_dim_x);
-  const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4);
+  const int SZ = min((static_cast<int>(src.size(1)) + block_dim_x - 1) / block_dim_x, 4);
   const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ);
   const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum);
   dim3 dimBlock(block_dim_x, block_dim_y);
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 78e1fa1d9c6..bfee083ca1a 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -86,7 +86,7 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& shape,
   }
   auto dshape_len = dshape_vec.size();
   auto params_len = param_shape_vec.size();
-  for (index_t i = 0; i < params_len; ++i) {
+  for (size_t i = 0; i < params_len; ++i) {
     IType proposed_dim = param_shape_vec[i];
     if (proposed_dim == 0) {
       // keep same
@@ -2085,7 +2085,7 @@ void StackOpForward(const nnvm::NodeAttrs& attrs,
     Shape<3> oshape = Shape3(leading, mid, trailing);
     out = outputs[0].get_with_shape<xpu, 3, DType>(oshape, s);
 
-    for (index_t i = 0; i < inputs.size(); ++i) {
+    for (size_t i = 0; i < inputs.size(); ++i) {
       Shape<3> dshape = Shape3(leading, 1, trailing);
       data[i] = inputs[i].get_with_shape<xpu, 3, DType>(dshape, s);
     }
@@ -2119,7 +2119,7 @@ void StackOpBackward(const nnvm::NodeAttrs& attrs,
     Shape<3> oshape = Shape3(leading, mid, trailing);
     grad = inputs[0].get_with_shape<xpu, 3, DType>(oshape, s);
 
-    for (index_t i = 0; i < outputs.size(); ++i) {
+    for (size_t i = 0; i < outputs.size(); ++i) {
       Shape<3> dshape = Shape3(leading, 1, trailing);
       grad_in[i] = outputs[i].get_with_shape<xpu, 3, DType>(dshape, s);
     }
diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index c1a5b89db09..18bd7608e4c 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -380,14 +380,17 @@ void TopKImpl(const RunContext &ctx,
   Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
   size_t temp_size = 0;
   // Temp space needed by the gpu-based full sorts.
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
+  temp_size = std::max(temp_size,
+    mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
+  temp_size = std::max(temp_size,
+    mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
+  temp_size = std::max(temp_size,
+    mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
   // Additional temp space for gpu full sorts for batch ids.
   temp_size += sizeof(int) * src.Size();
   // Temp space for cpu sorts.
-  temp_size = std::max(temp_size, sizeof(DType) * src.Size());
-  size_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
+  temp_size = std::max(temp_size, sizeof(DType) * static_cast<size_t>(src.Size()));
+  index_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
   if (param.ret_typ == topk_enum::kReturnMask) {
     workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k;
   }
diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc
index 1e2832d3763..189ea19fa6f 100644
--- a/src/operator/tensor/ordering_op.cc
+++ b/src/operator/tensor/ordering_op.cc
@@ -74,8 +74,8 @@ Examples::
     const TopKParam& param = nnvm::get<TopKParam>(n->attrs.parsed);
     if (param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth) {
       std::vector<nnvm::NodeEntry> inputs;
-      index_t n_out = n->num_outputs();
-      for (index_t i = 0; i < n_out; ++i) {
+      uint32_t n_out = n->num_outputs();
+      for (uint32_t i = 0; i < n_out; ++i) {
         inputs.emplace_back(nnvm::NodeEntry{ n, i, 0 });
       }
       return MakeNonlossGradNode("_backward_topk", n, {ograds[0]}, inputs, n->attrs.dict);
@@ -136,8 +136,8 @@ Examples::
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     const SortParam& param = nnvm::get<SortParam>(n->attrs.parsed);
     std::vector<nnvm::NodeEntry> inputs;
-    index_t n_out = n->num_outputs();
-    for (index_t i = 0; i < n_out; ++i) {
+    uint32_t n_out = n->num_outputs();
+    for (uint32_t i = 0; i < n_out; ++i) {
       inputs.emplace_back(nnvm::NodeEntry{ n, i, 0 });
     }
     return MakeNonlossGradNode("_backward_topk", n, {ograds[0]}, inputs,
diff --git a/tests/nightly/test_all.sh b/tests/nightly/test_all.sh
index 04d895fecf2..73f0f588fe9 100755
--- a/tests/nightly/test_all.sh
+++ b/tests/nightly/test_all.sh
@@ -122,4 +122,7 @@ juLog -name=BuildWithoutCUDNN -error=Error build
 # python: multi gpus lenet + mnist
 juLog -name=Python.Multi.Lenet.Mnist -error=Error python multi_lenet.py
 
+# python: large tensor
+juLog -name=Python.LargeTensor -error=Fail python test_large_array.py
+
 exit $errors
diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py
new file mode 100644
index 00000000000..121acc174b5
--- /dev/null
+++ b/tests/nightly/test_large_array.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+import mxnet as mx
+from mxnet import gluon, nd
+
+
+class TestLargeArray(unittest.TestCase):
+    def test_ndarray2numpy(self):
+        m = gluon.nn.Embedding(14000, 128)
+        m.initialize()
+        ind = nd.zeros((700000, 128))
+        x = m(ind)
+        x.shape
+        test = x.asnumpy()
+        assert (x.shape == test.shape)
+
+if __name__ == '__main__':
+    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services