You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/01 05:48:22 UTC

[GitHub] anirudh2290 closed pull request #13418: [MXNET-1185] Support large array in several operators (part 1)

anirudh2290 closed pull request #13418: [MXNET-1185] Support large array in several operators (part 1)
URL: https://github.com/apache/incubator-mxnet/pull/13418
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h
index cf44da69915..4b8663bba6e 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -100,7 +100,7 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs,
  *  \tparam rsp whether row sparse stype is supported
  *  \tparam rsp whether csr stype is supported
  */
-template<int n_in, int n_out, bool cpu_only, bool rsp, bool csr>
+template<index_t n_in, index_t n_out, bool cpu_only, bool rsp, bool csr>
 inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs,
                                 const int dev_mask,
                                 DispatchMode* dispatch_mode,
@@ -115,7 +115,7 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs,
 template<typename AttrType, bool (*is_none)(const AttrType&),
          bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
          std::string (*attr_string)(const AttrType&),
-         int n_in = -1, int n_out = -1>
+         index_t n_in = -1, index_t n_out = -1>
 inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
                          std::vector<AttrType> *in_attrs,
                          std::vector<AttrType> *out_attrs,
@@ -154,7 +154,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-template<int n_in, int n_out>
+template<index_t n_in, index_t n_out>
 inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
                           std::vector<TShape> *in_attrs,
                           std::vector<TShape> *out_attrs) {
@@ -168,7 +168,7 @@ inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
     attrs, in_attrs, out_attrs, TShape());
 }
 
-template<int n_in, int n_out>
+template<index_t n_in, index_t n_out>
 inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
                          std::vector<int> *in_attrs,
                          std::vector<int> *out_attrs) {
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 5b106afd8d5..6cab1990858 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -289,8 +289,8 @@ inline int get_num_threads<cpu>(const int N) {
 
 /* \brief Compute flattened index given coordinates and shape. */
 template<int ndim>
-MSHADOW_XINLINE int ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
-  int ret = 0;
+MSHADOW_XINLINE index_t ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
+  index_t ret = 0;
   #pragma unroll
   for (int i = 0; i < ndim; ++i) {
     ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i];
@@ -301,11 +301,11 @@ MSHADOW_XINLINE int ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
 
 /* Compute coordinates from flattened index given shape */
 template<int ndim>
-MSHADOW_XINLINE Shape<ndim> unravel(const int idx, const Shape<ndim>& shape) {
+MSHADOW_XINLINE Shape<ndim> unravel(const index_t idx, const Shape<ndim>& shape) {
   Shape<ndim> ret;
   #pragma unroll
-  for (int i = ndim-1, j = idx; i >=0; --i) {
-    int tmp = j / shape[i];
+  for (index_t i = ndim-1, j = idx; i >=0; --i) {
+    auto tmp = j / shape[i];
     ret[i] = j - tmp*shape[i];
     j = tmp;
   }
@@ -315,8 +315,8 @@ MSHADOW_XINLINE Shape<ndim> unravel(const int idx, const Shape<ndim>& shape) {
 
 /* Compute dot product of two vector */
 template<int ndim>
-MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
-  int ret = 0;
+MSHADOW_XINLINE index_t dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
+  index_t ret = 0;
   #pragma unroll
   for (int i = 0; i < ndim; ++i) {
     ret += coord[i] * stride[i];
@@ -327,12 +327,12 @@ MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
 
 /* Combining unravel and dot */
 template<int ndim>
-MSHADOW_XINLINE int unravel_dot(const int idx, const Shape<ndim>& shape,
+MSHADOW_XINLINE index_t unravel_dot(const index_t idx, const Shape<ndim>& shape,
   const Shape<ndim>& stride) {
-  int ret = 0;
+  index_t ret = 0;
   #pragma unroll
-  for (int i = ndim-1, j = idx; i >=0; --i) {
-    int tmp = j / shape[i];
+  for (index_t i = ndim-1, j = idx; i >=0; --i) {
+    auto tmp = j / shape[i];
     ret += (j - tmp*shape[i])*stride[i];
     j = tmp;
   }
@@ -433,51 +433,51 @@ struct op_with_req {
 
   /*! \brief input is one tensor */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in) {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in) {
     KERNEL_ASSIGN(out[i], req, OP::Map(in[i]));
   }
 
   /*! \brief inputs are two tensors */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs, const DType *rhs) {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *lhs, const DType *rhs) {
     KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
   }
 
   /*! \brief input is tensor and a scalar value */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, const DType value) {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in, const DType value) {
     KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
   }
 
   /*! \brief input is tensor and two scalar value */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
+  MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in,
                                   const DType value_1, const DType value_2) {
     KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value_1, value_2));
   }
 
   /*! \brief No inputs (ie fill to constant value) */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out) {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out) {
     KERNEL_ASSIGN(out[i], req, OP::Map());
   }
 
   /*! \brief input is single scalar value */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out, const DType value) {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType value) {
     KERNEL_ASSIGN(out[i], req, OP::Map(value));
   }
 
   /*! \brief inputs are two tensors and a scalar value */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out,
+  MSHADOW_XINLINE static void Map(index_t i, DType *out,
                                   const DType *input_1, const DType *input_2, const DType value) {
     KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], value));
   }
 
   /*! \brief inputs are three tensors (ie backward grad with binary grad function) */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out,
