You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/24 23:14:36 UTC

[incubator-mxnet] branch master updated: fix indpt[0] for take(csr) (#12927)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 57176cd  fix indpt[0] for take(csr) (#12927)
57176cd is described below

commit 57176cd84fa2715ae4707086d764c0a87cc30b54
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Oct 24 16:14:22 2018 -0700

    fix indpt[0] for take(csr) (#12927)
---
 src/operator/tensor/indexing_op.cc           | 7 +++++--
 tests/python/unittest/test_sparse_ndarray.py | 5 +++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 98e2536..710b502 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -137,7 +137,10 @@ struct CsrTakeRowCountKernel {
   MSHADOW_XINLINE static void Map(int tid, RType* out_indptr,
                                   const RType* src_indptr, const IType* idx_ptr,
                                   const nnvm::dim_t num_rows) {
-    if (tid == 0) out_indptr[0] = 0;
+    if (tid == 0) {
+      out_indptr[0] = 0;
+      return;
+    }
     nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
     // clip mode
     if (clip) {
@@ -181,7 +184,7 @@ void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
   out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});
 
   MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
-    MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, {
+    MSHADOW_TYPE_SWITCH(arr.dtype(), DType, {
       MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
         RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
         const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 8dd250c..43d370b 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -1018,13 +1018,14 @@ def test_sparse_take():
     def check_sparse_take(density, mode):
         data_shape = rand_shape_2d()
         idx_shape = (np.random.randint(low=1, high=10),)
-        data = rand_ndarray(data_shape, 'csr', density=density)
+        data = rand_ndarray(data_shape, 'csr', density=density).astype('int32')
         idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape))
-        result = mx.nd.take(data, idx, mode=mode)
         data_np = data.asnumpy()
         idx_np = idx.asnumpy().astype('int32')
         expected_result = np.take(data_np, idx_np, mode=mode, axis=0)
+        result = mx.nd.take(data, idx, mode=mode)
         assert_almost_equal(result.asnumpy(), expected_result)
+        assert result.indptr[0].asscalar() == 0
     densities = [0, 0.5, 1]
     modes = ['clip', 'wrap']
     for d in densities: