[GitHub] szha closed pull request #12250: [MXNET-507] Set arbitrary dtype for ret_indices in ordering ops

szha closed pull request #12250: [MXNET-507] Set arbitrary dtype for ret_indices in ordering ops

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index c3f6dc6558e..f11a497c564 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -176,6 +176,52 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "Unknown type enum " << type;            \
+#define MXNET_NO_FLOAT16_TYPE_SWITCH(type, DType, ...)     \
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    LOG(FATAL) << "This operation does not "               \
+                  "support float16";                       \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    {                                                      \
+      typedef uint8_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    {                                                      \
+      typedef int8_t DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    {                                                      \
+      typedef int32_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    {                                                      \
+      typedef int64_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
  * \brief assign the val to out according
diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index a6f638e2932..c1a5b89db09 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -58,6 +58,7 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
   int k;
   int ret_typ;
   bool is_ascend;
+  int dtype;
     .describe("Axis along which to choose the top k indices."
@@ -79,6 +80,16 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
       .describe("Whether to choose k largest or k smallest elements."
                 " Top K largest elements will be chosen if set to false.");
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(mshadow::kFloat32)
+    .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
+              "An error will be raised if the selected data type cannot precisely represent the "
+              "indices.");
@@ -97,12 +108,23 @@ struct SortParam : public dmlc::Parameter<SortParam> {
 struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
   dmlc::optional<int> axis;
   bool is_ascend;
+  int dtype;
     .describe("Axis along which to sort the input tensor."
               " If not given, the flattened array is used. Default is -1.");
       .describe("Whether to sort in ascending or descending order.");
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(mshadow::kFloat32)
+    .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
+              " \"both\". An error will be raised if the selected data type cannot precisely "
+              "represent the indices.");
@@ -154,29 +176,22 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha
 using namespace mshadow;
-template<typename xpu>
-void TopKSort(const Tensor<xpu, 1, real_t>& dat,
-              const Tensor<xpu, 1, int>& ind,
-              const Tensor<xpu, 1, char>& work,
-              int K, int N, bool is_ascend,
-              Stream<xpu> *s);
-MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
-                                        const Tensor<cpu, 1, int>& ind,
-                                        const Tensor<cpu, 1, char>& work,
-                                        int K, int N, bool is_ascend,
-                                        Stream<cpu> *s) {
+template<typename DType>
+MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
+                                   const Tensor<cpu, 1, int>& ind,
+                                   const Tensor<cpu, 1, char>& work,
+                                   int K, int N, bool is_ascend,
+                                   Stream<cpu> *s) {
   // Use full sort when K is relatively large.
   const bool full_sort(K*8 > N);
   // Batch size.
-  const int M(work.size(0)/(sizeof(real_t)*N));
+  const int M(work.size(0)/(sizeof(DType)*N));
   const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
   #pragma omp parallel for num_threads(omp_threads)
   for (int i = 0; i < M; ++i) {
     // Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
-    real_t *vals = reinterpret_cast<real_t*>(work.dptr_);
-    real_t *sorted_vals = dat.dptr_+i*N;
+    DType *vals = reinterpret_cast<DType*>(work.dptr_);
+    DType *sorted_vals = dat.dptr_+i*N;
     int *indices = ind.dptr_+i*N;
     if (is_ascend) {
       if (full_sort) {
@@ -285,12 +300,12 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as
-MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
-                                        const Tensor<gpu, 1, int>& ind,
-                                        const Tensor<gpu, 1, char>& work,
-                                        int K, int N, bool is_ascend,
-                                        Stream<gpu> *s) {
+template<typename DType>
+MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
+                                   const Tensor<gpu, 1, int>& ind,
+                                   const Tensor<gpu, 1, char>& work,
+                                   int K, int N, bool is_ascend,
+                                   Stream<gpu> *s) {
   // Use full sort for all but very small K for which we
   // can do a partial sort entirely within shared memory.
   const bool full_sort(K > 5);
@@ -311,7 +326,7 @@ MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
   } else {
     const int nthreads(mshadow::cuda::kBaseThreadNum);
-    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(real_t)),
+    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(DType)),
                         (K, N, dat.dptr_, ind.dptr_, is_ascend);
@@ -331,25 +346,25 @@ MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
    * \param k the K elements to keep
    * \param param the topk parameters
    * \tparam xpu the device type.
+   * \tparam DType type of the output value/mask.
+   * \tparam IDType type of the output indices.
-template<typename xpu>
-void TopKImpl(RunContext ctx,
-              Resource resource,
+template<typename xpu, typename DType, typename IDType>
+void TopKImpl(const RunContext &ctx,
+              const Resource &resource,
+              const std::vector<OpReqType>& req,
               const TBlob& src,
               const std::vector<TBlob>& ret,
               const TopKParam& param) {
   using namespace mshadow;
   using namespace mshadow::expr;
-  for (auto ret_ele : ret) {
-    CHECK_EQ(ret_ele.type_flag_, src.type_flag_);
-  }
   // 1. Parse and initialize information
   Stream<xpu> *s = ctx.get_stream<xpu>();
   Tensor<xpu, 1, char> workspace;
   Tensor<xpu, 1, char> temp_workspace;
-  Tensor<xpu, 1, real_t> sorted_dat;
+  Tensor<xpu, 1, DType> sorted_dat;
   Tensor<xpu, 1, int> indices, sel_indices;
-  Tensor<xpu, 2, real_t> mask_val;
+  Tensor<xpu, 2, DType> mask_val;
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
@@ -358,25 +373,29 @@ void TopKImpl(RunContext ctx,
   TShape target_shape;
   ParseTopKParam(src.shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
-  Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
+    << "'IDType' does not have a sufficient precision to represent the indices of the input array. "
+    << "The total element_num is " << element_num << ", but the selected IDType can only represent "
+    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
   size_t temp_size = 0;
   // Temp space needed by the gpu-based full sorts.
   temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, real_t, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<real_t, int, xpu>(src.Size()));
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
   // Additional temp space for gpu full sorts for batch ids.
   temp_size += sizeof(int) * src.Size();
   // Temp space for cpu sorts.
-  temp_size = std::max(temp_size, sizeof(real_t) * src.Size());
-  size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size();
+  temp_size = std::max(temp_size, sizeof(DType) * src.Size());
+  size_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
   if (param.ret_typ == topk_enum::kReturnMask) {
-    workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k;
+    workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k;
   workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
   char* workspace_curr_ptr = workspace.dptr_;
-  sorted_dat = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+  sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape1(src.Size()), s);  // contain sorted dat
-  workspace_curr_ptr += sizeof(real_t) * src.Size();
+  workspace_curr_ptr += sizeof(DType) * src.Size();
   indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the original matrix
   workspace_curr_ptr += sizeof(int) * src.Size();
@@ -385,28 +404,28 @@ void TopKImpl(RunContext ctx,
     sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                       Shape1(batch_size * k), s);
     workspace_curr_ptr += sizeof(int) * batch_size * k;
-    mask_val = Tensor<xpu, 2, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+    mask_val = Tensor<xpu, 2, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape2(batch_size * k, 1), s);
-    workspace_curr_ptr += sizeof(real_t) * batch_size * k;
-    mask_val = scalar<real_t>(1);
+    workspace_curr_ptr += sizeof(DType) * batch_size * k;
+    mask_val = scalar<DType>(1);
     CHECK_EQ(sel_indices.CheckContiguous(), true);
     CHECK_EQ(mask_val.CheckContiguous(), true);
   if (std::is_same<xpu, cpu>::value) {
-    Tensor<xpu, 1, real_t> flattened_data;
+    Tensor<xpu, 1, DType> flattened_data;
     if (do_transpose) {
-      flattened_data = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+      flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                               Shape1(src.Size()), s);
-      workspace_curr_ptr += sizeof(real_t) * src.Size();
+      workspace_curr_ptr += sizeof(DType) * src.Size();
       flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
       CHECK_EQ(flattened_data.CheckContiguous(), true);
     } else {
-      flattened_data = src.FlatTo1D<xpu, real_t>(s);
+      flattened_data = src.FlatTo1D<xpu, DType>(s);
     // `temp_workspace` stores the flattened data
     temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
-                                          Shape1(sizeof(real_t)*src.Size()), s);
+                                          Shape1(sizeof(DType)*src.Size()), s);
     CHECK_EQ(temp_workspace.CheckContiguous(), true);
   } else {
     if (do_transpose) {
@@ -436,9 +455,9 @@ void TopKImpl(RunContext ctx,
   // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
   // is large enough.
   if (param.ret_typ == topk_enum::kReturnMask) {
-    Tensor<xpu, 2, real_t> ret_mask =
-      ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
-    ret_mask = scalar<real_t>(0);
+    Tensor<xpu, 2, DType> ret_mask =
+      ret[0].get_with_shape<xpu, 2, DType>(Shape2(ret[0].Size(), 1), s);
+    ret_mask = scalar<DType>(0);
     sel_indices = reshape(slice<1>(
@@ -450,53 +469,53 @@ void TopKImpl(RunContext ctx,
       sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                       Shape3(0, 2, 1));
-    IndexFill(ret_mask, sel_indices, mask_val);
+    if (req[0] == kNullOp) {
+      return;
+    } else if (req[0] == kWriteTo) {
+      IndexFill(ret_mask, sel_indices, mask_val);
+    } else {
+      LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
+    }
   } else if (param.ret_typ == topk_enum::kReturnIndices) {
     if (do_transpose) {
-      Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      transpose(slice<2>(inplace_reshape(indices,
-                                                         Shape3(ret_indices.shape_[0],
-                                                                ret_indices.shape_[2],
-                                                                element_num)),
-                                         0, k),
-                                Shape3(0, 2, 1)),
-                      element_num));
-    } else {
-      Tensor<xpu, 2, real_t> ret_indices =
-        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)),
+      Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
+      ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
+                      slice<2>(inplace_reshape(indices,
+                                               Shape3(ret_indices.shape_[0],
+                                                      ret_indices.shape_[2],
+                                                      element_num)),
                                0, k),
-                      element_num));
+                      Shape3(0, 2, 1)), element_num)));
+    } else {
+      Tensor<xpu, 2, IDType> ret_indices =
+        ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+      ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
+                      inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
+                      element_num)));
   } else {
     if (do_transpose) {
-      Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
-      Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
-      ret_value = transpose(
+      Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
+      Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
+      ASSIGN_DISPATCH(ret_value, req[0], transpose(
                                     Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
-                            0, k),
-                   Shape3(0, 2, 1));
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      transpose(slice<2>(inplace_reshape(indices,
-                                                         Shape3(ret_indices.shape_[0],
-                                                         ret_indices.shape_[2],
-                                                         element_num)),
-                                         0, k),
-                                Shape3(0, 2, 1)),
-                      element_num));
+                            0, k), Shape3(0, 2, 1)));
+      ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
+                      slice<2>(inplace_reshape(indices,
+                                               Shape3(ret_indices.shape_[0],
+                                                      ret_indices.shape_[2],
+                                                      element_num)),
+                               0, k), Shape3(0, 2, 1)), element_num)));
     } else {
-      Tensor<xpu, 2, real_t> ret_value =
-        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      Tensor<xpu, 2, real_t> ret_indices =
-        ret[1].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)),
-                               0, k),
-                      element_num));
+      Tensor<xpu, 2, DType> ret_value =
+        ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
+      Tensor<xpu, 2, IDType> ret_indices =
+        ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+      ASSIGN_DISPATCH(ret_value, req[0],
+             slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
+      ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
+                 inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
@@ -508,9 +527,17 @@ void TopK(const nnvm::NodeAttrs& attrs,
           const std::vector<OpReqType>& req,
           const std::vector<TBlob>& outputs) {
   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
-  // TODO(sxjscience) We can support inplace in the future
-  CHECK_EQ(req[0], kWriteTo) << "TopK does not support inplace";
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, param);
+  if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
+      })
+    });
+  } else {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
+    });
+  }
 template<typename xpu>
