You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/22 23:35:27 UTC

[GitHub] szha closed pull request #12250: [MXNET-507] Set arbitrary dtype for ret_indices in ordering ops

szha closed pull request #12250: [MXNET-507] Set arbitrary dtype for ret_indices in ordering ops
URL: https://github.com/apache/incubator-mxnet/pull/12250
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index c3f6dc6558e..f11a497c564 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -176,6 +176,52 @@ inline int get_num_threads<cpu>(const int N) {
     LOG(FATAL) << "Unknown type enum " << type;            \
   }
 
+#define MXNET_NO_FLOAT16_TYPE_SWITCH(type, DType, ...)     \
+  switch (type) {                                          \
+  case mshadow::kFloat32:                                  \
+    {                                                      \
+      typedef float DType;                                 \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat64:                                  \
+    {                                                      \
+      typedef double DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kFloat16:                                  \
+    LOG(FATAL) << "This operation does not "               \
+                  "support float16";                       \
+    break;                                                 \
+  case mshadow::kUint8:                                    \
+    {                                                      \
+      typedef uint8_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt8:                                     \
+    {                                                      \
+      typedef int8_t DType;                                \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt32:                                    \
+    {                                                      \
+      typedef int32_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  case mshadow::kInt64:                                    \
+    {                                                      \
+      typedef int64_t DType;                               \
+      {__VA_ARGS__}                                        \
+    }                                                      \
+    break;                                                 \
+  default:                                                 \
+    LOG(FATAL) << "Unknown type enum " << type;            \
+  }
+
 
 /*!
  * \brief assign the val to out according
diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index a6f638e2932..c1a5b89db09 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -58,6 +58,7 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
   int k;
   int ret_typ;
   bool is_ascend;
+  int dtype;
   DMLC_DECLARE_PARAMETER(TopKParam) {
     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
     .describe("Axis along which to choose the top k indices."
@@ -79,6 +80,16 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
     DMLC_DECLARE_FIELD(is_ascend).set_default(false)
       .describe("Whether to choose k largest or k smallest elements."
                 " Top K largest elements will be chosen if set to false.");
+    DMLC_DECLARE_FIELD(dtype)
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(mshadow::kFloat32)
+    .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
+              "An error will be raised if the selected data type cannot precisely represent the "
+              "indices.");
   }
 };
 
@@ -97,12 +108,23 @@ struct SortParam : public dmlc::Parameter<SortParam> {
 struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
   dmlc::optional<int> axis;
   bool is_ascend;
+  int dtype;
   DMLC_DECLARE_PARAMETER(ArgSortParam) {
     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
     .describe("Axis along which to sort the input tensor."
               " If not given, the flattened array is used. Default is -1.");
     DMLC_DECLARE_FIELD(is_ascend).set_default(true)
       .describe("Whether to sort in ascending or descending order.");
+    DMLC_DECLARE_FIELD(dtype)
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(mshadow::kFloat32)
+    .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
+              " \"both\". An error will be raised if the selected data type cannot precisely "
+              "represent the indices.");
   }
 };
 
@@ -154,29 +176,22 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha
 
 using namespace mshadow;
 
-template<typename xpu>
-void TopKSort(const Tensor<xpu, 1, real_t>& dat,
-              const Tensor<xpu, 1, int>& ind,
-              const Tensor<xpu, 1, char>& work,
-              int K, int N, bool is_ascend,
-              Stream<xpu> *s);
-
-template<>
-MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
-                                        const Tensor<cpu, 1, int>& ind,
-                                        const Tensor<cpu, 1, char>& work,
-                                        int K, int N, bool is_ascend,
-                                        Stream<cpu> *s) {
+template<typename DType>
+MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
+                                   const Tensor<cpu, 1, int>& ind,
+                                   const Tensor<cpu, 1, char>& work,
+                                   int K, int N, bool is_ascend,
+                                   Stream<cpu> *s) {
   // Use full sort when K is relatively large.
   const bool full_sort(K*8 > N);
   // Batch size.
-  const int M(work.size(0)/(sizeof(real_t)*N));
+  const int M(work.size(0)/(sizeof(DType)*N));
   const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
   #pragma omp parallel for num_threads(omp_threads)
   for (int i = 0; i < M; ++i) {
     // Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
-    real_t *vals = reinterpret_cast<real_t*>(work.dptr_);
-    real_t *sorted_vals = dat.dptr_+i*N;
+    DType *vals = reinterpret_cast<DType*>(work.dptr_);
+    DType *sorted_vals = dat.dptr_+i*N;
     int *indices = ind.dptr_+i*N;
     if (is_ascend) {
       if (full_sort) {
@@ -285,12 +300,12 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as
   }
 }
 
-template<>
-MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
-                                        const Tensor<gpu, 1, int>& ind,
-                                        const Tensor<gpu, 1, char>& work,
-                                        int K, int N, bool is_ascend,
-                                        Stream<gpu> *s) {
+template<typename DType>
+MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
+                                   const Tensor<gpu, 1, int>& ind,
+                                   const Tensor<gpu, 1, char>& work,
+                                   int K, int N, bool is_ascend,
+                                   Stream<gpu> *s) {
   // Use full sort for all but very small K for which we
   // can do a partial sort entirely within shared memory.
   const bool full_sort(K > 5);
@@ -311,7 +326,7 @@ MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
     }
   } else {
     const int nthreads(mshadow::cuda::kBaseThreadNum);
-    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(real_t)),
+    PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(DType)),
                         mshadow::Stream<gpu>::GetStream(s)>>>
                         (K, N, dat.dptr_, ind.dptr_, is_ascend);
   }
@@ -331,25 +346,25 @@ MSHADOW_FORCE_INLINE void TopKSort<gpu>(const Tensor<gpu, 1, real_t>& dat,
    * \param k the K elements to keep
    * \param param the topk parameters
    * \tparam xpu the device type.
+   * \tparam DType type of the output value/mask.
+   * \tparam IDType type of the output indices.
    */
-template<typename xpu>
-void TopKImpl(RunContext ctx,
-              Resource resource,
+template<typename xpu, typename DType, typename IDType>
+void TopKImpl(const RunContext &ctx,
+              const Resource &resource,
+              const std::vector<OpReqType>& req,
               const TBlob& src,
               const std::vector<TBlob>& ret,
               const TopKParam& param) {
   using namespace mshadow;
   using namespace mshadow::expr;
-  for (auto ret_ele : ret) {
-    CHECK_EQ(ret_ele.type_flag_, src.type_flag_);
-  }
   // 1. Parse and initialize information
   Stream<xpu> *s = ctx.get_stream<xpu>();
   Tensor<xpu, 1, char> workspace;
   Tensor<xpu, 1, char> temp_workspace;
-  Tensor<xpu, 1, real_t> sorted_dat;
+  Tensor<xpu, 1, DType> sorted_dat;
   Tensor<xpu, 1, int> indices, sel_indices;
-  Tensor<xpu, 2, real_t> mask_val;
+  Tensor<xpu, 2, DType> mask_val;
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
@@ -358,25 +373,29 @@ void TopKImpl(RunContext ctx,
   TShape target_shape;
   ParseTopKParam(src.shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
-  Tensor<xpu, 3, real_t> dat = src.FlatTo3D<xpu, real_t>(axis, axis, s);
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
+    << "'IDType' does not have a sufficient precision to represent the indices of the input array. "
+    << "The total element_num is " << element_num << ", but the selected IDType can only represent "
+    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
   size_t temp_size = 0;
   // Temp space needed by the gpu-based full sorts.
   temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, real_t, xpu>(src.Size()));
-  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<real_t, int, xpu>(src.Size()));
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
+  temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
   // Additional temp space for gpu full sorts for batch ids.
   temp_size += sizeof(int) * src.Size();
   // Temp space for cpu sorts.
-  temp_size = std::max(temp_size, sizeof(real_t) * src.Size());
-  size_t workspace_size = temp_size + sizeof(real_t) * src.Size() + sizeof(int) * src.Size();
+  temp_size = std::max(temp_size, sizeof(DType) * src.Size());
+  size_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
   if (param.ret_typ == topk_enum::kReturnMask) {
-    workspace_size += sizeof(int) * batch_size * k + sizeof(real_t) * batch_size * k;
+    workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k;
   }
   workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
   char* workspace_curr_ptr = workspace.dptr_;
-  sorted_dat = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+  sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape1(src.Size()), s);  // contain sorted dat
-  workspace_curr_ptr += sizeof(real_t) * src.Size();
+  workspace_curr_ptr += sizeof(DType) * src.Size();
   indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the original matrix
   workspace_curr_ptr += sizeof(int) * src.Size();
@@ -385,28 +404,28 @@ void TopKImpl(RunContext ctx,
     sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                       Shape1(batch_size * k), s);
     workspace_curr_ptr += sizeof(int) * batch_size * k;
-    mask_val = Tensor<xpu, 2, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+    mask_val = Tensor<xpu, 2, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape2(batch_size * k, 1), s);
-    workspace_curr_ptr += sizeof(real_t) * batch_size * k;
-    mask_val = scalar<real_t>(1);
+    workspace_curr_ptr += sizeof(DType) * batch_size * k;
+    mask_val = scalar<DType>(1);
     CHECK_EQ(sel_indices.CheckContiguous(), true);
     CHECK_EQ(mask_val.CheckContiguous(), true);
   }
 
   if (std::is_same<xpu, cpu>::value) {
-    Tensor<xpu, 1, real_t> flattened_data;
+    Tensor<xpu, 1, DType> flattened_data;
     if (do_transpose) {
-      flattened_data = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
+      flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                               Shape1(src.Size()), s);
-      workspace_curr_ptr += sizeof(real_t) * src.Size();
+      workspace_curr_ptr += sizeof(DType) * src.Size();
       flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
       CHECK_EQ(flattened_data.CheckContiguous(), true);
     } else {
-      flattened_data = src.FlatTo1D<xpu, real_t>(s);
+      flattened_data = src.FlatTo1D<xpu, DType>(s);
     }
     // `temp_workspace` stores the flattened data
     temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
-                                          Shape1(sizeof(real_t)*src.Size()), s);
+                                          Shape1(sizeof(DType)*src.Size()), s);
     CHECK_EQ(temp_workspace.CheckContiguous(), true);
   } else {
     if (do_transpose) {
@@ -436,9 +455,9 @@ void TopKImpl(RunContext ctx,
   // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
   // is large enough.
   if (param.ret_typ == topk_enum::kReturnMask) {
-    Tensor<xpu, 2, real_t> ret_mask =
-      ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
-    ret_mask = scalar<real_t>(0);
+    Tensor<xpu, 2, DType> ret_mask =
+      ret[0].get_with_shape<xpu, 2, DType>(Shape2(ret[0].Size(), 1), s);
+    ret_mask = scalar<DType>(0);
     sel_indices = reshape(slice<1>(
                               inplace_reshape(indices,
                                               Shape2(batch_size,
@@ -450,53 +469,53 @@ void TopKImpl(RunContext ctx,
       sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                       Shape3(0, 2, 1));
     }
-    IndexFill(ret_mask, sel_indices, mask_val);
+    if (req[0] == kNullOp) {
+      return;
+    } else if (req[0] == kWriteTo) {
+      IndexFill(ret_mask, sel_indices, mask_val);
+    } else {
+      LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
+    }
   } else if (param.ret_typ == topk_enum::kReturnIndices) {
     if (do_transpose) {
-      Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      transpose(slice<2>(inplace_reshape(indices,
-                                                         Shape3(ret_indices.shape_[0],
-                                                                ret_indices.shape_[2],
-                                                                element_num)),
-                                         0, k),
-                                Shape3(0, 2, 1)),
-                      element_num));
-    } else {
-      Tensor<xpu, 2, real_t> ret_indices =
-        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)),
+      Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
+      ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
+                      slice<2>(inplace_reshape(indices,
+                                               Shape3(ret_indices.shape_[0],
+                                                      ret_indices.shape_[2],
+                                                      element_num)),
                                0, k),
-                      element_num));
+                      Shape3(0, 2, 1)), element_num)));
+    } else {
+      Tensor<xpu, 2, IDType> ret_indices =
+        ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+      ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
+                      inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
+                      element_num)));
     }
   } else {
     if (do_transpose) {
-      Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
-      Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
-      ret_value = transpose(
+      Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
+      Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
+      ASSIGN_DISPATCH(ret_value, req[0], transpose(
                    slice<2>(inplace_reshape(sorted_dat,
                                     Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
-                            0, k),
-                   Shape3(0, 2, 1));
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      transpose(slice<2>(inplace_reshape(indices,
-                                                         Shape3(ret_indices.shape_[0],
-                                                         ret_indices.shape_[2],
-                                                         element_num)),
-                                         0, k),
-                                Shape3(0, 2, 1)),
-                      element_num));
+                            0, k), Shape3(0, 2, 1)));
+      ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
+                      slice<2>(inplace_reshape(indices,
+                                               Shape3(ret_indices.shape_[0],
+                                                      ret_indices.shape_[2],
+                                                      element_num)),
+                               0, k), Shape3(0, 2, 1)), element_num)));
     } else {
-      Tensor<xpu, 2, real_t> ret_value =
-        ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      Tensor<xpu, 2, real_t> ret_indices =
-        ret[1].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-      ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
-      ret_indices = tcast<real_t>(F<mshadow_op::mod>(
-                      slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)),
-                               0, k),
-                      element_num));
+      Tensor<xpu, 2, DType> ret_value =
+        ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
+      Tensor<xpu, 2, IDType> ret_indices =
+        ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+      ASSIGN_DISPATCH(ret_value, req[0],
+             slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
+      ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
+                 inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
     }
   }
 }
