You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/24 23:14:36 UTC
[incubator-mxnet] branch master updated: fix indpt[0] for take(csr)
(#12927)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 57176cd fix indpt[0] for take(csr) (#12927)
57176cd is described below
commit 57176cd84fa2715ae4707086d764c0a87cc30b54
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Oct 24 16:14:22 2018 -0700
fix indpt[0] for take(csr) (#12927)
---
src/operator/tensor/indexing_op.cc | 7 +++++--
tests/python/unittest/test_sparse_ndarray.py | 5 +++--
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 98e2536..710b502 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -137,7 +137,10 @@ struct CsrTakeRowCountKernel {
MSHADOW_XINLINE static void Map(int tid, RType* out_indptr,
const RType* src_indptr, const IType* idx_ptr,
const nnvm::dim_t num_rows) {
- if (tid == 0) out_indptr[0] = 0;
+ if (tid == 0) {
+ out_indptr[0] = 0;
+ return;
+ }
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
// clip mode
if (clip) {
@@ -181,7 +184,7 @@ void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});
MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
- MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, {
+ MSHADOW_TYPE_SWITCH(arr.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 8dd250c..43d370b 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -1018,13 +1018,14 @@ def test_sparse_take():
def check_sparse_take(density, mode):
data_shape = rand_shape_2d()
idx_shape = (np.random.randint(low=1, high=10),)
- data = rand_ndarray(data_shape, 'csr', density=density)
+ data = rand_ndarray(data_shape, 'csr', density=density).astype('int32')
idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape))
- result = mx.nd.take(data, idx, mode=mode)
data_np = data.asnumpy()
idx_np = idx.asnumpy().astype('int32')
expected_result = np.take(data_np, idx_np, mode=mode, axis=0)
+ result = mx.nd.take(data, idx, mode=mode)
assert_almost_equal(result.asnumpy(), expected_result)
+ assert result.indptr[0].asscalar() == 0
densities = [0, 0.5, 1]
modes = ['clip', 'wrap']
for d in densities: