You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:01:03 UTC

[incubator-mxnet] 27/42: numpy-compatible cumsum (#15309)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 62ea3d233c1dc353116d45a6e8ebd21fd5e7d63a
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Sun Jun 23 12:27:16 2019 +0800

    numpy-compatible cumsum (#15309)
---
 src/operator/numpy/np_cumsum-inl.h     | 184 +++++++++++++++++++++++++++++++++
 src/operator/numpy/np_cumsum.cc        |  92 +++++++++++++++++
 src/operator/numpy/np_cumsum.cu        |  37 +++++++
 tests/python/unittest/test_numpy_op.py |  42 ++++++++
 4 files changed, 355 insertions(+)

diff --git a/src/operator/numpy/np_cumsum-inl.h b/src/operator/numpy/np_cumsum-inl.h
new file mode 100644
index 0000000..a9d2d8b
--- /dev/null
+++ b/src/operator/numpy/np_cumsum-inl.h
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_cumsum-inl.h
+ * \brief Function definition of numpy-compatible cumsum operator
+ */
+
+#ifndef MXNET_OPERATOR_NUMPY_NP_CUMSUM_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_CUMSUM_INL_H_
+
+#include <mxnet/base.h>
+#include <mxnet/operator_util.h>
+#include <vector>
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct CumsumParam : public dmlc::Parameter<CumsumParam> {
+  dmlc::optional<int> axis;
+  dmlc::optional<int> dtype;
+  DMLC_DECLARE_PARAMETER(CumsumParam) {
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(dmlc::optional<int>())
+      .describe("Axis along which the cumulative sum is computed."
+        " The default (None) is to compute the cumsum over the flattened array.");
+    DMLC_DECLARE_FIELD(dtype)
+      .add_enum("float16", mshadow::kFloat16)
+      .add_enum("float32", mshadow::kFloat32)
+      .add_enum("float64", mshadow::kFloat64)
+      .add_enum("int8", mshadow::kInt8)
+      .add_enum("int32", mshadow::kInt32)
+      .add_enum("int64", mshadow::kInt64)
+      .set_default(dmlc::optional<int>())
+      .describe("Type of the returned array and of the accumulator in which the elements"
+                " are summed. If dtype is not specified, it defaults to the dtype of a,"
+                " unless a has an integer dtype with a precision less than that of the"
+                " default platform integer. In that case, the default platform integer is used.");
+  }
+};
+
+struct cumsum_forward {
+  template<typename IType, typename OType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  OType *out,
+                                  const IType *in,
+                                  const int middle,
+                                  const int trailing) {
+    int left = i / trailing, right = i % trailing;
+    int offset = left * middle * trailing + right;
+    const IType *lane_in = in + offset;
+    OType *lane_out = out + offset;
+    lane_out[0] = OType(lane_in[0]);
+    for (int j = 1; j < middle; ++j) {
+      lane_out[j * trailing] = lane_out[(j - 1) * trailing] + OType(lane_in[j * trailing]);
+    }
+  }
+};
+
+template<typename xpu>
+void CumsumForwardImpl(const OpContext& ctx,
+                       const TBlob& in,
+                       const TBlob& out,
+                       const dmlc::optional<int>& axis) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+
+  int middle = axis.has_value() ? out.shape_[axis.value()] : out.Size();
+  if (middle == 0 || out.Size() == 0) return;
+  int trailing = 1;
+  if (axis.has_value()) {
+    for (int i = axis.value() + 1; i < out.shape_.ndim(); ++i) {
+      trailing *= out.shape_[i];
+    }
+  }
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(in.type_flag_, IType, {
+    MSHADOW_TYPE_SWITCH(out.type_flag_, OType, {
+      Kernel<cumsum_forward, xpu>::Launch(
+        s, out.Size() / middle, out.dptr<OType>(),
+        in.dptr<IType>(), middle, trailing);
+    });
+  });
+}
+
+template<typename xpu>
+void CumsumForward(const nnvm::NodeAttrs& attrs,
+                   const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);
+
+  CumsumForwardImpl<xpu>(ctx, inputs[0], outputs[0], param.axis);
+}
+
+struct cumsum_backward {
+  template<typename IType, typename OType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  IType *igrad,
+                                  const OType *ograd,
+                                  const int middle,
+                                  const int trailing) {
+    int left = i / trailing, right = i % trailing;
+    int offset = left * middle * trailing + right;
+    const OType *lane_ograd = ograd + offset;
+    IType *lane_igrad = igrad + offset;
+    lane_igrad[(middle - 1) * trailing] = IType(lane_ograd[(middle - 1) * trailing]);
+    for (int j = middle - 2; j >= 0; --j) {
+      lane_igrad[j * trailing] = lane_igrad[(j + 1) * trailing] + IType(lane_ograd[j * trailing]);
+    }
+  }
+};
+
+template<typename xpu>
+void CumsumBackwardImpl(const OpContext& ctx,
+                        const TBlob& ograd,
+                        const TBlob& igrad,
+                        const dmlc::optional<int>& axis) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  int middle = axis.has_value() ? igrad.shape_[axis.value()] : igrad.Size();
+  if (middle == 0 || igrad.Size() == 0) return;
+  int trailing = 1;
+  if (axis.has_value()) {
+    for (int i = axis.value() + 1; i < igrad.shape_.ndim(); ++i) {
+      trailing *= igrad.shape_[i];
+    }
+  }
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(igrad.type_flag_, IType, {
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
+      Kernel<cumsum_backward, xpu>::Launch(
+        s, igrad.Size() / middle, igrad.dptr<IType>(),
+        ograd.dptr<OType>(), middle, trailing);
+    });
+  });
+}
+
+template<typename xpu>
+void CumsumBackward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);
+
+  CumsumBackwardImpl<xpu>(ctx, inputs[0], outputs[0], param.axis);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_NP_CUMSUM_INL_H_
diff --git a/src/operator/numpy/np_cumsum.cc b/src/operator/numpy/np_cumsum.cc
new file mode 100644
index 0000000..8f16f25
--- /dev/null
+++ b/src/operator/numpy/np_cumsum.cc
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_cumsum.cc
+ * \brief CPU implementation of numpy-compatible cumsum operator
+ */
+
+#include "./np_cumsum-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool CumsumShape(const nnvm::NodeAttrs& attrs,
+                        mxnet::ShapeVector *in_attrs,
+                        mxnet::ShapeVector *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);
+
+  if (param.axis.has_value()) {
+    return ElemwiseShape<1, 1>(attrs, in_attrs, out_attrs);
+  } else {
+    TShape out_shape(1, in_attrs->at(0).Size());
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
+    return shape_is_known(out_attrs->at(0));
+  }
+}
+
+inline bool CumsumType(const nnvm::NodeAttrs& attrs,
+                       std::vector<int> *in_attrs,
+                       std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const CumsumParam &param = nnvm::get<CumsumParam>(attrs.parsed);
+
+  if (param.dtype.has_value()) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
+  } else {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+  }
+
+  return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
+}
+
+DMLC_REGISTER_PARAMETER(CumsumParam);
+
+NNVM_REGISTER_OP(_np_cumsum)
+.set_attr_parser(ParamParser<CumsumParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", CumsumShape)
+.set_attr<nnvm::FInferType>("FInferType", CumsumType)
+.set_attr<FCompute>("FCompute<cpu>", CumsumForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_cumsum"})
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.add_argument("a", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(CumsumParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_cumsum)
+.set_attr_parser(ParamParser<CumsumParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", CumsumBackward<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_cumsum.cu b/src/operator/numpy/np_cumsum.cu
new file mode 100644
index 0000000..cc574eb
--- /dev/null
+++ b/src/operator/numpy/np_cumsum.cu
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_cumsum.cu
+ * \brief GPU implementation of numpy-compatible cumsum operator
+ */
+
+#include "./np_cumsum-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_np_cumsum)
+.set_attr<FCompute>("FCompute<gpu>", CumsumForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_cumsum)
+.set_attr<FCompute>("FCompute<gpu>", CumsumBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 3ce0440..7a43083 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -979,6 +979,48 @@ def test_np_split():
 
 @with_seed()
 @npx.use_np_shape
+def test_np_cumsum():
+    def np_cumsum_backward(ograd, axis=None, dtype=None):
+        return _np.flip(_np.cumsum(_np.flip(ograd, axis=axis), axis=axis, dtype=dtype), axis=axis)
+
+    @npx.use_np_shape
+    class TestCumsum(HybridBlock):
+        def __init__(self, axis=None, dtype=None):
+            super(TestCumsum, self).__init__()
+            self._axis = axis
+            self._dtype = dtype
+
+        def hybrid_forward(self, F, a):
+            return F.np.cumsum(a, axis=self._axis, dtype=self._dtype)
+
+    shapes = [(2, 3, 4), (2, 0, 3), ()]
+    for hybridize in [True, False]:
+        for shape in shapes:
+            for axis in [None] + [i for i in range(0, len(shape))]:
+                for otype in [None, _np.float32, _np.float64]:
+                    test_cumsum = TestCumsum(axis=axis, dtype=otype)
+                    if hybridize:
+                        test_cumsum.hybridize()
+                    for itype in [_np.float16, _np.float32, _np.float64]:
+                        x = rand_ndarray(shape).astype(itype).as_np_ndarray()
+                        x.attach_grad()
+                        np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype)
+                        with mx.autograd.record():
+                            mx_out = test_cumsum(x)
+                        assert mx_out.shape == np_out.shape
+                        assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+                        mx_out.backward()
+                        np_backward = np_cumsum_backward(_np.ones(np_out.shape, dtype=otype),
+                                                         axis=axis, dtype=otype).reshape(x.shape)
+                        assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)
+
+                        mx_out = np.cumsum(x, axis=axis, dtype=otype)
+                        np_out = _np.cumsum(x.asnumpy(), axis=axis, dtype=otype)
+                        assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
+@with_seed()
+@npx.use_np_shape
 def test_np_tile():
     config = [
         ((), ()),