@@ -508,9 +527,17 @@ void TopK(const nnvm::NodeAttrs& attrs,
           const std::vector<OpReqType>& req,
           const std::vector<TBlob>& outputs) {
   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
-  // TODO(sxjscience) We can support inplace in the future
-  CHECK_EQ(req[0], kWriteTo) << "TopK does not support inplace";
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, param);
+  if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
+      })
+    });
+  } else {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
+    });
+  }
 }
 
 template<typename xpu>
@@ -520,13 +547,14 @@ void Sort(const nnvm::NodeAttrs& attrs,
           const std::vector<OpReqType>& req,
           const std::vector<TBlob>& outputs) {
   const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "Sort does not support inplace";
   TopKParam topk_param;
   topk_param.axis = param.axis;
   topk_param.is_ascend = param.is_ascend;
   topk_param.k = 0;
   topk_param.ret_typ = topk_enum::kReturnValue;
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param);
+  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param);
+  });
 }
 
 template<typename xpu>
@@ -536,26 +564,30 @@ void ArgSort(const nnvm::NodeAttrs& attrs,
              const std::vector<OpReqType>& req,
              const std::vector<TBlob>& outputs) {
   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
-  CHECK_EQ(req[0], kWriteTo) << "ArgSort does not support inplace";
   TopKParam topk_param;
   topk_param.axis = param.axis;
   topk_param.is_ascend = param.is_ascend;
   topk_param.k = 0;
+  topk_param.dtype = param.dtype;
   topk_param.ret_typ = topk_enum::kReturnIndices;
-  TopKImpl<xpu>(ctx.run_ctx, ctx.requested[0], inputs[0], outputs, topk_param);
+  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+      TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
+                                   ctx.requested[0], req, inputs[0], outputs, topk_param);
+    });
+  });
 }
 