+  MSHADOW_XINLINE static void Map(index_t i, DType *out,
                                   const DType *input_1,
                                   const DType *input_2,
                                   const DType *input_3) {
@@ -503,21 +503,21 @@ struct Kernel<OP, cpu> {
    * \param args Varargs to eventually pass to the OP::Map() function
    */
   template<typename ...Args>
-  inline static bool Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
+  inline static bool Launch(mshadow::Stream<cpu> *, const size_t N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     if (omp_threads < 2) {
-      for (int i = 0; i < N; ++i) {
+      for (size_t i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
     } else {
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < static_cast<index_t>(N); ++i) {
         OP::Map(i, args...);
       }
     }
 #else
-    for (int i = 0; i < N; ++i) {
+    for (size_t i = 0; i < N; ++i) {
       OP::Map(i, args...);
     }
 #endif
@@ -567,22 +567,22 @@ struct Kernel<OP, cpu> {
    * \param args Varargs to eventually pass to the OP::Map() function
    */
   template<typename PRIMITIVE_OP, typename DType, typename ...Args>
-  static void LaunchTuned(mshadow::Stream<cpu> *, const int N, Args... args) {
+  static void LaunchTuned(mshadow::Stream<cpu> *, const size_t N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     if (omp_threads < 2 || !tuned_op<PRIMITIVE_OP, DType>::UseOMP(
-      static_cast<size_t>(N), static_cast<size_t>(omp_threads))) {
-      for (int i = 0; i < N; ++i) {
+      N, static_cast<size_t>(omp_threads))) {
+      for (size_t i = 0; i < N; ++i) {
         OP::Map(i, args...);
       }
     } else {
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < static_cast<index_t>(N); ++i) {
         OP::Map(i, args...);
       }
     }
 #else
-    for (int i = 0; i < N; ++i) {
+    for (size_t i = 0; i < N; ++i) {
       OP::Map(i, args...);
     }
 #endif
@@ -596,15 +596,15 @@ struct Kernel<OP, cpu> {
    * \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions
    */
   template<typename ...Args>
-  inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
+  inline static void LaunchEx(mshadow::Stream<cpu> *s, const size_t N, Args... args) {
 #ifdef _OPENMP
     const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     if (omp_threads < 2) {
       OP::Map(0, N, args...);
     } else {
-      const int length = (N + omp_threads - 1) / omp_threads;
+      const auto length = (N + omp_threads - 1) / omp_threads;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i += length) {
+      for (index_t i = 0; i < static_cast<index_t>(N); i += length) {
         OP::Map(i, i + length > N ? N - i : length, args...);
       }
     }
@@ -626,7 +626,7 @@ struct Kernel<OP, cpu> {
   template<typename DType, typename T = OP, typename ...Args>
   static MSHADOW_CINLINE
   typename std::enable_if<std::is_base_of<tunable, T>::value, bool>::type
-  Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
+  Launch(mshadow::Stream<cpu> *s, const size_t N, DType *dest, Args... args) {
     LaunchTuned<T, DType>(s, N, dest, args...);
     return true;
   }
@@ -644,7 +644,7 @@ struct Kernel<OP, cpu> {
   template<typename DType, typename T = OP, typename ...Args>
   static MSHADOW_CINLINE
   typename std::enable_if<std::is_base_of<tunable, typename T::Operation>::value, bool>::type
-  Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
+  Launch(mshadow::Stream<cpu> *s, const size_t N, DType *dest, Args... args) {
     LaunchTuned<typename T::Operation, DType>(s, N, dest, args...);
     return true;
   }
@@ -700,7 +700,7 @@ template<int val>
 struct set_to_int : public tunable {
   // mxnet_op version (when used directly with Kernel<>::Launch()) */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType *out) {
+  MSHADOW_XINLINE static void Map(index_t i, DType *out) {
     out[i] = DType(val);
   }
   // mshadow_op version (when used with op_with_req<>)
diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h
index ca764e706c6..00963a6785e 100644
--- a/src/operator/random/sampler.h
+++ b/src/operator/random/sampler.h
@@ -43,32 +43,33 @@ namespace op {
 template<typename OP, typename xpu, typename GType, typename ...Args>
 inline static void LaunchRNG(mshadow::Stream<xpu> *s,
                              common::random::RandGenerator<xpu, GType> *gen,
-                             const int N, Args... args) {
+                             const index_t N, Args... args) {
   // minimal check to avoid division by zero, below.
   // if `N` is zero the map operation is a no-op in any case.
   if (N <= 0) {
     return;
   }
-  const int nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
+  const index_t nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
                     RandGenerator<xpu>::kMinNumRandomPerThread;
-  const int nthread = std::min(nloop, RandGenerator<xpu>::kNumRandomStates);
-  const int step = (N + nthread - 1) / nthread;
+  const index_t nthread = std::min(nloop,
+                                   static_cast<index_t>(RandGenerator<xpu>::kNumRandomStates));
+  const index_t step = (N + nthread - 1) / nthread;
   Kernel<OP, xpu>::Launch(s, nthread, *gen, N, step, args...);
 }
 
 #define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...)        \
-  const int start = thread_id * step;                                    \
-  const int end = start + step;                                          \
+  const index_t start = thread_id * step;                                    \
+  const index_t end = start + step;                                          \
   typename RandGenerator<xpu, GType>::Impl genImpl(&gen, thread_id);     \
-  for (int i = start; i < end && i < N; ++i) {                           \
+  for (index_t i = start; i < end && i < N; ++i) {                           \
     {__VA_ARGS__}                                                        \
   }
 
 template<typename xpu>
 struct SampleUniformKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *lower, const IType *upper, OType *out) {
     RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
@@ -127,8 +128,8 @@ struct RandIntSampler {
 template<typename xpu>
 struct SampleNormalKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *mean, const IType *std, OType *out) {
     RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
@@ -154,8 +155,8 @@ struct NormalSampler {
 template<typename xpu>
 struct SampleExponentialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *lambda, OType *out) {
     RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
@@ -202,8 +203,8 @@ MSHADOW_XINLINE OType SampleGamma(IType a, IType b, typename RandGenerator<xpu,
 template<typename xpu>
 struct SampleGammaKernel {
   template<typename IType, typename OType, typename FType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, FType> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, FType> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *alpha, const IType *beta, OType *out) {
     RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, {
@@ -264,8 +265,8 @@ MSHADOW_XINLINE int SamplePoisson(float lambda, typename RandGenerator<xpu, floa
 template<typename xpu>
 struct SamplePoissonKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, float> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *lambda, OType *out) {
     RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
@@ -291,8 +292,8 @@ struct PoissonSampler {
 template<typename xpu>
 struct SampleNegativeBinomialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, float> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *k, const IType *p, OType *out) {
     RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
@@ -323,8 +324,8 @@ struct NegativeBinomialSampler {
 template<typename xpu>
 struct SampleGeneralizedNegativeBinomialKernel {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
-                                  const int N, const int step,
+  MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, float> gen,
+                                  const index_t N, const index_t step,
                                   index_t nParm, index_t nSample,
                                   const IType *mu, const IType *alpha, OType *out) {
     RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h
index 167fa34b083..141d2fb83d0 100644
--- a/src/operator/tensor/broadcast_reduce-inl.h
+++ b/src/operator/tensor/broadcast_reduce-inl.h
@@ -53,14 +53,14 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
 }
 
 template<int ndim>
-MSHADOW_XINLINE void unravel_dot(const int idx, const Shape<ndim>& shape,
-  const Shape<ndim>& stridej, const Shape<ndim>& stridek, int* j, int* k) {
+MSHADOW_XINLINE void unravel_dot(const index_t idx, const Shape<ndim>& shape,
+  const Shape<ndim>& stridej, const Shape<ndim>& stridek, index_t* j, index_t* k) {
   *j = 0;
   *k = 0;
   #pragma unroll
-  for (int i = ndim-1, idx_t = idx; i >=0; --i) {
-    const int tmp = idx_t / shape[i];
-    const int coord = idx_t - tmp*shape[i];
+  for (index_t i = ndim-1, idx_t = idx; i >=0; --i) {
+    const auto tmp = idx_t / shape[i];
+    const auto coord = idx_t - tmp*shape[i];
     *j += coord*stridej[i];
     *k += coord*stridek[i];
     idx_t = tmp;
@@ -68,11 +68,11 @@ MSHADOW_XINLINE void unravel_dot(const int idx, const Shape<ndim>& shape,
 }
 
 template<int ndim>
-MSHADOW_XINLINE Shape<ndim> unravel(const int idx, const Shape<ndim>& shape) {
+MSHADOW_XINLINE Shape<ndim> unravel(const index_t idx, const Shape<ndim>& shape) {
   Shape<ndim> ret;
   #pragma unroll
-  for (int i = ndim-1, j = idx; i >=0; --i) {
-    int tmp = j / shape[i];
+  for (index_t i = ndim-1, j = idx; i >=0; --i) {
+    auto tmp = j / shape[i];
     ret[i] = j - tmp*shape[i];
     j = tmp;
   }
@@ -80,10 +80,10 @@ MSHADOW_XINLINE Shape<ndim> unravel(const int idx, const Shape<ndim>& shape) {
 }
 
 template<int ndim>
-MSHADOW_XINLINE int ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
-  int ret = 0;
+MSHADOW_XINLINE index_t ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
+  index_t ret = 0;
   #pragma unroll
-  for (int i = 0; i < ndim; ++i) {
+  for (index_t i = 0; i < ndim; ++i) {
     ret = ret * shape[i] + (shape[i] > 1) * coord[i];
   }
   return ret;
@@ -111,12 +111,12 @@ MSHADOW_XINLINE int diff(const Shape<ndim>& small, const Shape<ndim>& big, Shape
 }
 
 template<int ndim>
-MSHADOW_XINLINE int unravel_dot(const int idx, const Shape<ndim>& shape,
+MSHADOW_XINLINE index_t unravel_dot(const index_t idx, const Shape<ndim>& shape,
   const Shape<ndim>& stride) {
-  int ret = 0;
+  index_t ret = 0;
   #pragma unroll
-  for (int i = ndim-1, j = idx; i >=0; --i) {
-    int tmp = j / shape[i];
+  for (index_t i = ndim-1, j = idx; i >=0; --i) {
+    auto tmp = j / shape[i];
     ret += (j - tmp*shape[i])*stride[i];
     j = tmp;
   }
@@ -124,8 +124,8 @@ MSHADOW_XINLINE int unravel_dot(const int idx, const Shape<ndim>& shape,
 }
 
 template<int ndim>
-MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
-  int ret = 0;
+MSHADOW_XINLINE index_t dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
+  index_t ret = 0;
   #pragma unroll
   for (int i = 0; i < ndim; ++i)
     ret += coord[i] * stride[i];
@@ -142,27 +142,27 @@ MSHADOW_XINLINE void assign(DType* dst, const bool addto, const DType src) {
 }
 
 template<int ndim, typename DType, typename OP>
-MSHADOW_XINLINE void binary_broadcast_assign(const int idx, const bool addto,
+MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto,
                                              const DType* __restrict lhs,
                                              const DType* __restrict rhs, DType* out,
                                              const Shape<ndim>& lshape, const Shape<ndim>& rshape,
                                              const Shape<ndim>& oshape) {
   const Shape<ndim> coord = unravel(idx, oshape);
-  const int j = ravel(coord, lshape);
-  const int k = ravel(coord, rshape);
+  const index_t j = ravel(coord, lshape);
+  const index_t k = ravel(coord, rshape);
   assign(&out[idx], addto, OP::Map(lhs[j], rhs[k]));
 }
 
 template<typename Reducer, int ndim, typename DType, typename OP>
-MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool addto,
+MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto,
                                        const DType* __restrict big, DType *small,
                                        const Shape<ndim>& bshape, const Shape<ndim>& sshape,
                                        const Shape<ndim>& rshape, const Shape<ndim>& rstride) {
   Shape<ndim> coord = unravel(idx, sshape);
-  int j = ravel(coord, bshape);
+  index_t j = ravel(coord, bshape);
   DType val, residual;
   Reducer::SetInitValue(val, residual);
-  for (int k = 0; k < M; ++k) {
+  for (size_t k = 0; k < M; ++k) {
     coord = unravel(k, rshape);
     Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual);
   }
@@ -176,10 +176,10 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad
 #else
 
 template<int ndim, typename DType, typename OP>
-void binary_broadcast_compute(const int N, const bool addto, const DType *lhs,
+void binary_broadcast_compute(const size_t N, const bool addto, const DType *lhs,
                               const DType *rhs, DType *out, const Shape<ndim> lshape,
                               const Shape<ndim> rshape, const Shape<ndim> oshape) {
-  for (int idx = 0; idx < N; ++idx) {
+  for (size_t idx = 0; idx < N; ++idx) {
     binary_broadcast_assign<ndim, DType, OP>(idx, addto, lhs, rhs, out, lshape, rshape, oshape);
   }
 }
@@ -188,26 +188,26 @@ template<int ndim, typename DType, typename OP>
 void BinaryBroadcastComputeImpl(Stream<cpu> *s, const OpReqType req,
                                 const TBlob& lhs, const TBlob& rhs, const TBlob& out) {
   if (req == kNullOp) return;
-  int N = out.shape_.Size();
+  size_t N = out.shape_.Size();
   binary_broadcast_compute<ndim, DType, OP>(N, req == kAddTo, lhs.dptr<DType>(), rhs.dptr<DType>(),
                            out.dptr<DType>(), lhs.shape_.get<ndim>(), rhs.shape_.get<ndim>(),
                            out.shape_.get<ndim>());
 }
 
 template<typename Reducer, int ndim, typename DType, typename OP>
-void seq_reduce_compute(const int N, const int M, const bool addto,
+void seq_reduce_compute(const size_t N, const size_t M, const bool addto,
                         const DType *big, DType *small, const Shape<ndim> bshape,
                         const Shape<ndim> sshape, const Shape<ndim> rshape,
                         const Shape<ndim> rstride) {
   #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-  for (int idx = 0; idx < N; ++idx) {
+  for (index_t idx = 0; idx < static_cast<index_t>(N); ++idx) {
     seq_reduce_assign<Reducer, ndim, DType, OP>(idx, M, addto, big, small, bshape, sshape, rshape,
       rstride);
   }
 }
 
 template <typename Reducer, int ndim, typename DType, typename OP>
-void seq_reduce_compute_extra_mem(const int N, const int M, const bool addto,
+void seq_reduce_compute_extra_mem(const size_t N, const size_t M, const bool addto,
                                   const DType* big, DType* small,
                                   const Shape<ndim> bshape,
                                   const Shape<ndim> sshape,
@@ -215,12 +215,12 @@ void seq_reduce_compute_extra_mem(const int N, const int M, const bool addto,
                                   const Shape<ndim> rstride,
                                   const index_t* ws_dptr) {
   #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-  for (int idx = 0; idx < N; ++idx) {
+  for (index_t idx = 0; idx < static_cast<index_t>(N); ++idx) {
     Shape<ndim> coord = unravel(idx, sshape);
-    int j = ravel(coord, bshape);
+    index_t j = ravel(coord, bshape);
     DType val, residual;
     Reducer::SetInitValue(val, residual);
-    for (int k = 0; k < M; ++k) {
+    for (size_t k = 0; k < M; ++k) {
       Reducer::Reduce(val, OP::Map(big[j + ws_dptr[k]]), residual);
     }
     assign(&small[idx], addto, val);
@@ -233,7 +233,7 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
   if (req == kNullOp) return;
   Shape<ndim> rshape, rstride;
   diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
-  int N = small.shape_.Size(), M = rshape.Size();
+  size_t N = small.shape_.Size(), M = rshape.Size();
   seq_reduce_compute<Reducer, ndim, DType, OP>(
     N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
     big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
@@ -247,9 +247,9 @@ void ReduceWithExtraMem(Stream<cpu>* s, const TBlob& small, const OpReqType req,
   Shape<ndim> rshape, rstride;
   diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
   index_t* ws_dptr = reinterpret_cast<index_t*>(workspace.dptr_);
-  int N = small.shape_.Size(), M = rshape.Size();
+  size_t N = small.shape_.Size(), M = rshape.Size();
   #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-  for (int k = 0; k < M; k++) {
+  for (index_t k = 0; k < static_cast<index_t>(M); k++) {
     Shape<ndim> coord = unravel(k, rshape);
     ws_dptr[k] = dot(coord, rstride);
   }
@@ -272,7 +272,7 @@ size_t ReduceWorkspaceSize(Stream<cpu> *s, const TShape& small, const OpReqType
 }
 
 template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool addto,
+MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto,
                                        const DType* __restrict big, const DType* __restrict lhs,
                                        const DType* __restrict rhs, DType *small,
                                        const Shape<ndim>& big_shape, const Shape<ndim>& lhs_shape0,
@@ -282,20 +282,20 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad
                                        const Shape<ndim>& rstride, const Shape<ndim>& lhs_stride,
                                        const Shape<ndim>& rhs_stride) {
   Shape<ndim> coord = unravel(idx, small_shape);
-  const int idx_big0 = ravel(coord, big_shape);
-  const int idx_lhs0 = ravel(coord, lhs_shape0);
-  const int idx_rhs0 = ravel(coord, rhs_shape0);
+  const index_t idx_big0 = ravel(coord, big_shape);
+  const index_t idx_lhs0 = ravel(coord, lhs_shape0);
+  const index_t idx_rhs0 = ravel(coord, rhs_shape0);
   DType val, residual;
   Reducer::SetInitValue(val, residual);
-  for (int k = 0; k < M; ++k) {
+  for (size_t k = 0; k < M; ++k) {
     Shape<ndim> coord_big = unravel(k, rshape);
-    int idx_big = idx_big0 + dot(coord_big, rstride);
+    index_t idx_big = idx_big0 + dot(coord_big, rstride);
 
     Shape<ndim> coord_lhs = unravel(k, lhs_shape);
-    int idx_lhs = idx_lhs0 + dot(coord_lhs, lhs_stride);
+    index_t idx_lhs = idx_lhs0 + dot(coord_lhs, lhs_stride);
 
     Shape<ndim> coord_rhs = unravel(k, rhs_shape);
-    int idx_rhs = idx_rhs0 + dot(coord_rhs, rhs_stride);
+    index_t idx_rhs = idx_rhs0 + dot(coord_rhs, rhs_stride);
 
     Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
   }
@@ -304,7 +304,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const int idx, const int M, const bool ad
 }
 
 template<typename Reducer, int ndim, typename DType, typename OP1, typename OP2>
-void seq_reduce_compute(const int N, const int M, const bool addto,
+void seq_reduce_compute(const size_t N, const size_t M, const bool addto,
                         const DType *big, const DType *lhs, const DType *rhs, DType *small,
                         const Shape<ndim> big_shape, const Shape<ndim> small_shape,
                         const Shape<ndim> rshape, const Shape<ndim> rstride,
@@ -312,7 +312,7 @@ void seq_reduce_compute(const int N, const int M, const bool addto,
                         const Shape<ndim> rhs_shape, const Shape<ndim> rhs_stride,
                         const Shape<ndim>& lhs_shape0, const Shape<ndim>& rhs_shape0) {
   #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-  for (int idx = 0; idx < N; ++idx) {
+  for (index_t idx = 0; idx < static_cast<index_t>(N); ++idx) {
     seq_reduce_assign<Reducer, ndim, DType, OP1, OP2>(idx, M, addto, big, lhs, rhs, small,
       big_shape, lhs_shape0, rhs_shape0, small_shape, rshape, lhs_shape, rhs_shape, rstride,
       lhs_stride, rhs_stride);
@@ -326,8 +326,8 @@ void Reduce(Stream<cpu> *s, const TBlob& small, const OpReqType req,
   if (req == kNullOp) return;
   Shape<ndim> rshape, rstride;
   diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
-  int N = small.shape_.Size();
-  int M = rshape.Size();
+  size_t N = small.shape_.Size();
+  size_t M = rshape.Size();
 
   Shape<ndim> lhs_shape, lhs_stride;
   diff(small.shape_.get<ndim>(), lhs.shape_.get<ndim>(), &lhs_shape, &lhs_stride);
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 391c3511712..304422038b8 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -190,7 +190,7 @@ namespace mxnet_op {
 template<int ndim, typename DType, typename OP>
 struct binary_broadcast_kernel {
   /*! \brief Map function for binary_broadcast_kernel */
-  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+  MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req,
                                   const Shape <ndim> &lstride, const Shape <ndim> &rstride,
                                   const Shape <ndim> &oshape, DType *lhs, DType *rhs,
                                   DType *out) {
@@ -199,7 +199,7 @@ struct binary_broadcast_kernel {
     auto ridx = static_cast<index_t>(dot(coord, rstride));
     KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
     // starts from 1 to avoid extra inc at end of loop
-    for (int i = 1; i < length; ++i) {
+    for (index_t i = 1; i < length; ++i) {
       inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
       // When tuning, don't actually run the op, since it's not going to be tuned against
       // the actual op we'll eventually be using
@@ -208,7 +208,7 @@ struct binary_broadcast_kernel {
   }
 
   /*! \brief Map function for binary_broadcast_kernel */
-  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+  MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req,
                                   const Shape <ndim> &lstride, const Shape <ndim> &rstride,
                                   const Shape <ndim> &oshape, DType lhs, DType *rhs,
                                   DType *out) {
@@ -217,7 +217,7 @@ struct binary_broadcast_kernel {
     auto ridx = static_cast<index_t>(dot(coord, rstride));
     KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx]));
     // starts from 1 to avoid extra inc at end of loop
-    for (int i = 1; i < length; ++i) {
+    for (index_t i = 1; i < length; ++i) {
       inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
       // When tuning, don't actually run the op, since it's not going to be tuned against
       // the actual op we'll eventually be using
@@ -238,7 +238,7 @@ struct csr_dns_csr_broadcast_kernel {
    * \param out          ptr to the data buffer of the result csr matrix
    */
   template<typename DType, typename CType, typename RType>
-  MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
+  MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices,
                                   const RType *csr_indptr, const DType *dns, DType *out) {
     const nnvm::dim_t curr_row_i = csr_indptr[row];
     const nnvm::dim_t next_row_i = csr_indptr[row + 1];
@@ -257,7 +257,7 @@ struct csr_dns_csr_broadcast_kernel {
    * \param nnz         number of non-zero elements in input csr matrix
    */
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, const DType *csr_data, const DType* scalar_ptr,
+  MSHADOW_XINLINE static void Map(index_t i, const DType *csr_data, const DType* scalar_ptr,
                                   DType *out, const nnvm::dim_t nnz) {
     const DType scale = scalar_ptr[0];
     if (i < nnz) {
@@ -269,7 +269,7 @@ struct csr_dns_csr_broadcast_kernel {
 template<int req, typename OP, bool reverse = false>
 struct csr_dns_map_kernel {
   template <typename DType, typename CType, typename RType>
-  MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
+  MSHADOW_XINLINE static void Map(index_t row, const DType *csr_data, const CType *csr_indices,
                                   const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows,
                                   const nnvm::dim_t num_cols) {
     if (row < num_rows) {
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 77236e068f8..c39418dbe41 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -36,7 +36,7 @@ struct TakeCPU {
   // K is the number of rows of in_data
   // i is the index of out_data
   template<typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
                                   const IType* idx, const size_t M, const int64_t K) {
     int64_t j = static_cast<int64_t>(idx[i]);
     if (clip) {
@@ -420,19 +420,19 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
 
 template<typename DType, typename IType>
 inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
-GatherNDBackwardImpl(int N, int M, int K,
+GatherNDBackwardImpl(index_t N, index_t M, index_t K,
                      const mshadow::Shape<10> strides,
                      DType* out,
                      const DType* data,
                      const IType* indices,
                      mshadow::Stream<cpu> *s) {
 #pragma omp parallel for
-  for (int i = 0; i < N; i++) {
-    int offset = 0;
-    for (int j = 0; j < M; ++j) {
-      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+  for (index_t i = 0; i < N; i++) {
+    index_t offset = 0;
+    for (index_t j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
     }
-    for (int j = 0; j < K; ++j) {
+    for (index_t j = 0; j < K; ++j) {
 #pragma omp atomic
       out[offset + j] += data[i * K + j];
     }
@@ -441,18 +441,18 @@ GatherNDBackwardImpl(int N, int M, int K,
 
 template<typename DType, typename IType>
 inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
-GatherNDBackwardImpl(int N, int M, int K,
+GatherNDBackwardImpl(index_t N, index_t M, index_t K,
                      const mshadow::Shape<10> strides,
                      DType* out,
                      const DType* data,
                      const IType* indices,
                      mshadow::Stream<cpu> *s) {
-  for (int i = 0; i < N; i++) {
-    int offset = 0;
-    for (int j = 0; j < M; ++j) {
-      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+  for (index_t i = 0; i < N; i++) {
+    index_t offset = 0;
+    for (index_t j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
     }
-    for (int j = 0; j < K; ++j) {
+    for (index_t j = 0; j < K; ++j) {
       out[offset + j] += data[i * K + j];
     }
   }
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 0d72b1815fd..bad3e5a1a6c 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -439,22 +439,22 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const bool deterministic,
 
 struct backward_gather_nd_gpu {
   template<typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
+  MSHADOW_XINLINE static void Map(index_t i, index_t N, index_t M, index_t K,
                                   const mshadow::Shape<10> strides,
                                   DType* out, const DType* data,
                                   const IType* indices) {
-    int offset = 0;
-    for (int j = 0; j < M; ++j) {
+    index_t offset = 0;
+    for (index_t j = 0; j < M; ++j) {
       offset += strides[j] * static_cast<int>(indices[j*N + i]);
     }
-    for (int j = 0; j < K; ++j) {
+    for (index_t j = 0; j < K; ++j) {
       atomicAdd(out + (offset + j), data[i * K + j]);
     }
   }
 };
 
 template<typename DType, typename IType>
-inline void GatherNDBackwardImpl(int N, int M, int K,
+inline void GatherNDBackwardImpl(index_t N, index_t M, index_t K,
                                  const mshadow::Shape<10> strides,
                                  DType* out,
                                  const DType* data,
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 92b6e21018e..fba331e2570 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -314,7 +314,8 @@ struct Take {
    * \param axis        axis id
    */
   template<typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
+                                  const IType* idx,
                                   const mshadow::Shape<10> in_stride,
                                   const mshadow::Shape<10> out_stride,
                                   const int in_ndims, const int out_ndims, const int idx_ndims,
@@ -361,7 +362,7 @@ struct TakeRspKernel {
    * \param nnr         number of non-zero rows
    */
   template<typename DType, typename IType, typename RType>
-  MSHADOW_XINLINE static void Map(int i,
+  MSHADOW_XINLINE static void Map(index_t i,
                                   const IType* data,
                                   DType* out,
                                   const RType* weight_idx,
@@ -1395,15 +1396,15 @@ inline bool ScatterNDType(const nnvm::NodeAttrs& attrs,
 
 struct scatter_nd {
   template<typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int i, OpReqType req, int N, int M, int K,
+  MSHADOW_XINLINE static void Map(index_t i, OpReqType req, index_t N, index_t M, index_t K,
                                   const mshadow::Shape<10> strides,
                                   DType* out, const DType* data,
                                   const IType* indices) {
-    int offset = 0;
-    for (int j = 0; j < M; ++j) {
-      offset += strides[j] * static_cast<int>(indices[j*N + i]);
+    index_t offset = 0;
+    for (index_t j = 0; j < M; ++j) {
+      offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
     }
-    for (int j = 0; j < K; ++j) {
+    for (index_t j = 0; j < K; ++j) {
       KERNEL_ASSIGN(out[offset+j], req, data[i*K + j]);
     }
   }
@@ -1416,17 +1417,18 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
                       const std::vector<OpReqType>& req,
                       const std::vector<TBlob>& outputs) {
   using namespace mshadow;
+  using nnvm::dim_t;
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
   if (req[0] == kNullOp) return;
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   const TShape& oshape = outputs[0].shape_;
   const TShape& ishape = inputs[1].shape_;
-  int M = ishape[0];
-  int N = ishape.Size() / M;
-  int K = oshape.ProdShape(M, oshape.ndim());
+  dim_t M = ishape[0];
+  dim_t N = ishape.Size() / M;
+  dim_t K = oshape.ProdShape(M, oshape.ndim());
   mshadow::Shape<10> strides;
-  for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
+  for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
   if (kWriteTo == req[0]) {
     Fill<true>(s, outputs[0], req[0], 0);
   }
@@ -1441,7 +1443,7 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
 
 template<typename DType, typename IType>
 inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
-GatherNDBackwardImpl(int N, int M, int K,
+GatherNDBackwardImpl(index_t N, index_t M, index_t K,
                      const mshadow::Shape<10> strides,
                      DType* out,
                      const DType* data,
@@ -1450,7 +1452,7 @@ GatherNDBackwardImpl(int N, int M, int K,
 
 template<typename DType, typename IType>
 inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
-GatherNDBackwardImpl(int N, int M, int K,
+GatherNDBackwardImpl(index_t N, index_t M, index_t K,
                      const mshadow::Shape<10> strides,
                      DType* out,
                      const DType* data,
@@ -1458,7 +1460,7 @@ GatherNDBackwardImpl(int N, int M, int K,
                      mshadow::Stream<cpu> *s);
 
 template<typename DType, typename IType>
-inline void GatherNDBackwardImpl(int N, int M, int K,
+inline void GatherNDBackwardImpl(index_t N, index_t M, index_t K,
                                  const mshadow::Shape<10> strides,
                                  DType* out,
                                  const DType* data,
@@ -1472,17 +1474,18 @@ void GatherNDBackward(const nnvm::NodeAttrs& attrs,
                       const std::vector<OpReqType>& req,
                       const std::vector<TBlob>& outputs) {
   using namespace mshadow;
+  using nnvm::dim_t;
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
   if (req[0] == kNullOp) return;
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   const TShape& oshape = outputs[0].shape_;
   const TShape& ishape = inputs[1].shape_;
-  int M = ishape[0];
-  int N = ishape.Size() / M;
-  int K = oshape.ProdShape(M, oshape.ndim());
+  dim_t M = ishape[0];
+  dim_t N = ishape.Size() / M;
+  dim_t K = oshape.ProdShape(M, oshape.ndim());
   mshadow::Shape<10> strides;
-  for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
+  for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
   if (kWriteTo == req[0]) {
     Fill<true>(s, outputs[0], req[0], 0);
   }
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 4e52b087f10..e9e67cb1a4c 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -453,7 +453,7 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
 
 struct range_fwd {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, int repeat, DType start, DType step,
+  MSHADOW_XINLINE static void Map(index_t i, int repeat, DType start, DType step,
                                   int req, DType* out) {
     KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step);
   }
@@ -471,8 +471,8 @@ void RangeCompute(const nnvm::NodeAttrs& attrs,
   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       // Force unsigned params to take two's complement form on ARM to ensure consistency with x86
       // results.  Casting negative floats to unsigned types is undefined in the CPP standard.
-      auto step = std::is_signed<DType>() ? param.step : static_cast<int>(param.step);
-      auto start = std::is_signed<DType>() ? param.start : static_cast<int>(param.start);
+      auto step = std::is_signed<DType>() ? param.step : static_cast<index_t>(param.step);
+      auto start = std::is_signed<DType>() ? param.start : static_cast<index_t>(param.start);
       Kernel<range_fwd, xpu>::Launch(s,
                                      outputs[0].Size(),
                                      static_cast<int>(param.repeat),
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 9c81d87464d..3b229cf38eb 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -626,9 +626,9 @@ inline void GetIndexRange(const TShape& dshape,
                           const nnvm::Tuple<dmlc::optional<int>>& param_begin,
                           const nnvm::Tuple<dmlc::optional<int>>& param_end,
                           const nnvm::Tuple<dmlc::optional<int>>& param_step,
-                          common::StaticArray<int, ndim>* begin,
-                          common::StaticArray<int, ndim>* end,
-                          common::StaticArray<int, ndim>* step) {
+                          common::StaticArray<index_t, ndim>* begin,
+                          common::StaticArray<index_t, ndim>* end,
+                          common::StaticArray<index_t, ndim>* step) {
   CHECK_NE(dshape.ndim(), 0U);
   CHECK_LE(param_begin.ndim(), dshape.ndim())
     << "Slicing axis exceeds data dimensions";
@@ -646,8 +646,8 @@ inline void GetIndexRange(const TShape& dshape,
   }
 
   for (index_t i = 0; i < param_begin.ndim(); ++i) {
-    int b = 0, e = dshape[i], s = 1;
-    const int len = dshape[i];
+    index_t b = 0, e = dshape[i], s = 1;
+    const index_t len = dshape[i];
     if (param_step.ndim() != 0U) {
       const auto& opt_step_val = param_step[i];
       if (opt_step_val.has_value()) {
@@ -724,7 +724,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
   TShape oshape = dshape;
 
   MXNET_NDIM_SWITCH(dshape.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
     for (index_t i = 0; i < param.begin.ndim(); ++i) {
       const int b = begin[i], e = end[i], s = step[i];
@@ -743,19 +743,19 @@ template<int ndim, int req>
 struct slice_forward<ndim, req, gpu> {
   // i is the i-th row after flattening out into 2D tensor
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data,
                                   const mshadow::Shape<ndim> dshape,
                                   const mshadow::Shape<ndim> oshape,
-                                  const common::StaticArray<int, ndim> begin,
-                                  const common::StaticArray<int, ndim> step) {
-    const int data_last_dim_size = dshape[ndim-1];
-    const int out_last_dim_size = oshape[ndim-1];
-    const int step_last_dim = step[ndim-1];
-    const int begin_last_dim = begin[ndim-1];
-    const int j = i % out_last_dim_size;
-    int irow = 0;  // row id of flattend 2D data
-    int stride = 1;
-    int idx = i / out_last_dim_size;
+                                  const common::StaticArray<index_t, ndim> begin,
+                                  const common::StaticArray<index_t, ndim> step) {
+    const index_t data_last_dim_size = dshape[ndim-1];
+    const index_t out_last_dim_size = oshape[ndim-1];
+    const index_t step_last_dim = step[ndim-1];
+    const index_t begin_last_dim = begin[ndim-1];
+    const index_t j = i % out_last_dim_size;
+    index_t irow = 0;  // row id of flattend 2D data
+    index_t stride = 1;
+    index_t idx = i / out_last_dim_size;
     #pragma unroll
     for (int k = ndim - 2; k >= 0; --k) {
       irow += stride * ((idx % oshape[k]) * step[k] + begin[k]);
@@ -771,20 +771,20 @@ template<int ndim, int req>
 struct slice_forward<ndim, req, cpu> {
   // i is the i-th row after flattening out into 2D tensor
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* data,
                                   const mshadow::Shape<ndim> dshape,
                                   const mshadow::Shape<ndim> oshape,
-                                  const common::StaticArray<int, ndim> begin,
-                                  const common::StaticArray<int, ndim> step) {
-    const int data_last_dim_size = dshape[ndim-1];
-    const int out_last_dim_size = oshape[ndim-1];
-    const int step_last_dim = step[ndim-1];
-    const int begin_last_dim = begin[ndim-1];
-    int out_offset = i * out_last_dim_size;
-    for (int j = 0; j < out_last_dim_size; ++j) {
-      int irow = 0;  // row id of flattend 2D data
-      int stride = 1;
-      int idx = i;
+                                  const common::StaticArray<index_t, ndim> begin,
+                                  const common::StaticArray<index_t, ndim> step) {
+    const index_t data_last_dim_size = dshape[ndim-1];
+    const index_t out_last_dim_size = oshape[ndim-1];
+    const index_t step_last_dim = step[ndim-1];
+    const index_t begin_last_dim = begin[ndim-1];
+    index_t out_offset = i * out_last_dim_size;
+    for (index_t j = 0; j < out_last_dim_size; ++j) {
+      index_t irow = 0;  // row id of flattend 2D data
+      index_t stride = 1;
+      index_t idx = i;
       #pragma unroll
       for (int k = ndim - 2; k >= 0; --k) {
         irow += stride * ((idx % oshape[k]) * step[k] + begin[k]);
@@ -813,11 +813,11 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
   const TBlob& out = outputs[0];
   const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
   MXNET_NDIM_SWITCH(data.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        int num_threads = out.shape_.FlatTo2D()[0];
+        size_t num_threads = out.shape_.FlatTo2D()[0];
         if (std::is_same<xpu, gpu>::value) {
           num_threads *= out.shape_.get<ndim>()[ndim - 1];
         }
@@ -836,20 +836,20 @@ template<int ndim, int req>
 struct slice_assign<ndim, req, cpu> {
   // i is the i-th row after flattening out into 2D tensor
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* val,
                                   const mshadow::Shape<ndim> oshape,
                                   const mshadow::Shape<ndim> vshape,
-                                  const common::StaticArray<int, ndim> begin,
-                                  const common::StaticArray<int, ndim> step) {
-    const int data_last_dim_size = oshape[ndim-1];
-    const int out_last_dim_size = vshape[ndim-1];
-    const int step_last_dim = step[ndim-1];
-    const int begin_last_dim = begin[ndim-1];
-    int offset = i * out_last_dim_size;
-    for (int j = 0; j < out_last_dim_size; ++j) {
-      int irow = 0;  // row id of flattend 2D out
-      int stride = 1;
-      int idx = i;
+                                  const common::StaticArray<index_t, ndim> begin,
+                                  const common::StaticArray<index_t, ndim> step) {
+    const index_t data_last_dim_size = oshape[ndim-1];
+    const index_t out_last_dim_size = vshape[ndim-1];
+    const index_t step_last_dim = step[ndim-1];
+    const index_t begin_last_dim = begin[ndim-1];
+    index_t offset = i * out_last_dim_size;
+    for (index_t j = 0; j < out_last_dim_size; ++j) {
+      index_t irow = 0;  // row id of flattend 2D out
+      index_t stride = 1;
+      index_t idx = i;
       #pragma unroll
       for (int k = ndim - 2; k >= 0; --k) {
         irow += stride * ((idx % vshape[k]) * step[k] + begin[k]);
@@ -866,19 +866,19 @@ template<int ndim, int req>
 struct slice_assign<ndim, req, gpu> {
   // i is the i-th row after flattening out into 2D tensor
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* val,
                                   const mshadow::Shape<ndim> oshape,
                                   const mshadow::Shape<ndim> vshape,
-                                  const common::StaticArray<int, ndim> begin,
-                                  const common::StaticArray<int, ndim> step) {
-    const int data_last_dim_size = oshape[ndim-1];
-    const int out_last_dim_size = vshape[ndim-1];
-    const int step_last_dim = step[ndim-1];
-    const int begin_last_dim = begin[ndim-1];
-    const int j = i % out_last_dim_size;
-    int irow = 0;  // row id of flattend 2D out
-    int stride = 1;
-    int idx = i / out_last_dim_size;
+                                  const common::StaticArray<index_t, ndim> begin,
+                                  const common::StaticArray<index_t, ndim> step) {
+    const index_t data_last_dim_size = oshape[ndim-1];
+    const index_t out_last_dim_size = vshape[ndim-1];
+    const index_t step_last_dim = step[ndim-1];
+    const index_t begin_last_dim = begin[ndim-1];
+    const index_t j = i % out_last_dim_size;
+    index_t irow = 0;  // row id of flattend 2D out
+    index_t stride = 1;
+    index_t idx = i / out_last_dim_size;
     #pragma unroll
     for (int k = ndim - 2; k >= 0; --k) {
       irow += stride * ((idx % vshape[k]) * step[k] + begin[k]);
@@ -911,7 +911,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
     LOG(FATAL) << "_slice_backward does not support kWriteInplace";
   }
   MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
@@ -937,7 +937,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs,
   TShape vshape = dshape;  // vshape is the value shape on the right hand side
   const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
   MXNET_NDIM_SWITCH(dshape.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
     for (index_t i = 0; i < param.begin.ndim(); ++i) {
       const int b = begin[i], e = end[i], s = step[i];
@@ -975,7 +975,7 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs,
 
   const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
   MXNET_NDIM_SWITCH(data.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
@@ -1024,20 +1024,20 @@ template<int ndim>
 struct slice_assign_scalar {
   // i is the i-th row after flattening out into 2D tensor
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType val,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType val,
                                   const OpReqType req,
                                   const mshadow::Shape<ndim> oshape,
                                   const mshadow::Shape<ndim> vshape,
-                                  const common::StaticArray<int, ndim> begin,
-                                  const common::StaticArray<int, ndim> step) {
-    const int data_last_dim_size = oshape[ndim-1];
-    const int out_last_dim_size = vshape[ndim-1];
-    const int step_last_dim = step[ndim-1];
-    const int begin_last_dim = begin[ndim-1];
-    for (int j = 0; j < out_last_dim_size; ++j) {
-      int irow = 0;  // row id of flattend 2D out
-      int stride = 1;
-      int idx = i;
+                                  const common::StaticArray<index_t, ndim> begin,
+                                  const common::StaticArray<index_t, ndim> step) {
+    const index_t data_last_dim_size = oshape[ndim-1];
+    const index_t out_last_dim_size = vshape[ndim-1];
+    const index_t step_last_dim = step[ndim-1];
+    const index_t begin_last_dim = begin[ndim-1];
+    for (index_t j = 0; j < out_last_dim_size; ++j) {
+      index_t irow = 0;  // row id of flattend 2D out
+      index_t stride = 1;
+      index_t idx = i;
       #pragma unroll
       for (int k = ndim - 2; k >= 0; --k) {
         irow += stride * ((idx % vshape[k]) * step[k] + begin[k]);
@@ -1076,7 +1076,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs,
   TShape vshape = data.shape_;
   const SliceAssignScalarParam& param = nnvm::get<SliceAssignScalarParam>(attrs.parsed);
   MXNET_NDIM_SWITCH(data.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
     for (index_t i = 0; i < param.begin.ndim(); ++i) {
       const int b = begin[i], e = end[i], s = step[i];
@@ -1107,7 +1107,7 @@ struct SliceAxisParam : public dmlc::Parameter<SliceAxisParam> {
 };
 
 inline void GetSliceAxisParams(const SliceAxisParam& param, const TShape& ishape,
-                           int* axis, int* begin, int* end) {
+                           int* axis, index_t* begin, index_t* end) {
   *axis = param.axis;
   if (*axis < 0) {
     *axis += static_cast<int>(ishape.ndim());
@@ -1115,7 +1115,7 @@ inline void GetSliceAxisParams(const SliceAxisParam& param, const TShape& ishape
   CHECK(*axis < static_cast<int>(ishape.ndim()) && *axis >= 0) <<
     "Transformed axis must be smaller than the source ndim and larger than zero! Recieved axis=" <<
     param.axis << ", src_ndim=" << ishape.ndim() << ", transformed axis=" << *axis;
-  int axis_size = static_cast<int>(ishape[*axis]);
+  index_t axis_size = static_cast<index_t>(ishape[*axis]);
   *begin = param.begin;
   *end = -1;
   if (*begin < 0) {
@@ -1149,7 +1149,8 @@ inline bool SliceAxisShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
   TShape& ishape = (*in_attrs)[0];
-  int axis, begin, end;
+  int axis;
+  index_t begin, end;
   GetSliceAxisParams(param, ishape, &axis, &begin, &end);
   TShape shape(ishape.ndim());
   for (index_t i = 0; i < ishape.ndim(); ++i) {
@@ -1173,7 +1174,8 @@ void SliceAxis(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   const SliceAxisParam& param = nnvm::get<SliceAxisParam>(attrs.parsed);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  int axis, begin, end;
+  int axis;
+  index_t begin, end;
   GetSliceAxisParams(param, inputs[0].shape_, &axis, &begin, &end);
   int ndim = static_cast<int>(outputs[0].ndim());
 
@@ -1207,7 +1209,8 @@ void SliceAxisGrad_(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::op;
   using namespace mshadow::expr;
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  int axis, begin, end;
+  int axis;
+  index_t begin, end;
   GetSliceAxisParams(param, outputs[0].shape_, &axis, &begin, &end);
   int ndim = static_cast<int>(outputs[0].shape_.ndim());
 
@@ -1354,7 +1357,7 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs,
   SliceLikeInferRanges(ishape, from_shape, param.axes, &param_begin, &param_end, &param_step);
 
   MXNET_NDIM_SWITCH(data.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(data.shape_, param_begin, param_end, param_step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
@@ -1400,7 +1403,7 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs,
   SliceLikeInferRanges(ishape, from_shape, param.axes, &param_begin, &param_end, &param_step);
 
   MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
-    common::StaticArray<int, ndim> begin, end, step;
+    common::StaticArray<index_t, ndim> begin, end, step;
     GetIndexRange(ograd.shape_, param_begin, param_end, param_step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
@@ -1429,7 +1432,7 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
 
 struct clip {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* datas,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* datas,
                                   DType a_min, DType a_max) {
     DType data = datas[i];
     if (data > a_max) {
@@ -1445,7 +1448,7 @@ struct clip {
 
 struct clip_grad {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* grad, const DType* datas,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* grad, const DType* datas,
                                   DType a_min, DType a_max) {
     DType data = datas[i];
     if (data > a_max) {
@@ -1934,7 +1937,7 @@ struct reverse {
   }
 #ifdef __CUDACC__
   template<typename DType>
-  __device__  static void Map(int index, index_t nreversedim, const DType *src, DType *dst,
+  __device__  static void Map(index_t index, index_t nreversedim, const DType *src, DType *dst,
                               const index_t * stride_,
                               const index_t * trailing_) {
     __shared__ index_t stride_share[REVERSE_MAX_DIM];
@@ -1949,7 +1952,7 @@ struct reverse {
   }
 #else
   template<typename DType>
-  MSHADOW_XINLINE  static void Map(int index, index_t nreversedim, const DType *src, DType *dst,
+  MSHADOW_XINLINE  static void Map(index_t index, index_t nreversedim, const DType *src, DType *dst,
                                    const index_t * stride_,
                                    const index_t * trailing_) {
     index_t new_idx = ReverseIndex(index, nreversedim, stride_, trailing_);
@@ -2141,10 +2144,10 @@ struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {
 // move all the zeros to the last of the shape array
 // and keep the relative order of the non-zero values.
 // Returns the new shape size after moving all zeros to the end.
-inline uint32_t SqueezeShapeHelper(TShape* shape) {
+inline size_t SqueezeShapeHelper(TShape* shape) {
   CHECK(shape != nullptr);
-  uint32_t count = 0;
-  for (uint32_t i = 0; i < shape->ndim(); ++i) {
+  size_t count = 0;
+  for (size_t i = 0; i < shape->ndim(); ++i) {
     if ((*shape)[i] == 0) {
       ++count;
     } else {
@@ -2167,7 +2170,7 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
   if (param.axis.has_value()) {
     // preprocess axis
     TShape axes = param.axis.value();
-    for (uint32_t i = 0; i < axes.ndim(); ++i) {
+    for (size_t i = 0; i < axes.ndim(); ++i) {
       if (axes[i] < 0) {
         axes[i] += dndim;
         CHECK_GE(axes[i], 0)
@@ -2182,11 +2185,11 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
       oshape[axes[i]] = 0;
     }
   } else {
-    for (uint32_t i = 0; i < oshape.ndim(); ++i) {
+    for (size_t i = 0; i < oshape.ndim(); ++i) {
       if (oshape[i] == 1) oshape[i] = 0;
     }
   }
-  uint32_t oshape_size = SqueezeShapeHelper(&oshape);
+  size_t oshape_size = SqueezeShapeHelper(&oshape);
   if (oshape_size == 0) {  // corner case when dshape is (1, 1, 1, 1)
     oshape[0] = 1;
     oshape_size = 1;
@@ -2229,7 +2232,7 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
 
   expected_out[0] = in_shape[0];
   expected_out[1] = in_shape[1] / (block * block);
-  uint32_t i = 2;
+  size_t i = 2;
   while (i < expected_out.ndim()) {
     expected_out[i] = in_shape[i] * block;
     ++i;
@@ -2259,9 +2262,9 @@ inline bool DepthToSpaceOpType(const nnvm::NodeAttrs& attrs,
  * \param inp_index         index within input tensor from where value is retrieved
  * \param offset_arr        array containing the linear offset of input tensor
  */
-MSHADOW_XINLINE void update_index(int index_position, int dim_size, int *idx,
-                                  int *inp_index, const int* offset_arr) {
-  int next_idx_val = *idx / dim_size;
+MSHADOW_XINLINE void update_index(index_t index_position, index_t dim_size, index_t *idx,
+                                  index_t *inp_index, const index_t* offset_arr) {
+  index_t next_idx_val = *idx / dim_size;
   *inp_index += (*idx - next_idx_val * dim_size) * offset_arr[index_position];
   *idx = next_idx_val;
 }
@@ -2280,9 +2283,9 @@ MSHADOW_XINLINE void update_index(int index_position, int dim_size, int *idx,
 template<int req>
 struct depth_to_space_forward {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
-                                  const int block, const int* size, const int* offset_arr) {
-    int inp_index = 0, idx = i, dim_size;
+  MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
+                                  const int block, const index_t* size, const index_t* offset_arr) {
+    index_t inp_index = 0, idx = i, dim_size;
     dim_size = block;
     update_index(2, dim_size, &idx, &inp_index, offset_arr);
     dim_size = size[3];
@@ -2315,9 +2318,9 @@ struct depth_to_space_forward {
 template<int req>
 struct compute_offset_for_depth_to_space {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* offset_arr, DType* size, const int block,
-                                  const int32_t size0, const int32_t size1, const int32_t size2,
-                                  const int32_t size3) {
+  MSHADOW_XINLINE static void Map(index_t i, DType* offset_arr, DType* size, const int block,
+                                  const index_t size0, const index_t size1, const index_t size2,
+                                  const index_t size3) {
     size[0] = size0;
     size[1] = size1;
     size[2] = size2;
@@ -2349,10 +2352,10 @@ void DepthToSpaceOpForward(const nnvm::NodeAttrs& attrs,
   int block = param.block_size;
 
   mshadow::Tensor<xpu, 1, char> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, char>(mshadow::Shape1(sizeof(int32_t) * 10), s);
+    ctx.requested[0].get_space_typed<xpu, 1, char>(mshadow::Shape1(sizeof(index_t) * 10), s);
   char* workspace_curr_ptr = workspace.dptr_;
-  int32_t* offset_arr = reinterpret_cast<int32_t*>(workspace_curr_ptr);
-  int32_t* size = reinterpret_cast<int32_t*>(workspace_curr_ptr + sizeof(int32_t) * 6);
+  index_t* offset_arr = reinterpret_cast<index_t*>(workspace_curr_ptr);
+  index_t* size = reinterpret_cast<index_t*>(workspace_curr_ptr + sizeof(index_t) * 6);
 
   MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
     MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
@@ -2431,9 +2434,9 @@ inline bool SpaceToDepthOpType(const nnvm::NodeAttrs& attrs,
 template<int req>
 struct space_to_depth_forward {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const int block,
-                                  const int* size, const int* offset_arr) {
-    int inp_index = 0, idx = i, dim_size;
+  MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data, const int block,
+                                  const index_t* size, const index_t* offset_arr) {
+    index_t inp_index = 0, idx = i, dim_size;
     dim_size = size[3] / block;
     update_index(4, dim_size, &idx, &inp_index, offset_arr);
     dim_size = size[2] / block;
@@ -2466,9 +2469,9 @@ struct space_to_depth_forward {
 template<int req>
 struct compute_offset_for_space_to_depth {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* offset_arr, DType* size, const int block,
-                                  const int32_t size0, const int32_t size1,
-                                  const int32_t size2, const int32_t size3) {
+  MSHADOW_XINLINE static void Map(index_t i, DType* offset_arr, DType* size, const int block,
+                                  const index_t size0, const index_t size1,
+                                  const index_t size2, const index_t size3) {
     size[0] = size0;
     size[1] = size1;
     size[2] = size2;
@@ -2500,10 +2503,10 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs& attrs,
   int block = param.block_size;
 
   mshadow::Tensor<xpu, 1, char> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, char>(mshadow::Shape1(sizeof(int32_t) * 10), s);
+    ctx.requested[0].get_space_typed<xpu, 1, char>(mshadow::Shape1(sizeof(index_t) * 10), s);
   char* workspace_curr_ptr = workspace.dptr_;
-  int32_t* offset_arr = reinterpret_cast<int32_t*>(workspace_curr_ptr);
-  int32_t* size = reinterpret_cast<int32_t*>(workspace_curr_ptr + sizeof(int32_t) * 6);
+  index_t* offset_arr = reinterpret_cast<index_t*>(workspace_curr_ptr);
+  index_t* size = reinterpret_cast<index_t*>(workspace_curr_ptr + sizeof(index_t) * 6);
 
   MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
     MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py
index 121acc174b5..a301362f2db 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -15,20 +15,126 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import unittest
 import mxnet as mx
+import numpy as np
 from mxnet import gluon, nd
 
+# dimension constants
+MEDIUM_X = 10000
+LARGE_X = 100000000
+LARGE_Y = 50000000
+SMALL_Y = 50
+LARGE_SIZE = LARGE_X * SMALL_Y
+
+def test_gluon_embedding():
+    m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X)
+    m.initialize()
+    a = nd.zeros((MEDIUM_X, SMALL_Y))
+    b = m(a)
+    assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
+    assert b.asnumpy().size == LARGE_SIZE
+
+def test_ndarray_zeros():
+    a = nd.zeros(shape=(LARGE_X, SMALL_Y))
+    assert a[-1][0] == 0
+    assert a.shape == (LARGE_X, SMALL_Y)
+    assert a.size == LARGE_SIZE
+
+def test_ndarray_ones():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    assert a[-1][0] == 1
+    assert nd.sum(a).asnumpy() == LARGE_SIZE
+
+def test_ndarray_random_uniform():
+    a = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
+    assert a[-1][0] != 0
+
+def test_ndarray_empty():
+    a = nd.empty((LARGE_X, SMALL_Y))
+    assert a.shape == (LARGE_X, SMALL_Y)
+
+def test_elementwise():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    b = nd.ones(shape=(LARGE_X, SMALL_Y))
+    res = a + b
+    assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
+    res = a + 1
+    assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
+    res = nd.sqrt(a + 3)
+    assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
+
+def test_reduce():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y)) 
+    assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1]
+
+def test_dot():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y)) 
+    b = nd.ones(shape=(SMALL_Y, SMALL_Y))
+    res = nd.dot(a, b)
+    assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1]
+
+def test_FullyConnected():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y)) 
+    b = nd.ones(shape=(SMALL_Y, SMALL_Y)) 
+    res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True)
+    assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1]
+
+def test_broadcast():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
+    res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
+    assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1]
+    res = mx.nd.broadcast_like(b, a)
+    assert np.sum(res[-1].asnumpy() == LARGE_X) == a.shape[1]
+
+def test_clip():
+    a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
+    b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y))
+    res = nd.clip(b, a_min=100, a_max=1000)
+    assert np.sum(res[-1].asnumpy() == 1000) == b.shape[1]
+
+def test_take():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    idx = nd.arange(LARGE_X-1000, LARGE_X)
+    res = nd.take(a, idx)
+    assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]
+
+def test_slice():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    res = nd.slice(a, begin=(LARGE_X-1000, 1), end=(LARGE_X, SMALL_Y))
+    assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]
+
+def test_slice_assign():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    a[LARGE_X-1:LARGE_X] = 1000
+    assert np.sum(a[-1].asnumpy() == 1000) == a.shape[1]
+ 
+def test_expand_dims():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    res = nd.expand_dims(a, axis=1)
+    assert res.shape == (a.shape[0], 1, a.shape[1])
+
+def test_squeeze():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    data = nd.expand_dims(a, axis=1)
+    res = nd.squeeze(data)
+    assert res.shape == a.shape
+
+def test_broadcast_div():
+    a = nd.ones(shape=(LARGE_X, SMALL_Y))
+    b = nd.ones(shape=(LARGE_X, 1)) * 2
+    res = a / b
+    assert np.sum(res[-1].asnumpy() == 0.5) == a.shape[1]
+
+def test_Dense(ctx=mx.cpu(0)):
+    data = mx.nd.ones(shape=(50*1000*1000, 100))
+    linear = gluon.nn.Dense(100)
+    linear.initialize(ctx=ctx)
+    res = linear(data)
+    res.wait_to_read()
+    assert res.shape == (50000000, 100)
 
-class TestLargeArray(unittest.TestCase):
-    def test_ndarray2numpy(self):
-        m = gluon.nn.Embedding(14000, 128)
-        m.initialize()
-        ind = nd.zeros((700000, 128))
-        x = m(ind)
-        x.shape
-        test = x.asnumpy()
-        assert (x.shape == test.shape)
 
 if __name__ == '__main__':
-    unittest.main()
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services