You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/11 16:14:45 UTC

[GitHub] sxjscience closed pull request #12446: [WIP][Bugfix] Fix flaky topk

sxjscience closed pull request #12446: [WIP][Bugfix] Fix flaky topk
URL: https://github.com/apache/incubator-mxnet/pull/12446
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/3rdparty/mshadow b/3rdparty/mshadow
index 696803bd772..463c0dffe3e 160000
--- a/3rdparty/mshadow
+++ b/3rdparty/mshadow
@@ -1 +1 @@
-Subproject commit 696803bd7723ade8230af878460d96c68a550fbc
+Subproject commit 463c0dffe3eae8c39caf7989c85b7244823df27e
diff --git a/3rdparty/tvm b/3rdparty/tvm
index 0f053c82a74..290226e1c9a 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit 0f053c82a747b4dcdf49570ec87c17e0067b7439
+Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33
diff --git a/ci/docker/install/ubuntu_gcc8.sh b/ci/docker/install/ubuntu_gcc8.sh
old mode 100755
new mode 100644
diff --git a/ci/docker/install/ubuntu_julia.sh b/ci/docker/install/ubuntu_julia.sh
old mode 100755
new mode 100644
diff --git a/cpp-package/example/unittests/unit_test_mlp_csv.sh b/cpp-package/example/unittests/unit_test_mlp_csv.sh
old mode 100755
new mode 100644
diff --git a/docker/docker-python/build_python_dockerfile.sh b/docker/docker-python/build_python_dockerfile.sh
old mode 100755
new mode 100644
diff --git a/example/image-classification/symbols/resnetv1.py b/example/image-classification/symbols/resnetv1.py
old mode 100755
new mode 100644
diff --git a/julia/deps/cpcblas.sh b/julia/deps/cpcblas.sh
old mode 100755
new mode 100644
diff --git a/julia/models/Inception/get.sh b/julia/models/Inception/get.sh
old mode 100755
new mode 100644
diff --git a/julia/test/travis/run_test.sh b/julia/test/travis/run_test.sh
old mode 100755
new mode 100644
diff --git a/julia/test/travis/setup_env.sh b/julia/test/travis/setup_env.sh
old mode 100755
new mode 100644
diff --git a/scala-package/examples/scripts/benchmark/run_image_inference_bm.sh b/scala-package/examples/scripts/benchmark/run_image_inference_bm.sh
old mode 100755
new mode 100644
diff --git a/scala-package/examples/scripts/benchmark/run_text_charrnn_bm.sh b/scala-package/examples/scripts/benchmark/run_text_charrnn_bm.sh
old mode 100755
new mode 100644
diff --git a/scala-package/examples/scripts/infer/imageclassifier/get_resnet_18_data.sh b/scala-package/examples/scripts/infer/imageclassifier/get_resnet_18_data.sh
old mode 100755
new mode 100644
diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index 18bd7608e4c..8f09cbd72a8 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -176,6 +176,14 @@ inline void ParseTopKParam(const TShape& src_shape, const TopKParam& param, TSha
 
 using namespace mshadow;
 
+
+struct fill_ind_to_one {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const int* indices, DType* out) {
+    out[indices[i]] = static_cast<DType>(1);
+  }
+};
+
 template<typename DType>
 MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
                                    const Tensor<cpu, 1, int>& ind,
@@ -313,7 +321,8 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
   const int M(dat.size(0)/N);
   if (full_sort) {
     // Divide workspace into two parts. The first one is needed to store batch ids.
-    const int id_size(sizeof(int)*ind.size(0));
+    size_t alignment = std::max(sizeof(DType), sizeof(int));
+    size_t id_size = PadBytes(sizeof(int) * ind.size(0), alignment);
     Tensor<gpu, 1, int> batch_id(reinterpret_cast<int*>(work.dptr_), Shape1(ind.size(0)), s);
     Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s);
     mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
@@ -364,12 +373,12 @@ void TopKImpl(const RunContext &ctx,
   Tensor<xpu, 1, char> temp_workspace;
   Tensor<xpu, 1, DType> sorted_dat;
   Tensor<xpu, 1, int> indices, sel_indices;
-  Tensor<xpu, 2, DType> mask_val;
   int batch_size, element_num;  // number of batches + the size of each batch
   int axis = 0;
   bool do_transpose = false;
   bool is_ascend = false;
   int k = 0;
+  size_t alignment = std::max(sizeof(DType), sizeof(int));
   TShape target_shape;
   ParseTopKParam(src.shape_, param,
                  &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
@@ -387,32 +396,28 @@ void TopKImpl(const RunContext &ctx,
   temp_size = std::max(temp_size,
     mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
   // Additional temp space for gpu full sorts for batch ids.
-  temp_size += sizeof(int) * src.Size();
+  temp_size += PadBytes(sizeof(int) * src.Size(), alignment);
   // Temp space for cpu sorts.
-  temp_size = std::max(temp_size, sizeof(DType) * static_cast<size_t>(src.Size()));
-  index_t workspace_size = temp_size + sizeof(DType) * src.Size() + sizeof(int) * src.Size();
+  temp_size = std::max(temp_size, sizeof(DType) * src.Size());
+  size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
+                                    + PadBytes(sizeof(int) * src.Size(), alignment);
   if (param.ret_typ == topk_enum::kReturnMask) {
-    workspace_size += sizeof(int) * batch_size * k + sizeof(DType) * batch_size * k;
+    workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment);
   }
   workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
   char* workspace_curr_ptr = workspace.dptr_;
   sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
                                       Shape1(src.Size()), s);  // contain sorted dat
-  workspace_curr_ptr += sizeof(DType) * src.Size();
+  workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
   indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                 Shape1(src.Size()), s);  // indices in the original matrix
-  workspace_curr_ptr += sizeof(int) * src.Size();
+  workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment);
 
   if (param.ret_typ == topk_enum::kReturnMask) {
     sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
                                       Shape1(batch_size * k), s);
-    workspace_curr_ptr += sizeof(int) * batch_size * k;
-    mask_val = Tensor<xpu, 2, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
-                                      Shape2(batch_size * k, 1), s);
-    workspace_curr_ptr += sizeof(DType) * batch_size * k;
-    mask_val = scalar<DType>(1);
+    workspace_curr_ptr += PadBytes(sizeof(int) * batch_size * k, alignment);
     CHECK_EQ(sel_indices.CheckContiguous(), true);