-template<typename xpu>
-void TopKBackward_(const nnvm::NodeAttrs& attrs,
-                  const OpContext& ctx,
-                  const std::vector<TBlob>& inputs,
-                  const std::vector<OpReqType>& req,
-                  const std::vector<TBlob>& outputs) {
+template<typename xpu, typename DType, typename IDType>
+void TopKBackwardImpl(const OpContext &ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs,
+                      const TopKParam& param) {
   CHECK_NE(req[0], kWriteInplace);
   using namespace mshadow;
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
-  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
   CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
@@ -565,23 +597,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
   TShape target_shape;
   ParseTopKParam(outputs[0].shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
-  Tensor<xpu, 1, real_t> workspace =
-    ctx.requested[0].get_space_typed<xpu, 1, real_t>(Shape1(batch_size * k * 2 + batch_size), s);
-  Tensor<xpu, 1, real_t> sel_indices =
-    Tensor<xpu, 1, real_t>(workspace.dptr_, Shape1(batch_size * k), s);
-  Tensor<xpu, 1, real_t> batch_shift =
-    Tensor<xpu, 1, real_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
-  Tensor<xpu, 1, real_t> dummy_index =
-    Tensor<xpu, 1, real_t>(workspace.dptr_ + batch_size * k + batch_size,
+  CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
+    << "'IDType' does not have a sufficient precision to represent the indices of the input array. "
+    << "The total element_num is " << element_num << ", but the selected IDType can only represent "
+    << mxnet::common::MaxIntegerValue<IDType>() << " elements";
+  Tensor<xpu, 1, int> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k * 2 + batch_size), s);
+  Tensor<xpu, 1, int> sel_indices =
+    Tensor<xpu, 1, int>(workspace.dptr_, Shape1(batch_size * k), s);
+  Tensor<xpu, 1, int> batch_shift =
+    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
+  Tensor<xpu, 1, int> dummy_index =
+    Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k + batch_size,
                            Shape1(batch_size * k), s);
-  Tensor<xpu, 2, real_t> out_grad =
-    inputs[0].get_with_shape<xpu, 2, real_t>(Shape2(inputs[0].shape_.Size(), 1), s);
-  Tensor<xpu, 2, real_t> in_grad =
-    outputs[0].get_with_shape<xpu, 2, real_t>(Shape2(outputs[0].shape_.Size(), 1), s);
-  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0.0f,
-    static_cast<real_t>(element_num), kWriteTo, batch_shift.dptr_);
+
+  Tensor<xpu, 2, DType> out_grad =
+    inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
+  Tensor<xpu, 2, DType> in_grad =
+    outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
+  mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0, element_num, kWriteTo,
+                                           batch_shift.dptr_);
   if (do_transpose) {
-    Tensor<xpu, 1, real_t> indices = inputs[2].FlatTo1D<xpu, real_t>(s);
+    Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
     TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
     sel_indices = reshape(transpose(
                             broadcast_to(inplace_reshape(batch_shift,
@@ -589,26 +626,26 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
                                          TShape(Shape3(src_shape[0], src_shape[2], k))),
                             Shape3(0, 2, 1)),
                           Shape1(batch_size * k));
-    sel_indices += indices;
+    sel_indices += tcast<int>(indices);
     sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
                                     Shape3(0, 2, 1));
   } else {
-    Tensor<xpu, 2, real_t> indices =
-      inputs[2].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
-    sel_indices = reshape(indices +
+    Tensor<xpu, 2, IDType> indices =
+      inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
+    sel_indices = reshape(tcast<int>(indices) +
                           broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
                                        TShape(Shape2(batch_size, k))),
                           Shape1(batch_size * k));
   }
   CHECK_EQ(sel_indices.CheckContiguous(), true);
   if (kWriteTo == req[0]) {
-    in_grad = scalar<real_t>(0);
+    in_grad = scalar<DType>(0);
     IndexFill(in_grad, sel_indices, out_grad);
   } else if (kAddTo == req[0]) {
     // TODO(sxjscience) We can use AddTakeGrad in the future.
     // However, the current implementation of AddTakeGrad is not so efficient.
-    mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0.0f,
-      1.0f, kWriteTo, dummy_index.dptr_);
+    mxnet_op::Kernel<range_fwd, xpu>::Launch(s, sel_indices.shape_.Size(), 1, 0, 1, kWriteTo,
+                                             dummy_index.dptr_);
     mxnet::op::AddTakeGradLargeBatch(in_grad, sel_indices, dummy_index, out_grad);
   } else if (kNullOp == req[0]) {
     return;
@@ -617,6 +654,28 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu>
+void TopKBackward_(const nnvm::NodeAttrs& attrs,
+                   const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
+  if (param.ret_typ == topk_enum::kReturnBoth) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
+      });
+    });
+  } else if (param.ret_typ == topk_enum::kReturnValue) {
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      TopKBackwardImpl<xpu, DType, int>(ctx, inputs, req, outputs, param);
+    });
+  } else {
+    LOG(FATAL) << "Not Implemented";
+  }
+}
+
 inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) {
   const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
   if (param.ret_typ == topk_enum::kReturnIndices ||
@@ -639,8 +698,36 @@ inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) {
 inline bool TopKType(const nnvm::NodeAttrs& attrs,
                      std::vector<int> *in_attrs,
                      std::vector<int> *out_attrs) {
-  return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
-    attrs, in_attrs, out_attrs, -1);
+  const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
+  int data_type = -1;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  CHECK_EQ(in_size, 1);
+  CHECK(out_size == 1 || out_size == 2);
+  if (out_size > 1) {
+    if (param.ret_typ == topk_enum::kReturnValue) {
+      CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+        << "Failed to set the type of ret_indices.";
+    } else {
+      CHECK(type_assign(&(*out_attrs)[1], param.dtype))
+        << "Failed to set the type of ret_indices.";
+    }
+  }
+  if (param.ret_typ == topk_enum::kReturnIndices) {
+    CHECK(type_assign(&(*out_attrs)[0], param.dtype))
+            << "Failed to set the type of ret_indices.";
+  } else {
+    CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
+                                                   << (*in_attrs)[0];
+    CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
+                                                    << (*out_attrs)[0];
+    CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
+                                                   << (*in_attrs)[0];
+    CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
+                                                    << (*out_attrs)[0];
+    if (data_type == -1) return false;
+  }
+  return true;
 }
 
 inline bool TopKShapeImpl(const TopKParam& param,
@@ -679,6 +766,28 @@ inline bool TopKShape(const nnvm::NodeAttrs& attrs,
   return TopKShapeImpl(param, in_attrs, out_attrs);
 }
 
+inline bool SortType(const nnvm::NodeAttrs& attrs,
+                     std::vector<int> *in_attrs,
+                     std::vector<int> *out_attrs) {
+  int data_type = -1;
+  size_t in_size = in_attrs->size();
+  size_t out_size = out_attrs->size();
+  CHECK_EQ(in_size, 1);
+  CHECK_EQ(out_size, 2);
+  CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
+          << "Failed to set the type of ret_indices to int32.";
+  CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
+                                                 << (*in_attrs)[0];
+  CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
+                                                  << (*out_attrs)[0];
+  CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
+                                                 << (*in_attrs)[0];
+  CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
+                                                  << (*out_attrs)[0];
+  if (data_type == -1) return false;
+  return true;
+}
+
 inline bool SortShape(const nnvm::NodeAttrs& attrs,
                       std::vector<TShape> *in_attrs,
                       std::vector<TShape> *out_attrs) {
@@ -691,6 +800,15 @@ inline bool SortShape(const nnvm::NodeAttrs& attrs,
   return TopKShapeImpl(topk_param, in_attrs, out_attrs);
 }
 
+inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
+  CHECK(type_assign(&(*out_attrs)[0], param.dtype))
+          << "Failed to set the type of ret_indices to int32.";
+  return true;
+}
+
 inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
                          std::vector<TShape> *in_attrs,
                          std::vector<TShape> *out_attrs) {
diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc
index ebd7c62ec88..1e2832d3763 100644
--- a/src/operator/tensor/ordering_op.cc
+++ b/src/operator/tensor/ordering_op.cc
@@ -35,6 +35,7 @@ DMLC_REGISTER_PARAMETER(ArgSortParam);
 
 NNVM_REGISTER_OP(topk)
 .describe(R"code(Returns the top *k* elements in an input array along the given axis.
+ The returned elements will be sorted.
 
 Examples::
 
@@ -128,7 +129,7 @@ Examples::
 .set_num_outputs(2)
 .set_attr_parser(ParamParser<SortParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", SortShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 2>)
+.set_attr<nnvm::FInferType>("FInferType", SortType)
 .set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; })
 .set_attr<FCompute>("FCompute<cpu>", Sort<cpu>)
 .set_attr<nnvm::FGradient>("FGradient",
@@ -178,7 +179,7 @@ Examples::
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<ArgSortParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", ArgSortShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ArgSortType)
 .set_attr<FCompute>("FCompute<cpu>", ArgSort<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FResourceRequest>("FResourceRequest",
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 2db39d5dd53..c9bc0cd1e1e 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -661,17 +661,19 @@ def gt_topk(dat, axis, ret_typ, k, is_ascend):
     # values, making it hard to generate a numpy 'golden copy' to compare against
     # the mxnet operator.  The 'mask' function is particularly hard to test given that
     # equal values might span the 'k' boundary.  Issue exposed with seed 1405838964.
-    def get_values(ensure_unique):
-        while True:
-            data = np.float32(np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)))
-            if not ensure_unique:
-                return data
-            num_unique_values = len(set(data.flatten()))
-            if data.size == num_unique_values:
-                return data
-
-    a_npy = get_values(ensure_unique=True)
-    a_nd = mx.nd.array(a_npy, ctx=ctx)
+    def get_values(ensure_unique, dtype):
+        if dtype == np.int16 or dtype == np.int32 or dtype == np.int64:
+            return np.arange(dat_size ** 4, dtype=dtype).reshape((dat_size, dat_size, dat_size, dat_size))
+        elif dtype == np.float32 or dtype == np.float64:
+            while True:
+                data = np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)).astype(dtype)
+                if not ensure_unique:
+                    return data
+                num_unique_values = len(set(data.flatten()))
+                if data.size == num_unique_values:
+                    return data
+        else:
+            raise NotImplementedError
 
     # Produce a large matrix (256, 300096) as the input data, to cover the case which
     # has a large size of matrix (exceed the express range by float precisly), but
@@ -685,103 +687,161 @@ def get_large_matrix():
     large_matrix_npy = get_large_matrix()
     large_matrix_nd = mx.nd.array(large_matrix_npy, ctx=ctx)
 
-    # test for ret_typ=indices
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
     nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=1, ret_typ="indices", k=5, is_ascend=False).asnumpy()
     gt = gt_topk(large_matrix_npy, axis=1, ret_typ="indices", k=5, is_ascend=False)
     assert_almost_equal(nd_ret_topk, gt)
 