@@ -520,13 +547,14 @@ void Sort(const nnvm::NodeAttrs& attrs,
           const std::vector<OpReqType>& req,
           const std::vector<TBlob>& outputs) {
   const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "Sort does not support inplace";
   TopKParam topk_param;
   topk_param.axis = param.axis;
   topk_param.is_ascend = param.is_ascend;
   topk_param.k = 0;
   topk_param.ret_typ = topk_enum::kReturnValue;
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param);
+  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param);
+  });
 template<typename xpu>
@@ -536,26 +564,30 @@ void ArgSort(const nnvm::NodeAttrs& attrs,
              const std::vector<OpReqType>& req,
              const std::vector<TBlob>& outputs) {
   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "ArgSort does not support inplace";
   TopKParam topk_param;
   topk_param.axis = param.axis;
   topk_param.is_ascend = param.is_ascend;
   topk_param.k = 0;
+  topk_param.dtype = param.dtype;
   topk_param.ret_typ = topk_enum::kReturnIndices;
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param);
+  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+      TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
+                                   ctx.requested[0], req, inputs[0], outputs, topk_param);
+    });
+  });
-template<typename xpu>
-void TopKBackward_(const nnvm::NodeAttrs& attrs,
-                  const OpContext& ctx,
-                  const std::vector<TBlob>& inputs,
-                  const std::vector<OpReqType>& req,
-                  const std::vector<TBlob>& outputs) {
+template<typename xpu, typename DType, typename IDType>
+void TopKBackwardImpl(const OpContext &ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs,
+                      const TopKParam& param) {
   CHECK_NE(req[0], kWriteInplace);
   using namespace mshadow;
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
-  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
   CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
@@ -565,23 +597,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
   TShape target_shape;
   ParseTopKParam(outputs[0].shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
-  Tensor<xpu, 1, real_t> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, real_t>(Shape1(batch_size * k * 2 + batch_size), s);
-  Tensor<xpu, 1, real_t> sel_indices =
-    Tensor<xpu, 1, real_t>(workspace.dptr_, Shape1(batch_size * k), s);
-  Tensor<xpu, 1, real_t> batch_shift =
-    Tensor<xpu, 1, real_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
-  Tensor<xpu, 1, real_t> dummy_index =
-    Tensor<xpu, 1, real_t>(workspace.dptr_ + batch_size * k + batch_size,
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
+    << "'IDType' does not have a sufficient precision to represent the indices of the input array. "
+    << "The total element_num is " << element_num << ", but the selected IDType can only represent "
+    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  Tensor<xpu, 1, int> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k * 2 + batch_size), s);
+  Tensor<xpu, 1, int> sel_indices =
+    Tensor<xpu, 1, int>(workspace.dptr_, Shape1(batch_size * k), s);
+  Tensor<xpu, 1, int> batch_shift =
+    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
+  Tensor<xpu, 1, int> dummy_index =
+    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k + batch_size,
                            Shape1(batch_size * k), s);
-  Tensor<xpu, 2, real_t> out_grad =
-    inputs[0].get_with_shape<xpu, 2, real_t>(Shape2(inputs[0].shape_.Size(), 1), s);
-  Tensor<xpu, 2, real_t> in_grad =
-    outputs[0].get_with_shape<xpu, 2, real_t>(Shape2(outputs[0].shape_.Size(), 1), s);
-  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0.0f,
-    static_cast<real_t>(element_num), kWriteTo, batch_shift.dptr_);
+  Tensor<xpu, 2, DType> out_grad =
+    inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
+  Tensor<xpu, 2, DType> in_grad =
+    outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
+  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0, element_num, kWriteTo,
+                                           batch_shift.dptr_);
   if (do_transpose) {
-    Tensor<xpu, 1, real_t> indices = inputs[2].FlatTo1D<xpu, real_t>(s);
+    Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
     TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
     sel_indices = reshape(transpose(
@@ -589,26 +626,26 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
                                          TShape(Shape3(src_shape[0], src_shape[2], k))),
                             Shape3(0, 2, 1)),
                           Shape1(batch_size * k));
-    sel_indices += indices;
+    sel_indices += tcast<int>(indices);
     sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                     Shape3(0, 2, 1));
   } else {
-    Tensor<xpu, 2, real_t> indices =
-      inputs[2].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-    sel_indices = reshape(indices +
+    Tensor<xpu, 2, IDType> indices =
+      inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+    sel_indices = reshape(tcast<int>(indices) +
                           broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
                                        TShape(Shape2(batch_size, k))),
                           Shape1(batch_size * k));
   CHECK_EQ(sel_indices.CheckContiguous(), true);
   if (kWriteTo == req[0]) {
-    in_grad = scalar<real_t>(0);
+    in_grad = scalar<DType>(0);
     IndexFill(in_grad, sel_indices, out_grad);
   } else if (kAddTo == req[0]) {
     // TODO(sxjscience) We can use AddTakeGrad in the future.
     // However, the current implementation of AddTakeGrad is not so efficient.
-    mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0.0f,
-      1.0f, kWriteTo, dummy_index.dptr_);
+    mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo,
+                                             dummy_index.dptr_);
     mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad);
   } else if (kNullOp == req[0]) {
@@ -617,6 +654,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
+template<typename xpu>
+void TopKBackward_(const nnvm::NodeAttrs& attrs,
+                   const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
+  if (param.ret_typ == topk_enum::kReturnBoth) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
+      });
+    });
+  } else if (param.ret_typ == topk_enum::kReturnValue) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      TopKBackwardImpl<xpu, DType, int>(ctx, inputs, req, outputs, param);
+    });
+  } else {
+    LOG(FATAL) << "Not Implemented";
+  }
 inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) {
   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
   if (param.ret_typ == topk_enum::kReturnIndices ||
@@ -639,8 +698,36 @@ inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) {
 inline bool TopKType(const nnvm::NodeAttrs& attrs,
                      std::vector<int> *in_attrs,
                      std::vector<int> *out_attrs) {
-  return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
-    attrs, in_attrs, out_attrs, -1);
+  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
+  int data_type = -1;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  CHECK_EQ(in_size, 1);
+  CHECK(out_size == 1 || out_size == 2);
+  if (out_size > 1) {
+    if (param.ret_typ == topk_enum::kReturnValue) {
+      CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+        << "Failed to set the type of ret_indices.";
+    } else {
+      CHECK(type_assign(&(*out_attrs)[1], param.dtype))
+        << "Failed to set the type of ret_indices.";
+    }
+  }
+  if (param.ret_typ == topk_enum::kReturnIndices) {
+    CHECK(type_assign(&(*out_attrs)[0], param.dtype))
+            << "Failed to set the type of ret_indices.";
+  } else {
+    CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
+                                                   << (*in_attrs)[0];
+    CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
+                                                    << (*out_attrs)[0];
+    CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
+                                                   << (*in_attrs)[0];
+    CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
+                                                    << (*out_attrs)[0];
+    if (data_type == -1) return false;
+  }
+  return true;
 inline bool TopKShapeImpl(const TopKParam& param,
@@ -679,6 +766,28 @@ inline bool TopKShape(const nnvm::NodeAttrs& attrs,
   return TopKShapeImpl(param, in_attrs, out_attrs);
+inline bool SortType(const nnvm::NodeAttrs& attrs,
+                     std::vector<int> *in_attrs,
+                     std::vector<int> *out_attrs) {
+  int data_type = -1;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  CHECK_EQ(in_size, 1);
+  CHECK_EQ(out_size, 2);
+  CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+          << "Failed to set the type of ret_indices to int32.";
+  CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
+                                                 << (*in_attrs)[0];
+  CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
+                                                  << (*out_attrs)[0];
+  CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
+                                                 << (*in_attrs)[0];
+  CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
+                                                  << (*out_attrs)[0];
+  if (data_type == -1) return false;
+  return true;
 inline bool SortShape(const nnvm::NodeAttrs& attrs,
                       std::vector<TShape> *in_attrs,
                       std::vector<TShape> *out_attrs) {
@@ -691,6 +800,15 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs,
   return TopKShapeImpl(topk_param, in_attrs, out_attrs);
+inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
+  CHECK(type_assign(&(*out_attrs)[0], param.dtype))
+          << "Failed to set the type of ret_indices to int32.";
+  return true;
 inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
                          std::vector<TShape> *in_attrs,
                          std::vector<TShape> *out_attrs) {
diff --git a/src/operator/tensor/ b/src/operator/tensor/
index ebd7c62ec88..1e2832d3763 100644
--- a/src/operator/tensor/
+++ b/src/operator/tensor/
@@ -35,6 +35,7 @@ DMLC_REGISTER_PARAMETER(ArgSortParam);
 .describe(R"code(Returns the top *k* elements in an input array along the given axis.
+ The returned elements will be sorted.
@@ -128,7 +129,7 @@ Examples::
 .set_attr<nnvm::FInferShape>("FInferShape", SortShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 2>)
+.set_attr<nnvm::FInferType>("FInferType", SortType)
 .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; })
 .set_attr<FCompute>("FCompute<cpu>", Sort<cpu>)
@@ -178,7 +179,7 @@ Examples::
 .set_attr<nnvm::FInferShape>("FInferShape", ArgSortShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ArgSortType)
 .set_attr<FCompute>("FCompute<cpu>", ArgSort<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
diff --git a/tests/python/unittest/ b/tests/python/unittest/
index 2db39d5dd53..c9bc0cd1e1e 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -661,17 +661,19 @@ def gt_topk(dat, axis, ret_typ, k, is_ascend):
     # values, making it hard to generate a numpy 'golden copy' to compare against
     # the mxnet operator.  The 'mask' function is particularly hard to test given that
     # equal values might span the 'k' boundary.  Issue exposed with seed 1405838964.
-    def get_values(ensure_unique):
-        while True:
-            data = np.float32(np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)))
-            if not ensure_unique:
-                return data
-            num_unique_values = len(set(data.flatten()))
-            if data.size == num_unique_values:
-                return data
-    a_npy = get_values(ensure_unique=True)
-    a_nd = mx.nd.array(a_npy, ctx=ctx)
+    def get_values(ensure_unique, dtype):
+        if dtype == np.int16 or dtype == np.int32 or dtype == np.int64:
+            return np.arange(dat_size ** 4, dtype=dtype).reshape((dat_size, dat_size, dat_size, dat_size))
+        elif dtype == np.float32 or dtype == np.float64:
+            while True:
+                data = np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)).astype(dtype)
+                if not ensure_unique:
+                    return data
+                num_unique_values = len(set(data.flatten()))
+                if data.size == num_unique_values:
+                    return data
+        else:
+            raise NotImplementedError
     # Produce a large matrix (256, 300096) as the input data, to cover the case which
     # has a large size of matrix (exceed the express range by float precisly), but
@@ -685,103 +687,161 @@ def get_large_matrix():
     large_matrix_npy = get_large_matrix()
     large_matrix_nd = mx.nd.array(large_matrix_npy, ctx=ctx)
-    # test for ret_typ=indices
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
     nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=1, ret_typ="indices", k=5, is_ascend=False).asnumpy()
     gt = gt_topk(large_matrix_npy, axis=1, ret_typ="indices", k=5, is_ascend=False)
     assert_almost_equal(nd_ret_topk, gt)
-    # test for ret_typ=value
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=0, ret_typ="value", k=3, is_ascend=False).asnumpy()
-    gt = gt_topk(large_matrix_npy, axis=0, ret_typ="value", k=3, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=1, ret_typ="value", k=5, is_ascend=False).asnumpy()
-    gt = gt_topk(large_matrix_npy, axis=1, ret_typ="value", k=5, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    # test for ret_typ=mask
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    # test for ret_typ=both
-    nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
-    nd_ret_topk_val = nd_ret_topk_val.asnumpy()
-    nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
-    gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk_val, gt_val)
-    assert_almost_equal(nd_ret_topk_ind, gt_ind)
-    # test for sort
-    nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_sort, gt)
-    nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_sort, gt)
-    # test for argsort
-    nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_argsort, gt)
-    nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="indices",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_argsort, gt)
-    a = mx.nd.arange(0, 1024, step=1, repeat=1)
-    assert_almost_equal(a.topk(k=1024).asnumpy(), a.asnumpy()[::-1])
+    for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]:
+        a_npy = get_values(ensure_unique=True, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+        # test for ret_typ=indices
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        # test for ret_typ=mask
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        # test for ret_typ=both
+        nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_val = nd_ret_topk_val.asnumpy()
+        nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
+        gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk_val, gt_val)
+        assert_almost_equal(nd_ret_topk_ind, gt_ind)
+        # test for kNullOp
+        _, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
+        assert_almost_equal(nd_ret_topk_ind, gt_ind)
+        # test for kNullOp
+        nd_ret_topk_val, _ = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_val = nd_ret_topk_val.asnumpy()
+        assert_almost_equal(nd_ret_topk_val, gt_val)
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
+        # test for argsort
+        for idtype in [np.int32, np.float16, np.float32, np.float64]:
+            nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy()
+            gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
+            assert_almost_equal(nd_ret_argsort, gt)
+            nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy()
+            gt = gt_topk(a_npy, axis=None, ret_typ="indices",
+                         k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+            assert_almost_equal(nd_ret_argsort, gt)
+        # Repeat those tests that don't involve indices.  These should pass even with
+        # duplicated input data values (over many repeated runs with different random seeds,
+        # this will be tested).
+        a_npy = get_values(ensure_unique=False, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
+    a = mx.nd.arange(0, 1024, step=1, repeat=1, dtype=np.int32)
+    assert_almost_equal(a.topk(k=1024, dtype=np.int32).asnumpy(), a.asnumpy()[::-1])
+    a.attach_grad()
+    k = 10
+    with mx.autograd.record():
+        b = mx.nd.topk(a, k=k, ret_typ='value')
+        b.backward(mx.nd.ones((k,), dtype=np.int32))
+    a_grad = a.grad.asnumpy()
+    for i in range(-1, - k - 1, -1):
+        assert a_grad[i] == 1
+    # test topk gradient with a small shape
+    for dtype in [np.int32, np.int64, np.float32, np.float64]:
+        a = mx.nd.arange(0, 1000, step=1, repeat=1, dtype=dtype)
+        a.attach_grad()
+        k = 10
+        ograd = mx.nd.arange(0, k, dtype=dtype)
+        with mx.autograd.record():
+            b = mx.nd.topk(a, k=k, ret_typ='value')
+            b.backward(ograd)
+        a_grad = a.grad.asnumpy()
+        ograd_npy = ograd.asnumpy()
+        for i in range(-1, - k - 1, -1):
+            assert a_grad[i] == ograd_npy[-i - 1]
     # Repeat those tests that don't involve indices.  These should pass even with
     # duplicated input data values (over many repeated runs with different random seeds,
     # this will be tested).
-    a_npy = get_values(ensure_unique=False)
-    a_nd = mx.nd.array(a_npy, ctx=ctx)
-    # test for ret_typ=value
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    # test for sort
-    nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_sort, gt)
-    nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_sort, gt)
+    for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]:
+        a_npy = get_values(ensure_unique=False, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
 def test_ndarray_equal():