-    CHECK_EQ(mask_val.CheckContiguous(), true);
   }
 
   if (std::is_same<xpu, cpu>::value) {
@@ -458,8 +463,7 @@ void TopKImpl(const RunContext &ctx,
   // Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
   // is large enough.
   if (param.ret_typ == topk_enum::kReturnMask) {
-    Tensor<xpu, 2, DType> ret_mask =
-      ret[0].get_with_shape<xpu, 2, DType>(Shape2(ret[0].Size(), 1), s);
+    Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
     ret_mask = scalar<DType>(0);
     sel_indices = reshape(slice<1>(
                               inplace_reshape(indices,
@@ -475,7 +479,8 @@ void TopKImpl(const RunContext &ctx,
     if (req[0] == kNullOp) {
       return;
     } else if (req[0] == kWriteTo) {
-      IndexFill(ret_mask, sel_indices, mask_val);
+      mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
+                                                     sel_indices.dptr_, ret_mask.dptr_);
     } else {
       LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
     }
diff --git a/src/operator/tensor/sort_op-inl.cuh b/src/operator/tensor/sort_op-inl.cuh
index 1a8e2325ef4..f0caee4f5cb 100644
--- a/src/operator/tensor/sort_op-inl.cuh
+++ b/src/operator/tensor/sort_op-inl.cuh
@@ -74,8 +74,9 @@ SortByKeyWorkspaceSize(const size_t num_keys) {
   size_t sortpairs_bytes = 0;
   cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
       NULL, NULL, NULL, NULL, num_keys);
-  size_t keys_bytes = num_keys*sizeof(KDType);
-  size_t values_bytes = num_keys*sizeof(VDType);
+  size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
+  size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment);
+  size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment);
   return (keys_bytes + values_bytes + sortpairs_bytes);
 #endif
 }
@@ -96,8 +97,9 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
     // Workspace given, sort using CUB
     CHECK_EQ(workspace->CheckContiguous(), true);
     // workspace = [keys_out, values_out, temporary_storage]
-    size_t keys_bytes = keys.size(0)*sizeof(KDType);
-    size_t values_bytes = keys.size(0)*sizeof(VDType);
+    size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
+    size_t keys_bytes = PadBytes(keys.size(0)*sizeof(KDType), alignment);
+    size_t values_bytes = PadBytes(keys.size(0)*sizeof(VDType), alignment);
     // Get the size of internal storage (for checking purposes only)
     size_t sortpairs_bytes = 0;
     if (is_ascend) {
diff --git a/src/operator/tensor/sort_op.h b/src/operator/tensor/sort_op.h
index 3fa95bb660f..5881060d987 100644
--- a/src/operator/tensor/sort_op.h
+++ b/src/operator/tensor/sort_op.h
@@ -31,6 +31,11 @@
 #include <type_traits>
 
 namespace mxnet {
+
+inline size_t PadBytes(size_t num_bytes, size_t alignment) {
+  return num_bytes + (alignment - num_bytes % alignment) % alignment;
+}
+
 namespace op {
 /*!
  * \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 00cf30ae62d..0044ae38f5e 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -723,7 +723,7 @@ def get_large_matrix():
     gt = gt_topk(large_matrix_npy, axis=1, ret_typ="indices", k=5, is_ascend=False)
     assert_almost_equal(nd_ret_topk, gt)
 
-    for dtype in [ np.int32, np.int64, np.float32, np.float64]:
+    for dtype in [np.int32, np.int64, np.float32, np.float64]:
         a_npy = get_values(ensure_unique=True, dtype=dtype)
         a_nd = mx.nd.array(a_npy, ctx=ctx, dtype=dtype)
 
@@ -754,9 +754,6 @@ def get_large_matrix():
         assert_almost_equal(nd_ret_topk, gt)
 
         # test for ret_typ=mask
-        # test needs to be re-enabled once flaky topk gets fixed
-        # tracked in https://github.com/apache/incubator-mxnet/pull/12446
-        '''
         nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
         assert nd_ret_topk.dtype == dtype
         gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
@@ -767,7 +764,7 @@ def get_large_matrix():
         nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="mask", k=21, is_ascend=False).asnumpy()
         gt = gt_topk(a_npy, axis=None, ret_typ="mask", k=21, is_ascend=False)
         assert_almost_equal(nd_ret_topk, gt)
-        '''
+
         # test for ret_typ=both
         nd_ret_topk_val, nd_ret_topk_ind = mx.nd.topk(a_nd, axis=1, ret_typ="both", k=3, is_ascend=True)
         nd_ret_topk_val = nd_ret_topk_val.asnumpy()
@@ -800,6 +797,7 @@ def get_large_matrix():
         # test for argsort
         for idtype in [np.int32, np.float16, np.float32, np.float64]:
             nd_ret_argsort = mx.nd.argsort(a_nd, axis=3, is_ascend=True, dtype=idtype).asnumpy()
+            assert nd_ret_argsort.dtype == idtype
             gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=dat_size, is_ascend=True)
             assert_almost_equal(nd_ret_argsort, gt)
             nd_ret_argsort = mx.nd.argsort(a_nd, axis=None, is_ascend=False, dtype=idtype).asnumpy()
@@ -863,7 +861,7 @@ def get_large_matrix():
     # Repeat those tests that don't involve indices.  These should pass even with
     # duplicated input data values (over many repeated runs with different random seeds,
     # this will be tested).
-    for dtype in [ np.int32, np.int64, np.float32, np.float64]:
+    for dtype in [np.int32, np.int64, np.float32, np.float64]:
         a_npy = get_values(ensure_unique=False, dtype=dtype)
         a_nd = mx.nd.array(a_npy, ctx=ctx, dtype=dtype)
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services