-    # test for ret_typ=value
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=0, ret_typ="value", k=3, is_ascend=False).asnumpy()
-    gt = gt_topk(large_matrix_npy, axis=0, ret_typ="value", k=3, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=1, ret_typ="value", k=5, is_ascend=False).asnumpy()
-    gt = gt_topk(large_matrix_npy, axis=1, ret_typ="value", k=5, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for ret_typ=mask
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for ret_typ=both
-    nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
-    nd_ret_topk_val = nd_ret_topk_val.asnumpy()
-    nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
-    gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk_val, gt_val)
-    assert_almost_equal(nd_ret_topk_ind, gt_ind)
-
-    # test for sort
-    nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_sort, gt)
-    nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_sort, gt)
-
-    # test for argsort
-    nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_argsort, gt)
-    nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="indices",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_argsort, gt)
-
-    a = mx.nd.arange(0, 1024, step=1, repeat=1)
-    assert_almost_equal(a.topk(k=1024).asnumpy(), a.asnumpy()[::-1])
+    for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]:
+        a_npy = get_values(ensure_unique=True, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+
+        # test for ret_typ=indices
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for ret_typ=mask
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for ret_typ=both
+        nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_val = nd_ret_topk_val.asnumpy()
+        nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
+        gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk_val, gt_val)
+        assert_almost_equal(nd_ret_topk_ind, gt_ind)
+        # test for kNullOp
+        _, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
+        assert_almost_equal(nd_ret_topk_ind, gt_ind)
+        # test for kNullOp
+        nd_ret_topk_val, _ = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
+        nd_ret_topk_val = nd_ret_topk_val.asnumpy()
+        assert_almost_equal(nd_ret_topk_val, gt_val)
+
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
+
+        # test for argsort
+        for idtype in [np.int32, np.float16, np.float32, np.float64]:
+            nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy()
+            gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
+            assert_almost_equal(nd_ret_argsort, gt)
+            nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy()
+            gt = gt_topk(a_npy, axis=None, ret_typ="indices",
+                         k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+            assert_almost_equal(nd_ret_argsort, gt)
+
+        # Repeat those tests that don't involve indices.  These should pass even with
+        # duplicated input data values (over many repeated runs with different random seeds,
+        # this will be tested).
+        a_npy = get_values(ensure_unique=False, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
+
+    a = mx.nd.arange(0, 1024, step=1, repeat=1, dtype=np.int32)
+    assert_almost_equal(a.topk(k=1024, dtype=np.int32).asnumpy(), a.asnumpy()[::-1])
+    a.attach_grad()
+
+    k = 10
+    with mx.autograd.record():
+        b = mx.nd.topk(a, k=k, ret_typ='value')
+        b.backward(mx.nd.ones((k,), dtype=np.int32))
+    a_grad = a.grad.asnumpy()
+    for i in range(-1, - k - 1, -1):
+        assert a_grad[i] == 1
+
+    # test topk gradient with a small shape
+    for dtype in [np.int32, np.int64, np.float32, np.float64]:
+        a = mx.nd.arange(0, 1000, step=1, repeat=1, dtype=dtype)
+        a.attach_grad()
+        k = 10
+        ograd = mx.nd.arange(0, k, dtype=dtype)
+        with mx.autograd.record():
+            b = mx.nd.topk(a, k=k, ret_typ='value')
+            b.backward(ograd)
+        a_grad = a.grad.asnumpy()
+        ograd_npy = ograd.asnumpy()
+        for i in range(-1, - k - 1, -1):
+            assert a_grad[i] == ograd_npy[-i - 1]
 
     # Repeat those tests that don't involve indices.  These should pass even with
     # duplicated input data values (over many repeated runs with different random seeds,
     # this will be tested).
-    a_npy = get_values(ensure_unique=False)
-    a_nd = mx.nd.array(a_npy, ctx=ctx)
-
-    # test for ret_typ=value
-    nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-    nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
-    assert_almost_equal(nd_ret_topk, gt)
-
-    # test for sort
-    nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
-    gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
-    assert_almost_equal(nd_ret_sort, gt)
-    nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
-    gt = gt_topk(a_npy, axis=None, ret_typ="value",
-                 k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
-    assert_almost_equal(nd_ret_sort, gt)
+    for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64]:
+        a_npy = get_values(ensure_unique=False, dtype=dtype)
+        a_nd = mx.nd.array(a_npy, ctx=ctx)
+
+        # test for ret_typ=value
+        nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+        nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
+        assert_almost_equal(nd_ret_topk, gt)
+
+        # test for sort
+        nd_ret_sort = mx.nd.sort(a_nd, axis=1, is_ascend=True).asnumpy()
+        gt = gt_topk(a_npy, axis=1, ret_typ="value", k=dat_size, is_ascend=True)
+        assert_almost_equal(nd_ret_sort, gt)
+        nd_ret_sort = mx.nd.sort(a_nd, axis=None, is_ascend=False).asnumpy()
+        gt = gt_topk(a_npy, axis=None, ret_typ="value",
+                     k=dat_size*dat_size*dat_size*dat_size, is_ascend=False)
+        assert_almost_equal(nd_ret_sort, gt)
 
 @with_seed()
 def test_ndarray_equal():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services