You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/23 05:43:41 UTC

[incubator-mxnet] branch numpy_staging_prs updated: numpy operator nonzero (#15838)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy_staging_prs
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy_staging_prs by this push:
     new c36819e  numpy operator nonzero (#15838)
c36819e is described below

commit c36819e97fb7ac8159b00ab31e04a48149fc0c40
Author: tingying <ti...@u.northwestern.edu>
AuthorDate: Mon Sep 23 13:43:14 2019 +0800

    numpy operator nonzero (#15838)
    
    * add cpu test and handle 0-dim
    
    * add FGradient with MakeZeroGradNodes
    
    * handle 0-dim and 0-shape and add test on gpu
    
    * add doc
    
    * fix bug in review
    
    * do not use thrust::inclusive_scan on cpu
    
    * fix format error
    
    * edit test and remove gpu test
    
    The output is same as numpy.transpose(numpy.nonzero(x))
    
    * fix error of review
    
    * edit test
---
 python/mxnet/_numpy_op_doc.py          |  48 ++++++++++++
 src/operator/numpy/np_nonzero_op-inl.h |  65 +++++++++++++++++
 src/operator/numpy/np_nonzero_op.cc    | 129 ++++++++++++++++++++++++++++++++
 src/operator/numpy/np_nonzero_op.cu    | 130 +++++++++++++++++++++++++++++++++
 tests/python/unittest/test_numpy_op.py |  32 ++++++++
 5 files changed, 404 insertions(+)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index f4787be..6d2776e 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -109,7 +109,55 @@ def _np_cumsum(a, axis=None, dtype=None, out=None):
     >>> np.cumsum(a,axis=1)      # sum over columns for each of the 2 rows
     array([[ 1,  3,  6],
            [ 4,  9, 15]])
+    """
+    pass
+
+
+def _npx_nonzero(a):
+    """
+    nonzero(a)
+
+    Return the indices of the elements that are non-zero.
+
+    Returns a ndarray with ndim is 2. Each row contains the indices 
+    of the non-zero elements. The values in `a` are always tested and returned in
+    row-major, C-style order.
+
+    The result of this is always a 2-D array, with a row for
+    each non-zero element.
+
+    Parameters
+    ----------
+    a : array_like
+        Input array.
+
+    Returns
+    -------
+    array : ndarray
+        Indices of elements that are non-zero.
 
+    Notes
+    -----
+    This function differs from the original numpy.prod in the following aspects:
+        - Do not support python numeric.
+        - The return value is same as numpy.transpose(numpy.nonzero(a)).
+
+    Examples
+    --------
+    >>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
+    >>> x
+    array([[3, 0, 0],
+           [0, 4, 0],
+           [5, 6, 0]])
+    >>> npx.nonzero(x)
+    array([[0, 0],
+           [1, 1],
+           [2, 0],
+           [2, 1]], dtype=int64)
+
+    >>> np.transpose(npx.nonzero(x))
+    array([[0, 1, 2, 2],
+           [0, 1, 0, 1]], dtype=int64)
     """
     pass
 
diff --git a/src/operator/numpy/np_nonzero_op-inl.h b/src/operator/numpy/np_nonzero_op-inl.h
new file mode 100644
index 0000000..88929c4
--- /dev/null
+++ b/src/operator/numpy/np_nonzero_op-inl.h
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file np_nonzero_op-inl.h
+*/
+
+#ifndef MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/ndarray.h>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+#include <algorithm>
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+#include "../tensor/init_op.h"
+#include "../mshadow_op.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct NonzeroForwardKernel {
+  template<int ndim>
+  MSHADOW_XINLINE static void Map(int i,
+                                  int64_t* out,
+                                  const int32_t* idx,
+                                  const mshadow::Shape<ndim> shape) {
+    int32_t prev = (i == 0) ? 0 : idx[i - 1];
+    int32_t curr = idx[i];
+    if (prev != curr) {
+      mshadow::Shape<ndim> coord = mxnet_op::unravel<ndim>(i, shape);
+      for (int j = 0; j < ndim; j++) {
+        out[prev * ndim + j] = coord[j];
+      }
+    }
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
diff --git a/src/operator/numpy/np_nonzero_op.cc b/src/operator/numpy/np_nonzero_op.cc
new file mode 100644
index 0000000..00f9081
--- /dev/null
+++ b/src/operator/numpy/np_nonzero_op.cc
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file np_nonzero_op.cc
+*/
+#include "np_nonzero_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool NonzeroType(const nnvm::NodeAttrs& attrs,
+                 std::vector<int> *in_attrs,
+                 std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  // Output must be int64.
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
+  return out_attrs->at(0) != -1;
+}
+
+#define MAXDIM 5
+
+bool NonzeroStorageType(const nnvm::NodeAttrs& attrs,
+                        const int dev_mask,
+                        DispatchMode* dispatch_mode,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  for (int &attr : *in_attrs) {
+    CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
+  }
+  for (int &attr : *out_attrs) {
+    attr = kDefaultStorage;
+  }
+  *dispatch_mode = DispatchMode::kFComputeEx;
+  return true;
+}
+
+void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
+                       const OpContext &ctx,
+                       const std::vector<NDArray> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<NDArray> &outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const NDArray &in = inputs[0];
+  const NDArray &out = outputs[0];
+  CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
+  // 0-dim
+  if (0 == in.shape().ndim()) {
+    MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+      DType* in_dptr = in.data().dptr<DType>();
+      if (*in_dptr) {
+        mxnet::TShape s(2, 1);
+        const_cast<NDArray &>(out).Init(s);
+        *(out.data().dptr<int64_t>()) = 0;
+      } else {
+        mxnet::TShape s(2, 1);
+        s[0] = 0;
+        const_cast<NDArray &>(out).Init(s);
+      }
+    });
+    return;
+  }
+  size_t in_size = in.shape().Size();
+  // 0-shape
+  if (0 == in_size) {
+    mxnet::TShape s(2, in.shape().ndim());
+    s[0] = 0;
+    const_cast<NDArray &>(out).Init(s);
+    return;
+  }
+  std::vector<int32_t> prefix_sum(in_size, 0);
+  size_t valid_num = 0;
+  // Calculate prefix sum
+  MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+    DType* in_dptr = in.data().dptr<DType>();
+    for (size_t i = 0; i < in_size; i++) {
+      prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
+      prefix_sum[i] += (in_dptr[i]) ? 1 : 0;
+    }
+  });
+  valid_num = prefix_sum[in_size - 1];
+  // set the output shape forcefully
+  mxnet::TShape s(2, in.shape().ndim());
+  s[0] = valid_num;
+  const_cast<NDArray &>(out).Init(s);
+  // get the shape from the input
+  MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
+    mshadow::Shape<ndim> shape = in.shape().get<ndim>();
+    mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
+    mxnet_op::Kernel<NonzeroForwardKernel, cpu>::Launch(
+      stream, in_size, out.data().dptr<int64_t>(), prefix_sum.data(), shape);
+  })
+}
+
+NNVM_REGISTER_OP(_npx_nonzero)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"x"};
+  })
+.set_attr<nnvm::FInferType>("FInferType", NonzeroType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", NonzeroForwardCPU)
+.set_attr<FInferStorageType>("FInferStorageType", NonzeroStorageType)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("x", "NDArray-or-Symbol", "The input array.");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_nonzero_op.cu b/src/operator/numpy/np_nonzero_op.cu
new file mode 100644
index 0000000..33925ea
--- /dev/null
+++ b/src/operator/numpy/np_nonzero_op.cu
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file np_nonzero_op.cu
+*/
+
+#include "np_nonzero_op-inl.h"
+#include <cub/cub.cuh>
+
+namespace mxnet {
+namespace op {
+
+struct PrefixSumInit {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  int32_t* out,
+                                  DType* in) {
+    if (in[i]) {
+      out[i] = 1;
+    } else {
+      out[i] = 0;
+    }
+  }
+};
+
+#define MAXDIM 5
+
+void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs,
+                       const OpContext &ctx,
+                       const std::vector<NDArray> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<NDArray> &outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const NDArray &in = inputs[0];
+  const NDArray &out = outputs[0];
+  CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
+  size_t in_size = in.shape().Size();
+  // 0-shape
+  if (0 == in_size) {
+    mxnet::TShape s(2, in.shape().ndim());
+    s[0] = 0;
+    const_cast<NDArray &>(out).Init(s);
+    return;
+  }
+  int32_t valid_num = 0;
+  Stream<gpu>* stream = ctx.get_stream<gpu>();
+  int32_t* prefix_sum = nullptr;
+  void* d_temp_storage = nullptr;
+  size_t temp_storage_bytes = 0;
+  // Calculate total temporary memory size
+  cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                temp_storage_bytes,
+                                prefix_sum,
+                                prefix_sum,
+                                in_size,
+                                Stream<gpu>::GetStream(stream));
+  size_t buffer_size = in_size * sizeof(int32_t);
+  temp_storage_bytes += buffer_size;
+  // Allocate memory on GPU and allocate pointer
+  Tensor<gpu, 1, char> workspace =
+    ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), stream);
+  prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
+  d_temp_storage = workspace.dptr_ + buffer_size;
+  MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+    mxnet_op::Kernel<PrefixSumInit, gpu>::Launch(
+      stream, in_size, prefix_sum, in.data().dptr<DType>());
+  });
+  // Calculate prefix sum
+  cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                temp_storage_bytes,
+                                prefix_sum,
+                                prefix_sum,
+                                in_size,
+                                Stream<gpu>::GetStream(stream));
+  CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t),
+                       cudaMemcpyDeviceToHost));
+  // 0-dim
+  if (0 == in.shape().ndim()) {
+    mxnet::TShape s(2, 1);
+    if (valid_num) {
+      const_cast<NDArray &>(out).Init(s);
+      int64_t temp = 0;
+      CUDA_CALL(cudaMemcpy(out.data().dptr<int64_t>(), &temp, sizeof(int64_t),
+                           cudaMemcpyHostToDevice));
+    } else {
+      s[0] = 0;
+      const_cast<NDArray &>(out).Init(s);
+    }
+    return;
+  }
+  // Set the output shape forcefully
+  mxnet::TShape s(2, in.shape().ndim());
+  s[0] = valid_num;
+  const_cast<NDArray &>(out).Init(s);
+  // get the shape from the input
+  MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
+    mshadow::Shape<ndim> shape = in.shape().get<ndim>();
+    mxnet_op::Kernel<NonzeroForwardKernel, gpu>::Launch(
+      stream, in_size, out.data().dptr<int64_t>(), prefix_sum, shape);
+  })
+}
+
+NNVM_REGISTER_OP(_npx_nonzero)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FComputeEx>("FComputeEx<gpu>", NonzeroForwardGPU);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 3d30012..8d12419 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2447,6 +2447,38 @@ def test_np_arctan2():
                 assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
 
 
+@with_seed()
+@use_np
+def test_np_nonzero():
+    class TestNonzero(HybridBlock):
+        def __init__(self):
+            super(TestNonzero, self).__init__()
+            
+        def hybrid_forward(self, F, x):
+            return F.npx.nonzero(x)
+
+    types = ['int32', 'int64', 'float64', 'float32', 'float16']
+    for hybridize in [True, False]:
+        for shape in [(), (1, 2, 3), (1, 0)]:
+            for oneType in types:
+                rtol, atol = 1e-3, 1e-5
+                test_nonzero = TestNonzero()
+                if hybridize:
+                    test_nonzero.hybridize()
+                x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
+                np_out = _np.nonzero(x.asnumpy())
+                np_out = _np.transpose(np_out)
+                mx_out = test_nonzero(x)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)
+
+                # Test imperative once again
+                mx_out = npx.nonzero(x)
+                np_out = _np.nonzero(x.asnumpy())
+                np_out = _np.transpose(np_out)
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()