You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/08/20 21:11:19 UTC
[incubator-mxnet] branch master updated: fix potential floating
number overflow, enable float16 (#12118)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c479eb2 fix potential floating number overflow, enable float16 (#12118)
c479eb2 is described below
commit c479eb24eaab8857dca254ea76c1179b0f6fe36f
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Mon Aug 20 14:11:01 2018 -0700
fix potential floating number overflow, enable float16 (#12118)
* fix potential floating number overflow, enable float16
* fix cuda impl
* fix cuda imple
* fix template substitution for windows
* half_f substantiate operand + fix
* remove ambiguous operand + for mshadow half_T
* fix con't
* use int32_t as indices
* use overload
* try remove ambiguous function overloading
* thrust version limit
* change sizeof cast from floor to ceil when allocating buffers
* cleaner
* fix alignment of pointers
---
src/operator/contrib/bounding_box-inl.cuh | 4 +-
src/operator/contrib/bounding_box-inl.h | 86 +++++++++-------
src/operator/tensor/sort_op-inl.cuh | 135 ++++++++++++++++++++++---
tests/python/unittest/test_contrib_operator.py | 25 ++---
4 files changed, 184 insertions(+), 66 deletions(-)
diff --git a/src/operator/contrib/bounding_box-inl.cuh b/src/operator/contrib/bounding_box-inl.cuh
index fb1dacc..fd5e30b 100644
--- a/src/operator/contrib/bounding_box-inl.cuh
+++ b/src/operator/contrib/bounding_box-inl.cuh
@@ -45,9 +45,9 @@ struct valid_score {
template<typename DType>
int FilterScores(mshadow::Tensor<gpu, 1, DType> out_scores,
- mshadow::Tensor<gpu, 1, DType> out_sorted_index,
+ mshadow::Tensor<gpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<gpu, 1, DType> scores,
- mshadow::Tensor<gpu, 1, DType> sorted_index,
+ mshadow::Tensor<gpu, 1, int32_t> sorted_index,
float valid_thresh) {
valid_score<DType> pred(static_cast<DType>(valid_thresh));
DType * end_scores = thrust::copy_if(thrust::device, scores.dptr_, scores.dptr_ + scores.MSize(),
diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h
index f739dbc..8e96346 100644
--- a/src/operator/contrib/bounding_box-inl.h
+++ b/src/operator/contrib/bounding_box-inl.h
@@ -150,9 +150,9 @@ inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& attrs) {
template<typename DType>
int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
- mshadow::Tensor<cpu, 1, DType> out_sorted_index,
+ mshadow::Tensor<cpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<cpu, 1, DType> scores,
- mshadow::Tensor<cpu, 1, DType> sorted_index,
+ mshadow::Tensor<cpu, 1, int32_t> sorted_index,
float valid_thresh) {
index_t j = 0;
for (index_t i = 0; i < scores.size(0); i++) {
@@ -230,7 +230,7 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) {
/*!
* \brief compute areas specialized for nms to reduce computation
- *
+ *
* \param i the launched thread index (total thread num_batch * topk)
* \param out 1d array for areas (size num_batch * num_elem)
* \param in 1st coordinate of 1st box (buffer + coord_start)
@@ -243,7 +243,7 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) {
struct compute_area {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
- const DType *indices, const DType *batch_start,
+ const int32_t *indices, const int32_t *batch_start,
int topk, int num_elem, int stride, int encode) {
int b = i / topk;
int k = i % topk;
@@ -302,7 +302,7 @@ MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
*/
struct nms_impl {
template<typename DType>
- MSHADOW_XINLINE static void Map(int i, DType *index, const DType *batch_start,
+ MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start,
const DType *input, const DType *areas,
int k, int ref, int num,
int stride, int offset_box, int offset_id,
@@ -326,8 +326,7 @@ struct nms_impl {
intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
int ref_area_offset = static_cast<int>(index[ref]);
int pos_area_offset = static_cast<int>(index[pos]);
- DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] -
- intersect);
+ DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect);
if (iou > thresh) {
index[pos] = -1;
}
@@ -336,7 +335,7 @@ struct nms_impl {
/*!
* \brief Assign output of nms by indexing input
- *
+ *
* \param i the launched thread index (total num_batch)
* \param out output array [cls, conf, b0, b1, b2, b3]
* \param record book keeping the selected index for backward
@@ -349,7 +348,7 @@ struct nms_impl {
struct nms_assign {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, DType *record, const DType *input,
- const DType *index, const DType *batch_start,
+ const int32_t *index, const int32_t *batch_start,
int k, int num, int stride) {
int count = 0;
for (int j = 0; j < k; ++j) {
@@ -404,7 +403,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
int num_elem = in_shape[indim - 2];
int width_elem = in_shape[indim - 1];
- MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 3, DType> data = inputs[box_nms_enum::kData]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Tensor<xpu, 3, DType> out = outputs[box_nms_enum::kOut]
@@ -415,25 +414,33 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
// prepare workspace
Shape<1> sort_index_shape = Shape1(num_batch * num_elem);
Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem);
- index_t workspace_size = 4 * sort_index_shape.Size();
Shape<1> batch_start_shape = Shape1(num_batch + 1);
- workspace_size += batch_start_shape.Size();
+
+ // index
+ index_t int32_size = sort_index_shape.Size() * 3 + batch_start_shape.Size();
+ index_t dtype_size = sort_index_shape.Size() * 2;
if (req[0] == kWriteInplace) {
- workspace_size += buffer_shape.Size();
+ dtype_size += buffer_shape.Size();
}
+ // ceil up when sizeof(DType) is larger than sizeof(DType)
+ index_t int32_offset = (int32_size * sizeof(int32_t) - 1) / sizeof(DType) + 1;
+ index_t workspace_size = int32_offset + dtype_size;
Tensor<xpu, 1, DType> workspace = ctx.requested[box_nms_enum::kTempSpace]
.get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
- Tensor<xpu, 1, DType> sorted_index(workspace.dptr_, sort_index_shape, s);
- Tensor<xpu, 1, DType> scores(sorted_index.dptr_ + sorted_index.MSize(),
+ Tensor<xpu, 1, int32_t> sorted_index(
+ reinterpret_cast<int32_t*>(workspace.dptr_), sort_index_shape, s);
+ Tensor<xpu, 1, int32_t> all_sorted_index(sorted_index.dptr_ + sorted_index.MSize(),
sort_index_shape, s);
- Tensor<xpu, 1, DType> batch_id(scores.dptr_ + scores.MSize(), sort_index_shape,
- s);
- Tensor<xpu, 1, DType> areas(batch_id.dptr_ + batch_id.MSize(), sort_index_shape, s);
- Tensor<xpu, 1, DType> batch_start(areas.dptr_ + areas.MSize(), batch_start_shape, s);
+ Tensor<xpu, 1, int32_t> batch_id(
+ all_sorted_index.dptr_ + all_sorted_index.MSize(), sort_index_shape, s);
+ Tensor<xpu, 1, int32_t> batch_start(batch_id.dptr_ + batch_id.MSize(), batch_start_shape, s);
+ Tensor<xpu, 1, DType> scores(workspace.dptr_ + int32_offset,
+ sort_index_shape, s);
+ Tensor<xpu, 1, DType> areas(scores.dptr_ + scores.MSize(), sort_index_shape, s);
Tensor<xpu, 3, DType> buffer = data;
if (req[0] == kWriteInplace) {
// make copy
- buffer = Tensor<xpu, 3, DType>(batch_start.dptr_ + batch_start.MSize(), buffer_shape, s);
+ buffer = Tensor<xpu, 3, DType>(areas.dptr_ + areas.MSize(), buffer_shape, s);
buffer = F<mshadow_op::identity>(data);
}
@@ -451,10 +458,10 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
}
// use batch_id and areas as temporary storage
- Tensor<xpu, 1, DType> all_scores = batch_id;
- Tensor<xpu, 1, DType> all_sorted_index = areas;
+ Tensor<xpu, 1, DType> all_scores = areas;
+ // Tensor<xpu, 1, DType> all_sorted_index = areas;
all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_);
- all_sorted_index = range<DType>(0, num_batch * num_elem);
+ all_sorted_index = range<int32_t>(0, num_batch * num_elem);
// filter scores but keep original sorted_index value
// move valid score and index to the front, return valid size
@@ -474,19 +481,19 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
// only sort the valid scores and batch_id
Shape<1> valid_score_shape = Shape1(num_valid);
Tensor<xpu, 1, DType> valid_scores(scores.dptr_, valid_score_shape, s);
- Tensor<xpu, 1, DType> valid_sorted_index(sorted_index.dptr_, valid_score_shape, s);
- Tensor<xpu, 1, DType> valid_batch_id(batch_id.dptr_, valid_score_shape, s);
+ Tensor<xpu, 1, int32_t> valid_sorted_index(sorted_index.dptr_, valid_score_shape, s);
+ Tensor<xpu, 1, int32_t> valid_batch_id(batch_id.dptr_, valid_score_shape, s);
// sort index by batch_id then score (stable sort)
mxnet::op::SortByKey(valid_scores, valid_sorted_index, false);
- valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / ScalarExp<DType>(num_elem));
+ valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
mxnet::op::SortByKey(valid_batch_id, valid_sorted_index, true);
// calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index
- valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / ScalarExp<DType>(num_elem));
+ valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
- F<mshadow_op::less_than>(valid_batch_id, ScalarExp<DType>(b)), 0);
+ F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
}
// pre-compute areas of candidates
@@ -721,11 +728,11 @@ inline bool MatchingShape(const nnvm::NodeAttrs& attrs,
struct bipartite_matching {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *row_marker, DType *col_marker,
- const DType *scores, const DType *sorted_index,
+ const DType *scores, const int32_t *sorted_index,
int num_batch, int num_row, int num_col,
float threshold, bool is_ascend, int topk) {
int stride = num_row * num_col;
- const DType *index = sorted_index + i * stride;
+ const int32_t *index = sorted_index + i * stride;
const DType *score = scores + i * stride;
DType *rmarker = row_marker + i * num_row;
DType *cmarker = col_marker + i * num_col;
@@ -769,7 +776,7 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs,
int row = dshape[dshape.ndim() - 2];
int col = dshape[dshape.ndim() - 1];
int batch_size = dshape.Size() / row / col;
- MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> scores = inputs[0]
.get_with_shape<xpu, 1, DType>(Shape1(dshape.Size()), s);
Tensor<xpu, 2, DType> row_marker = outputs[0]
@@ -777,23 +784,24 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 2, DType> col_marker = outputs[1]
.get_with_shape<xpu, 2, DType>(Shape2(batch_size, col), s);
Shape<1> sort_index_shape = Shape1(dshape.Size());
- index_t workspace_size = sort_index_shape.Size() * 3;
+ index_t workspace_size = sort_index_shape.Size();
+ workspace_size += ((sort_index_shape.Size() * sizeof(int32_t) - 1) / sizeof(DType)) * 2;
Tensor<xpu, 1, DType> workspace = ctx.requested[0]
.get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
- Tensor<xpu, 1, DType> sorted_index(workspace.dptr_,
- sort_index_shape, s);
- Tensor<xpu, 1, DType> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
+ Tensor<xpu, 1, DType> scores_copy(workspace.dptr_,
sort_index_shape, s);
- Tensor<xpu, 1, DType> scores_copy(batch_id.dptr_ + batch_id.MSize(),
+ Tensor<xpu, 1, int32_t> sorted_index(reinterpret_cast<int32_t*>(
+ scores_copy.dptr_ + scores_copy.MSize()), sort_index_shape, s);
+ Tensor<xpu, 1, int32_t> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
sort_index_shape, s);
// sort according to score
scores_copy = F<mshadow_op::identity>(scores);
- sorted_index = range<DType>(0, dshape.Size());
+ sorted_index = range<int32_t>(0, dshape.Size());
mxnet::op::SortByKey(scores_copy, sorted_index, param.is_ascend);
- batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(row * col));
+ batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
mxnet::op::SortByKey(batch_id, scores_copy, true);
- batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(row * col));
+ batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
mxnet::op::SortByKey(batch_id, sorted_index, true);
// bipartite matching, parallelization is limited to batch_size
diff --git a/src/operator/tensor/sort_op-inl.cuh b/src/operator/tensor/sort_op-inl.cuh
index 5ad3105..1a8e232 100644
--- a/src/operator/tensor/sort_op-inl.cuh
+++ b/src/operator/tensor/sort_op-inl.cuh
@@ -24,6 +24,7 @@
*/
#ifndef MXNET_OPERATOR_TENSOR_SORT_OP_INL_CUH_
#define MXNET_OPERATOR_TENSOR_SORT_OP_INL_CUH_
+#include <type_traits>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#if defined(_MSC_VER) && __CUDACC_VER_MAJOR__ == 8 && __CUDACC_VER_BUILD__ != 44
@@ -40,6 +41,29 @@
namespace mxnet {
namespace op {
+namespace cuda {
+template<typename T>
+struct less_half
+{
+ typedef T first_argument_type;
+ typedef T second_argument_type;
+ typedef bool result_type;
+ __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
+ return static_cast<mshadow::half::half_t>(lhs) < static_cast<mshadow::half::half_t>(rhs);
+ }
+};
+
+template<typename T>
+struct greater_half
+{
+ typedef T first_argument_type;
+ typedef T second_argument_type;
+ typedef bool result_type;
+ __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
+ return static_cast<mshadow::half::half_t>(lhs) < static_cast<mshadow::half::half_t>(rhs);
+ }
+};
+}
template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
@@ -57,9 +81,12 @@ SortByKeyWorkspaceSize(const size_t num_keys) {
}
template<typename KDType, typename VDType>
-inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu, 1, VDType> values,
- bool is_ascend, mshadow::Tensor<gpu, 1, char>* workspace,
- const int begin_bit, const int end_bit) {
+inline typename std::enable_if<!(std::is_same<KDType,mshadow::half::half_t>::value ||
+ std::is_same<VDType,mshadow::half::half_t>::value), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+ mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+ mshadow::Tensor<gpu, 1, char>* workspace,
+ const int begin_bit, const int end_bit) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 7000
@@ -128,18 +155,100 @@ inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu,
#endif
}
-template<typename DType>
-inline void SortByKey(mshadow::Tensor<gpu, 1, mshadow::half::half_t> keys,
- mshadow::Tensor<gpu, 1, DType> values, bool is_ascend,
- mshadow::Tensor<gpu, 1, char>* workspace, const int begin_bit, const int end_bit) {
- LOG(FATAL) << "SortByKey for half_t is not implemented!";
+template<typename KDType, typename VDType>
+inline typename std::enable_if<((!std::is_same<KDType,mshadow::half::half_t>::value) &&
+ std::is_same<VDType,mshadow::half::half_t>::value), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+ mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+ mshadow::Tensor<gpu, 1, char>* workspace,
+ const int begin_bit, const int end_bit) {
+ CHECK_EQ(keys.CheckContiguous(), true);
+ CHECK_EQ(values.CheckContiguous(), true);
+#if CUDA_VERSION >= 9000
+ cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
+ thrust::device_ptr<KDType> key_iter = thrust::device_pointer_cast(keys.dptr_);
+ thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
+ reinterpret_cast<half*>(values.dptr_));
+ if (is_ascend) {
+ thrust::stable_sort_by_key(
+ thrust::cuda::par.on(stream),
+ key_iter.get(), key_iter.get() + (keys.size(0)), value_iter.get(), thrust::less<KDType>());
+ } else {
+ thrust::stable_sort_by_key(
+ thrust::cuda::par.on(stream),
+ key_iter.get(), key_iter.get() + (keys.size(0)), value_iter.get(), thrust::greater<KDType>());
+ }
+ MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
+#else
+ LOG(FATAL) << "SortByKey with fp16 values is only supported for CUDA version >= 9.0";
+#endif
+}
+
+template<typename KDType, typename VDType>
+inline typename std::enable_if<(std::is_same<KDType,mshadow::half::half_t>::value &&
+ (!std::is_same<VDType,mshadow::half::half_t>::value)), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+ mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+ mshadow::Tensor<gpu, 1, char>* workspace,
+ const int begin_bit, const int end_bit) {
+ CHECK_EQ(keys.CheckContiguous(), true);
+ CHECK_EQ(values.CheckContiguous(), true);
+#if CUDA_VERSION >= 9000
+ cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
+ thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
+ reinterpret_cast<half*>(keys.dptr_));
+ thrust::device_ptr<VDType> value_iter = thrust::device_pointer_cast(values.dptr_);
+ if (is_ascend) {
+ thrust::stable_sort_by_key(
+ thrust::cuda::par.on(stream),
+ key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
+ } else {
+ thrust::stable_sort_by_key(
+ thrust::cuda::par.on(stream),
+ key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
+ }
+ MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
+#else
+ LOG(FATAL) << "SortByKey with fp16 keys is only supported for CUDA version >= 9.0";
+#endif
+}
+
+// use thrust sorting when keys or values are half_t
+template<typename KDType, typename VDType>
+inline typename std::enable_if<(std::is_same<KDType,mshadow::half::half_t>::value &&
+ std::is_same<VDType,mshadow::half::half_t>::value), void>::type
+SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
+ mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
+ mshadow::Tensor<gpu, 1, char>* workspace,
+ const int begin_bit, const int end_bit) {
+ CHECK_EQ(keys.CheckContiguous(), true);
+ CHECK_EQ(values.CheckContiguous(), true);
+#if CUDA_VERSION >= 9000
+ cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
+ thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
+ reinterpret_cast<half*>(keys.dptr_));
+ thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
+ reinterpret_cast<half*>(values.dptr_));
+ if (is_ascend) {
+ thrust::stable_sort_by_key(
+ thrust::cuda::par.on(stream),
+ key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
+ } else {
+ thrust::stable_sort_by_key(
+ thrust::cuda::par.on(stream),
+ key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
+ }
+ MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
+#else
+ LOG(FATAL) << "SortByKey with fp16 keys and values is only supported for CUDA version >= 9.0";
+#endif
}
-template<typename DType>
-inline void SortByKey(mshadow::Tensor<gpu, 1, DType> keys,
- mshadow::Tensor<gpu, 1, mshadow::half::half_t> values, bool is_ascend,
- mshadow::Tensor<gpu, 1, char>* workspace, const int begin_bit, const int end_bit) {
- LOG(FATAL) << "SortByKey for half_t is not implemented!";
+template<typename KDType, typename VDType>
+inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu, 1, VDType> values,
+ bool is_ascend, mshadow::Tensor<gpu, 1, char>* workspace,
+ const int begin_bit, const int end_bit) {
+ SortByKeyImpl(keys, values, is_ascend, workspace, begin_bit, end_bit);
}
} // namespace op
diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py
index a220f08..fc6c1be 100644
--- a/tests/python/unittest/test_contrib_operator.py
+++ b/tests/python/unittest/test_contrib_operator.py
@@ -28,11 +28,12 @@ import unittest
def test_box_nms_op():
def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1, cid=0,
force=False, in_format='corner', out_format='corner'):
- data = mx.nd.array(data)
- out = mx.contrib.nd.box_nms(data, overlap_thresh=thresh, valid_thresh=valid, topk=topk,
- coord_start=coord, score_index=score, id_index=cid,
- force_suppress=force, in_format=in_format, out_format=out_format)
- assert_almost_equal(out.asnumpy(), expected)
+ for dtype in ['float16', 'float32', 'float64']:
+ data = mx.nd.array(data, dtype=dtype)
+ out = mx.contrib.nd.box_nms(data, overlap_thresh=thresh, valid_thresh=valid, topk=topk,
+ coord_start=coord, score_index=score, id_index=cid,
+ force_suppress=force, in_format=in_format, out_format=out_format)
+ assert_almost_equal(out.asnumpy(), expected.astype(dtype), rtol=1e-3, atol=1e-3)
def test_box_nms_backward(data, grad, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1,
cid=0, force=False, in_format='corner', out_format='corner'):
@@ -233,13 +234,13 @@ def test_box_iou_op():
def test_bipartite_matching_op():
def assert_match(inputs, x, y, threshold, is_ascend=False):
- inputs = mx.nd.array(inputs)
- x = np.array(x)
- y = np.array(y)
- a, b = mx.nd.contrib.bipartite_matching(inputs, threshold=threshold, is_ascend=is_ascend)
- print(a, b)
- assert_array_equal(a.asnumpy().astype('int64'), x.astype('int64'))
- assert_array_equal(b.asnumpy().astype('int64'), y.astype('int64'))
+ for dtype in ['float16', 'float32', 'float64']:
+ inputs = mx.nd.array(inputs, dtype=dtype)
+ x = np.array(x, dtype=dtype)
+ y = np.array(y, dtype=dtype)
+ a, b = mx.nd.contrib.bipartite_matching(inputs, threshold=threshold, is_ascend=is_ascend)
+ assert_array_equal(a.asnumpy().astype('int64'), x.astype('int64'))
+ assert_array_equal(b.asnumpy().astype('int64'), y.astype('int64'))
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False)
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True)