You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/20 13:57:59 UTC

[GitHub] sxjscience closed pull request #11134: [MXNET-507] Set dtype=int32 for ret_indices in ordering ops

sxjscience closed pull request #11134: [MXNET-507] Set dtype=int32 for ret_indices in ordering ops
URL: https://github.com/apache/incubator-mxnet/pull/11134
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index c3f6dc6558e..696006d4df8 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -176,6 +176,51 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "Unknown type enum " << type;            \
   }
 
+#define MXNET_NO_FLOAT16_TYPE_SWITCH(type, DType, ...)     \
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    LOG(FATAL) << "This operation does not "               \
+                  "support float16";                       \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    {                                                      \
+      typedef uint8_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    {                                                      \
+      typedef int8_t DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    {                                                      \
+      typedef int32_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    {                                                      \
+      typedef int64_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
 
 /*!
  * \brief assign the val to out according
diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index 105ee8b90db..4e0bc423785 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -58,6 +58,7 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
   int k;
   int ret_typ;
   bool is_ascend;
+  int dtype;
   DMLC_DECLARE_PARAMETER(TopKParam) {
     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
     .describe("Axis along which to choose the top k indices."
@@ -79,6 +80,16 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
     DMLC_DECLARE_FIELD(is_ascend).set_default(false)
       .describe("Whether to choose k largest or k smallest elements."
                 " Top K largest elements will be chosen if set to false.");
+    DMLC_DECLARE_FIELD(dtype)
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(mshadow::kFloat32)
+    .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
+              "An error will be raised if the selected data type cannot precisely represent the "
+              "indices.");
   }
 };
 
@@ -97,12 +108,23 @@ struct SortParam : public dmlc::Parameter<SortParam> {
 struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
   dmlc::optional<int> axis;
   bool is_ascend;
+  int dtype;
   DMLC_DECLARE_PARAMETER(ArgSortParam) {
     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
     .describe("Axis along which to sort the input tensor."
               " If not given, the flattened array is used. Default is -1.");
     DMLC_DECLARE_FIELD(is_ascend).set_default(true)
       .describe("Whether to sort in ascending or descending order.");
+    DMLC_DECLARE_FIELD(dtype)
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(mshadow::kFloat32)
+    .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
+              " \"both\". An error will be raised if the selected data type cannot precisely "
+              "represent the indices.");
   }
 };
 
@@ -154,19 +176,12 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha
 
 using namespace mshadow;
 
-template<typename xpu>
-void TopKSort(const Tensor<xpu, 1, real_t>& dat,
-              const Tensor<xpu, 1, int>& ind,
-              const Tensor<xpu, 1, char>& work,
-              int K, int N, bool is_ascend,
-              Stream<xpu> *s);
-
-template<>
-MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
-                                        const Tensor<cpu, 1, int>& ind,
-                                        const Tensor<cpu, 1, char>& work,
-                                        int K, int N, bool is_ascend,
-                                        Stream<cpu> *s) {
+template<typename DType>
+MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
+                                   const Tensor<cpu, 1, int>& ind,
+                                   const Tensor<cpu, 1, char>& work,
+                                   int K, int N, bool is_ascend,
+                                   Stream<cpu> *s) {
   // Use full sort when K is relatively large.
   const bool full_sort(K*8 > N);
   // Batch size.
@@ -174,7 +189,7 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
   const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
   #pragma omp parallel for num_threads(omp_threads)
   for (int i = 0; i < M; ++i) {
-    real_t *vals = dat.dptr_;
+    DType *vals = dat.dptr_;
     int *indices = ind.dptr_+i*N;
     if (is_ascend) {
       if (full_sort) {
@@ -193,7 +208,7 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
                           [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
       }
     }
-    real_t *buff = reinterpret_cast<real_t*>(work.dptr_)+i*K;
+    DType *buff = reinterpret_cast<DType*>(work.dptr_)+i*K;
     for (int j = 0; j < K; ++j) {
       buff[j] = vals[indices[j]];
     }
@@ -285,12 +300,12 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as
   }
 }
 
-template<>
-MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
-                                        const Tensor<gpu, 1, int>& ind,
-                                        const Tensor<gpu, 1, char>& work,
-                                        int K, int N, bool is_ascend,
-                                        Stream<gpu> *s) {
+template<typename DType>
+MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
+                                   const Tensor<gpu, 1, int>& ind,
+                                   const Tensor<gpu, 1, char>& work,
+                                   int K, int N, bool is_ascend,
+                                   Stream<gpu> *s) {
   // Use full sort for all but very small K for which we
   // can do a partial sort entirely within shared memory.
   const bool full_sort(K > 5);
@@ -311,7 +326,7 @@ MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
     }
   } else {
     const int nthreads(mshadow::cuda::kBaseThreadNum);
-    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(real_t)),
+    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(DType)),
                         mshadow::Stream<gpu>::GetStream(s)>>>
                         (K, N, dat.dptr_, ind.dptr_, is_ascend);
   }
@@ -331,25 +346,25 @@ MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
    * \param k the K elements to keep
    * \param param the topk parameters
    * \tparam xpu the device type.
+   * \tparam DType type of the output value/mask.
+   * \tparam IDType type of the output indices.
    */
-template<typename xpu>
-void TopKImpl(RunContext ctx,
-              Resource resource,
+template<typename xpu, typename DType, typename IDType>
+void TopKImpl(const RunContext &ctx,
+              const Resource &resource,
+              const std::vector<OpReqType>& req,
               const TBlob& src,
               const std::vector<TBlob>& ret,
               const TopKParam& param) {
   using namespace mshadow;
   using namespace mshadow::expr;
-  for (auto ret_ele : ret) {
-    CHECK_EQ(ret_ele.type_flag_, src.type_flag_);
-  }
   // 1. Parse and initialize information
   Stream<xpu> *s = ctx.get_stream<xpu>();
   Tensor<xpu, 1, char> workspace;
   Tensor<xpu, 1, char> temp_workspace;
-  Tensor<xpu, 1, real_t> sorted_dat;
+  Tensor<xpu, 1, DType> sorted_dat;
   Tensor<xpu, 1, int> indices, sel_indices;
-  Tensor<xpu, 2, real_t> mask_val;
+  Tensor<xpu, 2, DType> mask_val;
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
@@ -358,25 +373,29 @@ void TopKImpl(RunContext ctx,
   TShape target_shape;
   ParseTopKParam(src.shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
-  Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
+    << "'IDType' does not have a sufficient precision to represent the indices of the input array. "
+    << "The total element_num is " << element_num << ", but the selected IDType can only represent "
+    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
   size_t temp_size = 0;
   // Temp space needed by the gpu-based full sorts.
   temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, real_t, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<real_t, int, xpu>(src.Size()));
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
   // Additional temp space for gpu full sorts for batch ids.
   temp_size += sizeof(int) * src.Size();
   // Temp space for cpu sorts.
-  temp_size = std::max(temp_size, sizeof(real_t) * src.Size());
-  size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size();
+  temp_size = std::max(temp_size, sizeof(DType) * src.Size());
+  size_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
   if (param.ret_typ == topk_enum::kReturnMask) {
-    workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k;
+    workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k;
   }
   workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
   char* workspace_curr_ptr = workspace.dptr_;
-  sorted_dat = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+  sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape1(src.Size()), s);  // contain sorted dat
-  workspace_curr_ptr += sizeof(real_t) * src.Size();
+  workspace_curr_ptr += sizeof(DType) * src.Size();
   indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the original matrix
   workspace_curr_ptr += sizeof(int) * src.Size();
@@ -394,10 +413,10 @@ void TopKImpl(RunContext ctx,
     sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                       Shape1(batch_size * k), s);
     workspace_curr_ptr += sizeof(int) * batch_size * k;
-    mask_val = Tensor<xpu, 2, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+    mask_val = Tensor<xpu, 2, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape2(batch_size * k, 1), s);
-    workspace_curr_ptr += sizeof(real_t) * batch_size * k;
-    mask_val = scalar<real_t>(1);
+    workspace_curr_ptr += sizeof(DType) * batch_size * k;
+    mask_val = scalar<DType>(1);
     CHECK_EQ(sel_indices.CheckContiguous(), true);
     CHECK_EQ(mask_val.CheckContiguous(), true);
   }
@@ -411,9 +430,9 @@ void TopKImpl(RunContext ctx,
 
   // 3. Assign results to the ret blob
   if (param.ret_typ == topk_enum::kReturnMask) {
-    Tensor<xpu, 2, real_t> ret_mask =
-      ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
-    ret_mask = scalar<real_t>(0);
+    Tensor<xpu, 2, DType> ret_mask =
+      ret[0].get_with_shape<xpu, 2, DType>(Shape2(ret[0].Size(), 1), s);
+    ret_mask = scalar<DType>(0);
     sel_indices = reshape(slice<1>(
                               inplace_reshape(indices,
                                               Shape2(batch_size,
@@ -425,49 +444,56 @@ void TopKImpl(RunContext ctx,
       sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                       Shape3(0, 2, 1));
     }
-    IndexFill(ret_mask, sel_indices, mask_val);
+    if (req[0] == kNullOp) {
+      return;
+    } else if (req[0] == kWriteTo) {
+      IndexFill(ret_mask, sel_indices, mask_val);
+    } else {
+      LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
+    }
   } else if (param.ret_typ == topk_enum::kReturnIndices) {
     indices = F<mshadow_op::mod>(indices, element_num);
     if (do_transpose) {
-      Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
-      ret_indices = tcast<real_t>(transpose(
+      Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
+      Assign(ret_indices, req[0], tcast<IDType>(transpose(
                       slice<2>(inplace_reshape(indices,
                                                Shape3(ret_indices.shape_[0],
                                                       ret_indices.shape_[2],
                                                       element_num)),
                                0, k),
-                      Shape3(0, 2, 1)));
+                      Shape3(0, 2, 1))));
     } else {
-      Tensor<xpu, 2, real_t> ret_indices =
-        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      ret_indices = tcast<real_t>(slice<1>(
-                      inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
+      Tensor<xpu, 2, IDType> ret_indices =
+        ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+      Assign(ret_indices, req[0], tcast<IDType>(slice<1>(
+                      inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)));
     }
   } else {
     indices = F<mshadow_op::mod>(indices, element_num);
     if (do_transpose) {
-      Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
-      Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
-      ret_value = transpose(
+      Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
+      Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
+      Assign(ret_value, req[0], transpose(
                    slice<2>(inplace_reshape(sorted_dat,
                                     Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
                             0, k),
-                   Shape3(0, 2, 1));
-      ret_indices = tcast<real_t>(transpose(
+                   Shape3(0, 2, 1)));
+      Assign(ret_indices, req[1], tcast<IDType>(transpose(
                       slice<2>(inplace_reshape(indices,
                                                Shape3(ret_indices.shape_[0],
                                                       ret_indices.shape_[2],
                                                       element_num)),
                                0, k),
-                      Shape3(0, 2, 1)));
+                      Shape3(0, 2, 1))));
     } else {
-      Tensor<xpu, 2, real_t> ret_value =
-        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      Tensor<xpu, 2, real_t> ret_indices =
-        ret[1].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
-      ret_indices = tcast<real_t>(slice<1>(
-                      inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
+      Tensor<xpu, 2, DType> ret_value =
+        ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
+      Tensor<xpu, 2, IDType> ret_indices =
+        ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+      Assign(ret_value, req[0],
+             slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
+      Assign(ret_indices, req[1], tcast<IDType>(slice<1>(
+                      inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)));
     }
   }
 }
@@ -479,9 +505,17 @@ void TopK(const nnvm::NodeAttrs& attrs,
           const std::vector<OpReqType>& req,
           const std::vector<TBlob>& outputs) {
   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
-  // TODO(sxjscience) We can support inplace in the future
-  CHECK_EQ(req[0], kWriteTo) << "TopK does not support inplace";
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, param);
+  if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
+      })
+    });
+  } else {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
+    });
+  }
 }
 
 template<typename xpu>
@@ -491,13 +525,14 @@ void Sort(const nnvm::NodeAttrs& attrs,
           const std::vector<OpReqType>& req,
           const std::vector<TBlob>& outputs) {
   const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "Sort does not support inplace";
   TopKParam topk_param;
   topk_param.axis = param.axis;
   topk_param.is_ascend = param.is_ascend;
   topk_param.k = 0;
   topk_param.ret_typ = topk_enum::kReturnValue;
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param);
+  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param);
+  });
 }
 
 template<typename xpu>
@@ -507,26 +542,30 @@ void ArgSort(const nnvm::NodeAttrs& attrs,
              const std::vector<OpReqType>& req,
              const std::vector<TBlob>& outputs) {
   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "ArgSort does not support inplace";
   TopKParam topk_param;
   topk_param.axis = param.axis;
   topk_param.is_ascend = param.is_ascend;
   topk_param.k = 0;
+  topk_param.dtype = param.dtype;
   topk_param.ret_typ = topk_enum::kReturnIndices;
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param);
+  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+      TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
+                                   ctx.requested[0], req, inputs[0], outputs, topk_param);
+    });
+  });
 }
 
-template<typename xpu>
-void TopKBackward_(const nnvm::NodeAttrs& attrs,
-                  const OpContext& ctx,
-                  const std::vector<TBlob>& inputs,
-                  const std::vector<OpReqType>& req,
-                  const std::vector<TBlob>& outputs) {
+template<typename xpu, typename DType, typename IDType>
+void TopKBackwardImpl(const OpContext &ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs,
+                      const TopKParam& param) {
   CHECK_NE(req[0], kWriteInplace);
   using namespace mshadow;
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
-  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
   CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
@@ -536,23 +575,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
   TShape target_shape;
   ParseTopKParam(outputs[0].shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
-  Tensor<xpu, 1, real_t> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, real_t>(Shape1(batch_size * k * 2 + batch_size), s);
-  Tensor<xpu, 1, real_t> sel_indices =
-    Tensor<xpu, 1, real_t>(workspace.dptr_, Shape1(batch_size * k), s);
-  Tensor<xpu, 1, real_t> batch_shift =
-    Tensor<xpu, 1, real_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
-  Tensor<xpu, 1, real_t> dummy_index =
-    Tensor<xpu, 1, real_t>(workspace.dptr_ + batch_size * k + batch_size,
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
+    << "'IDType' does not have a sufficient precision to represent the indices of the input array. "
+    << "The total element_num is " << element_num << ", but the selected IDType can only represent "
+    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  Tensor<xpu, 1, int> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k * 2 + batch_size), s);
+  Tensor<xpu, 1, int> sel_indices =
+    Tensor<xpu, 1, int>(workspace.dptr_, Shape1(batch_size * k), s);
+  Tensor<xpu, 1, int> batch_shift =
+    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
+  Tensor<xpu, 1, int> dummy_index =
+    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k + batch_size,
                            Shape1(batch_size * k), s);
-  Tensor<xpu, 2, real_t> out_grad =
-    inputs[0].get_with_shape<xpu, 2, real_t>(Shape2(inputs[0].shape_.Size(), 1), s);
-  Tensor<xpu, 2, real_t> in_grad =
-    outputs[0].get_with_shape<xpu, 2, real_t>(Shape2(outputs[0].shape_.Size(), 1), s);
-  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0.0f,
-    static_cast<real_t>(element_num), kWriteTo, batch_shift.dptr_);
+
+  Tensor<xpu, 2, DType> out_grad =
+    inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
+  Tensor<xpu, 2, DType> in_grad =
+    outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
+  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0, element_num, kWriteTo,
+                                           batch_shift.dptr_);
   if (do_transpose) {
-    Tensor<xpu, 1, real_t> indices = inputs[2].FlatTo1D<xpu, real_t>(s);
+    Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
     TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
     sel_indices = reshape(transpose(
                             broadcast_to(inplace_reshape(batch_shift,
@@ -560,26 +604,26 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
                                          TShape(Shape3(src_shape[0], src_shape[2], k))),
                             Shape3(0, 2, 1)),
                           Shape1(batch_size * k));
-    sel_indices += indices;
+    sel_indices += tcast<int>(indices);
     sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                     Shape3(0, 2, 1));
   } else {
-    Tensor<xpu, 2, real_t> indices =
-      inputs[2].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-    sel_indices = reshape(indices +
+    Tensor<xpu, 2, IDType> indices =
+      inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+    sel_indices = reshape(tcast<int>(indices) +
                           broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
                                        TShape(Shape2(batch_size, k))),
                           Shape1(batch_size * k));
   }
   CHECK_EQ(sel_indices.CheckContiguous(), true);
   if (kWriteTo == req[0]) {
-    in_grad = scalar<real_t>(0);
+    in_grad = scalar<DType>(0);
     IndexFill(in_grad, sel_indices, out_grad);
   } else if (kAddTo == req[0]) {
     // TODO(sxjscience) We can use AddTakeGrad in the future.
     // However, the current implementation of AddTakeGrad is not so efficient.
-    mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0.0f,
-      1.0f, kWriteTo, dummy_index.dptr_);
+    mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo,
+                                             dummy_index.dptr_);
     mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad);
   } else if (kNullOp == req[0]) {
     return;
@@ -588,6 +632,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu>
+void TopKBackward_(const nnvm::NodeAttrs& attrs,
+                   const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
+  if (param.ret_typ == topk_enum::kReturnBoth) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
+      });
+    });
+  } else if (param.ret_typ == topk_enum::kReturnValue) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      TopKBackwardImpl<xpu, DType, int>(ctx, inputs, req, outputs, param);
+    });
+  } else {
+    LOG(FATAL) << "Not Implemented";
+  }
+}
+
 inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) {
   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
   if (param.ret_typ == topk_enum::kReturnIndices ||
@@ -610,8 +676,36 @@ inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) {
 inline bool TopKType(const nnvm::NodeAttrs& attrs,
                      std::vector<int> *in_attrs,
                      std::vector<int> *out_attrs) {
-  return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
-    attrs, in_attrs, out_attrs, -1);
+  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
+  int data_type = -1;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  CHECK_EQ(in_size, 1);
+  CHECK(out_size == 1 || out_size == 2);
+  if (out_size > 1) {
+    if (param.ret_typ == topk_enum::kReturnValue) {
+      CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+        << "Failed to set the type of ret_indices.";
+    } else {
+      CHECK(type_assign(&(*out_attrs)[1], param.dtype))
+        << "Failed to set the type of ret_indices.";
+    }
+  }
+  if (param.ret_typ == topk_enum::kReturnIndices) {
+    CHECK(type_assign(&(*out_attrs)[0], param.dtype))
+            << "Failed to set the type of ret_indices.";
+  } else {
+    CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
+                                                   << (*in_attrs)[0];
+    CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
+                                                    << (*out_attrs)[0];
+    CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
+                                                   << (*in_attrs)[0];
+    CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
+                                                    << (*out_attrs)[0];
+    if (data_type == -1) return false;
+  }
+  return true;
 }
 
 inline bool TopKShapeImpl(const TopKParam& param,
@@ -650,6 +744,28 @@ inline bool TopKShape(const nnvm::NodeAttrs& attrs,
   return TopKShapeImpl(param, in_attrs, out_attrs);
 }
 
+inline bool SortType(const nnvm::NodeAttrs& attrs,
+                     std::vector<int> *in_attrs,
+                     std::vector<int> *out_attrs) {
+  int data_type = -1;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  CHECK_EQ(in_size, 1);
+  CHECK_EQ(out_size, 2);
+  CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+          << "Failed to set the type of ret_indices to int32.";
+  CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
+                                                 << (*in_attrs)[0];
+  CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
+                                                  << (*out_attrs)[0];
+  CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
+                                                 << (*in_attrs)[0];
+  CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
+                                                  << (*out_attrs)[0];
+  if (data_type == -1) return false;
+  return true;
+}
+
 inline bool SortShape(const nnvm::NodeAttrs& attrs,
                       std::vector<TShape> *in_attrs,
                       std::vector<TShape> *out_attrs) {
@@ -662,6 +778,15 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs,
   return TopKShapeImpl(topk_param, in_attrs, out_attrs);
 }
 
+inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
+  CHECK(type_assign(&(*out_attrs)[0], param.dtype))
+          << "Failed to set the type of ret_indices to int32.";
+  return true;
+}
+
 inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
                          std::vector<TShape> *in_attrs,
                          std::vector<TShape> *out_attrs) {
diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc
index ebd7c62ec88..f30be6c12b8 100644
--- a/src/operator/tensor/ordering_op.cc
+++ b/src/operator/tensor/ordering_op.cc
@@ -128,7 +128,7 @@ Examples::
 .set_num_outputs(2)
 .set_attr_parser(ParamParser<SortParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", SortShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 2>)
+.set_attr<nnvm::FInferType>("FInferType", SortType)
 .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; })
 .set_attr<FCompute>("FCompute<cpu>", Sort<cpu>)
 .set_attr<nnvm::FGradient>("FGradient",
@@ -178,7 +178,7 @@ Examples::
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<ArgSortParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ArgSortShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ArgSortType)
 .set_attr<FCompute>("FCompute<cpu>", ArgSort<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FResourceRequest>("FResourceRequest",
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index aeaa0b72679..30abd5a9a3e 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -639,106 +639,147 @@ def gt_topk(dat, axis, ret_typ, k, is_ascend):
     # values, making it hard to generate a numpy 'golden copy' to compare against
     # the mxnet operator.  The 'mask' function is particularly hard to test given that
     # equal values might span the 'k' boundary.  Issue exposed with seed 1405838964.
-    def get_values(ensure_unique):
-        while True:
-            data = np.float32(np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)))
-            if not ensure_unique:
-                return data
-            num_unique_values = len(set(data.flatten()))
-            if data.size == num_unique_values:
-                return data
-
-    a_npy = get_values(ensure_unique=True)
-    a_nd = mx.nd.array(a_npy, ctx=ctx)
-
-    # test for ret_typ=indices
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for ret_typ=value
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for ret_typ=mask
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for ret_typ=both
-    nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
-    nd_ret_topk_val = nd_ret_topk_val.asnumpy()
-    nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
-    gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk_val, gt_val)
-    assert_almost_equal(nd_ret_topk_ind, gt_ind)
-
-    # test for sort
-    nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_sort, gt)
-    nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_sort, gt)
-
-    # test for argsort
-    nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_argsort, gt)
-    nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="indices",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_argsort, gt)
-
-    a = mx.nd.arange(0, 1024, step=1, repeat=1)
-    assert_almost_equal(a.topk(k=1024).asnumpy(), a.asnumpy()[::-1])
+    def get_values(ensure_unique, dtype):
+        if dtype == np.int16 or dtype == np.int32 or dtype == np.int64:
+            return np.arange(dat_size ** 4, dtype=dtype).reshape((dat_size, dat_size, dat_size, dat_size))
+        elif dtype == np.float32 or dtype == np.float64:
+            while True:
+                data = np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)).astype(dtype)
+                if not ensure_unique:
+                    return data
+                num_unique_values = len(set(data.flatten()))
+                if data.size == num_unique_values:
+                    return data
+        else:
+            raise NotImplementedError
+
+    for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]:
+        a_npy = get_values(ensure_unique=True, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+
+        # test for ret_typ=indices
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for ret_typ=mask
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for ret_typ=both
+        nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_val = nd_ret_topk_val.asnumpy()
+        nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
+        gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk_val, gt_val)
+        assert_almost_equal(nd_ret_topk_ind, gt_ind)
+        # test for kNullOp
+        _, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
+        assert_almost_equal(nd_ret_topk_ind, gt_ind)
+        # test for kNullOp
+        nd_ret_topk_val, _ = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_val = nd_ret_topk_val.asnumpy()
+        assert_almost_equal(nd_ret_topk_val, gt_val)
+
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
+
+        # test for argsort
+        for idtype in [np.int32, np.float16, np.float32, np.float64]:
+            nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy()
+            gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
+            assert_almost_equal(nd_ret_argsort, gt)
+            nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy()
+            gt = gt_topk(a_npy, axis=None, ret_typ="indices",
+                         k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+            assert_almost_equal(nd_ret_argsort, gt)
+
+    # test topk with a big shape
+    a = mx.nd.arange(0, 1024, step=1, repeat=1, dtype=np.int32)
+    assert_almost_equal(a.topk(k=1024, dtype=np.int32).asnumpy(), a.asnumpy()[::-1])
+    a.attach_grad()
+
+    k = 10
+    with mx.autograd.record():
+        b = mx.nd.topk(a, k=k, ret_typ='value')
+        b.backward(mx.nd.ones((k,), dtype=np.int32))
+    a_grad = a.grad.asnumpy()
+    for i in range(-1, - k - 1, -1):
+        assert a_grad[i] == 1
+
+    # test topk gradient with a small shape
+    for dtype in [np.int32, np.int64, np.float32, np.float64]:
+        a = mx.nd.arange(0, 1000, step=1, repeat=1, dtype=dtype)
+        a.attach_grad()
+        k = 10
+        ograd = mx.nd.arange(0, k, dtype=dtype)
+        with mx.autograd.record():
+            b = mx.nd.topk(a, k=k, ret_typ='value')
+            b.backward(ograd)
+        a_grad = a.grad.asnumpy()
+        ograd_npy = ograd.asnumpy()
+        for i in range(-1, - k - 1, -1):
+            assert a_grad[i] == ograd_npy[-i - 1]
+
 
     # Repeat those tests that don't involve indices.  These should pass even with
     # duplicated input data values (over many repeated runs with different random seeds,
     # this will be tested).
-    a_npy = get_values(ensure_unique=False)
-    a_nd = mx.nd.array(a_npy, ctx=ctx)
-
-    # test for ret_typ=value
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for sort
-    nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_sort, gt)
-    nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_sort, gt)
+    for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]:
+        a_npy = get_values(ensure_unique=False, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
 
 @with_seed()
 def test_ndarray_equal():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services