You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/09/22 07:19:50 UTC

[incubator-mxnet] branch master updated: add numpy compatible trace (#16008)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new da6e744  add numpy compatible trace (#16008)
da6e744 is described below

commit da6e7442e6a4e00246c063298aadfb231a5ac63a
Author: Haozheng Fan <fh...@gmail.com>
AuthorDate: Sun Sep 22 15:19:24 2019 +0800

    add numpy compatible trace (#16008)
    
    add doc for trace
---
 python/mxnet/_numpy_op_doc.py          |  66 +++++++--
 src/operator/numpy/np_trace_op-inl.h   | 255 +++++++++++++++++++++++++++++++++
 src/operator/numpy/np_trace_op.cc      |  98 +++++++++++++
 src/operator/numpy/np_trace_op.cu      |  36 +++++
 tests/python/unittest/test_numpy_op.py |  82 +++++++++++
 5 files changed, 527 insertions(+), 10 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index a83447f..f4787be 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -530,22 +530,22 @@ def _np_roll(a, shift, axis=None):
     Parameters
     ----------
     a : ndarray
-    Input array.
+        Input array.
     shift : int or tuple of ints
-    The number of places by which elements are shifted.  If a tuple,
-    then `axis` must be a tuple of the same size, and each of the
-    given axes is shifted by the corresponding number.  If an int
-    while `axis` is a tuple of ints, then the same value is used for
-    all given axes.
+        The number of places by which elements are shifted.  If a tuple,
+        then `axis` must be a tuple of the same size, and each of the
+        given axes is shifted by the corresponding number.  If an int
+        while `axis` is a tuple of ints, then the same value is used for
+        all given axes.
     axis : int or tuple of ints, optional
-    Axis or axes along which elements are shifted.  By default, the
-    array is flattened before shifting, after which the original
-    shape is restored.
+        Axis or axes along which elements are shifted.  By default, the
+        array is flattened before shifting, after which the original
+        shape is restored.
 
     Returns
     -------
     res : ndarray
-    Output array, with the same shape as `a`.
+        Output array, with the same shape as `a`.
 
     Notes
     -----
@@ -581,5 +581,51 @@ def _np_roll(a, shift, axis=None):
     >>> np.roll(x2, -1, axis=1)
     array([[1., 2., 3., 4., 0.],
            [6., 7., 8., 9., 5.]])
+   """
+
+
+def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
+    """trace(a, offset=0, axis1=0, axis2=1, out=None)
+
+    Return the sum along diagonals of the array.
+    If `a` is 2-D, the sum along its diagonal with the given offset
+    is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i.
+    If `a` has more than two dimensions, then the axes specified by axis1 and
+    axis2 are used to determine the 2-D sub-arrays whose traces are returned.
+    The shape of the resulting array is the same as that of `a` with `axis1`
+    and `axis2` removed.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array, from which the diagonals are taken.
+    offset : int, optional
+        Offset of the diagonal from the main diagonal. Can be both positive
+        and negative. Defaults to 0.
+    axis1, axis2 : int, optional
+        Axes to be used as the first and second axis of the 2-D sub-arrays
+        from which the diagonals should be taken. Defaults are the first two
+        axes of `a`.
+    out : ndarray, optional
+        Array into which the output is placed. It must be of the right shape
+        and right type to hold the output.
+
+    Returns
+    -------
+    sum_along_diagonals : ndarray
+        If `a` is 2-D, the sum along the diagonal is returned.  If `a` has
+        larger dimensions, then an array of sums along diagonals is returned.
+
+    Examples
+    --------
+    >>> a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+    >>> np.trace(a)
+    array(3.)
+    >>> a = np.arange(8).reshape((2, 2, 2))
+    >>> np.trace(a)
+    array([6., 8.])
+    >>> a = np.arange(24).reshape((2, 2, 2, 3))
+    >>> np.trace(a).shape
+    (2, 3)
     """
     pass
diff --git a/src/operator/numpy/np_trace_op-inl.h b/src/operator/numpy/np_trace_op-inl.h
new file mode 100644
index 0000000..741c20b
--- /dev/null
+++ b/src/operator/numpy/np_trace_op-inl.h
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_trace_op-inl.h
+ * \brief Function definition of matrix numpy-compatible trace operator
+ */
+
+#ifndef MXNET_OPERATOR_NUMPY_NP_TRACE_OP_INL_H_
+#define MXNET_OPERATOR_NUMPY_NP_TRACE_OP_INL_H_
+
+#include <dmlc/parameter.h>
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+#include "../tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+struct NumpyTraceParam: public dmlc::Parameter<NumpyTraceParam> {
+  int offset, axis1, axis2;
+  DMLC_DECLARE_PARAMETER(NumpyTraceParam) {
+    DMLC_DECLARE_FIELD(offset)
+    .set_default(0)
+    .describe("Offset of the diagonal from the main diagonal. "
+              "Can be both positive and negative. Defaults to 0.");
+    DMLC_DECLARE_FIELD(axis1)
+    .set_default(0)
+    .describe("Axes to be used as the first axis of the 2-D sub-arrays "
+              "from which the diagonals should be taken. Defaults to 0.");
+    DMLC_DECLARE_FIELD(axis2)
+    .set_default(1)
+    .describe("Axes to be used as the second axis of the 2-D sub-arrays "
+              "from which the diagonals should be taken. Defaults to 1.");
+  }
+};
+
+template<int ndim, int req, bool back>
+struct numpy_trace {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
+                                  mshadow::Shape<ndim> oshape,
+                                  mshadow::Shape<ndim> ishape,
+                                  index_t stride, index_t offset, int dlength) {
+    using namespace mxnet_op;
+    using namespace mshadow;
+    index_t j = ravel(unravel(i, oshape), ishape) + offset;
+    if (back) {
+      for (index_t k = 0; k < dlength; ++k) {
+        KERNEL_ASSIGN(out[j], req, a[i]);
+        j += stride;
+      }
+    } else {
+      if (req == kWriteTo) {
+        out[i] = 0;
+        for (index_t k = 0; k < dlength; ++k) {
+          out[i] += a[j];
+          j += stride;
+        }
+      } else if (req == kAddTo) {
+        for (index_t k = 0; k < dlength; ++k) {
+          out[i] += a[j];
+          j += stride;
+        }
+      }
+    }
+  }
+};
+
+template<typename xpu, bool back>
+void NumpyTraceOpProcess(const TBlob& in_data,
+                         const TBlob& out_data,
+                         const mxnet::TShape& ishape,
+                         const mxnet::TShape& oshape,
+                         index_t dsize,
+                         const NumpyTraceParam& param,
+                         mxnet_op::Stream<xpu> *s,
+                         const std::vector<OpReqType>& req) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  if (dsize == 0) {
+    if (back) {
+      if (out_data.Size() != 0) {
+        MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+          MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+            if (req_type == kWriteTo) {
+              out_data.FlatTo1D<xpu, DType>(s) = 0;
+            }
+          });
+        });
+      }
+    }
+    return;
+  } else if (ishape.Size() == 0) {
+    if (!back) {
+      MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+        MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+          if (req_type == kWriteTo) {
+            out_data.FlatTo1D<xpu, DType>(s) = 0;
+          }
+        });
+      });
+    }
+    return;
+  }
+  uint32_t x1 = CheckAxis(param.axis1, ishape.ndim());
+  uint32_t x2 = CheckAxis(param.axis2, ishape.ndim());
+
+  uint32_t idim = ishape.ndim();
+
+  uint32_t minx = x1, maxx = x2;
+  if (minx > maxx) {
+    std::swap(minx, maxx);
+  }
+
+  // merges contiguous axes that are not separated
+  // by axis1 or axis2 since they can be directly
+  // mapped to the output and there is no need
+  // to distinguish them
+  // (After this the input will have no more than
+  // three axes, hence improving the rave and
+  // unravel efficiency)
+
+  index_t oleading = 1,
+          obody = 1,
+          otrailing = 1;
+
+  for (uint32_t i = 0; i < minx; ++i) {
+    oleading *= ishape[i];
+  }
+  for (uint32_t i = minx + 1; i < maxx; ++i) {
+    obody *= ishape[i];
+  }
+  for (uint32_t i = maxx + 1; i < idim; ++i) {
+    otrailing *= ishape[i];
+  }
+
+  index_t ileading = oleading,
+          ibody = obody * ishape[minx],
+          itrailing = otrailing * ishape[maxx];
+
+  index_t stride1 = itrailing * obody,
+          stride2 = otrailing;
+  // stride1 + stride2 is the stride for
+  // iterating over the diagonal in question
+
+  if (x1 == maxx) {
+    std::swap(stride1, stride2);
+  }
+
+  // the extra index offset introduced by offset
+  index_t offset;
+  if (param.offset > 0) {
+    offset = stride2 * param.offset;
+  } else if (param.offset < 0) {
+    offset = stride1 * -param.offset;
+  } else {
+    offset = 0;
+  }
+
+  // number of elements in the offset diagonal
+  // may be negative
+  int dlength;
+  if (param.offset > 0) {
+    dlength = std::min(ishape[x1], ishape[x2] - param.offset);
+  } else if (param.offset < 0) {
+    dlength = std::min(ishape[x1] - (-param.offset), ishape[x2]);
+  } else {
+    dlength = std::min(ishape[x1], ishape[x2]);
+  }
+
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      if (back) {
+        out_data.FlatTo1D<xpu, DType>(s) = 0;
+      }
+      Kernel<numpy_trace<3, req_type, back>, xpu>::Launch(s, dsize, out_data.dptr<DType>(),
+                                                          in_data.dptr<DType>(),
+                                                          Shape3(oleading, obody, otrailing),
+                                                          Shape3(ileading, ibody, itrailing),
+                                                          stride1 + stride2, offset, dlength);
+    });
+  });
+}
+
+template<typename xpu>
+void NumpyTraceOpForward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<TBlob>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const mxnet::TShape& ishape = inputs[0].shape_;
+  const mxnet::TShape& oshape = outputs[0].shape_;
+  const NumpyTraceParam& param = nnvm::get<NumpyTraceParam>(attrs.parsed);
+
+  NumpyTraceOpProcess<xpu, false>(in_data, out_data, ishape, oshape,
+                                  out_data.Size(), param, s, req);
+}
+
+template<typename xpu>
+void NumpyTraceOpBackward(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const mxnet::TShape& ishape = inputs[0].shape_;
+  const mxnet::TShape& oshape = outputs[0].shape_;
+  const NumpyTraceParam& param = nnvm::get<NumpyTraceParam>(attrs.parsed);
+
+  NumpyTraceOpProcess<xpu, true>(in_data, out_data, oshape, ishape,
+                                 in_data.Size(), param, s, req);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NUMPY_NP_TRACE_OP_INL_H_
diff --git a/src/operator/numpy/np_trace_op.cc b/src/operator/numpy/np_trace_op.cc
new file mode 100644
index 0000000..d97ac30
--- /dev/null
+++ b/src/operator/numpy/np_trace_op.cc
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file np_trace_op.cc
+ * \brief CPU Implementation of numpy-compatible trace operator
+ */
+
+#include "./np_trace_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool NumpyTraceOpShape(const nnvm::NodeAttrs& attrs,
+                              mxnet::ShapeVector* in_attrs,
+                              mxnet::ShapeVector* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  const int ndim((*in_attrs)[0].ndim());
+  if (ndim < 2) {
+    return false;
+  }
+  std::vector<int> oshape(ndim - 2);
+  const NumpyTraceParam& param = nnvm::get<NumpyTraceParam>(attrs.parsed);
+  int x1 = CheckAxis(param.axis1, (*in_attrs)[0].ndim());
+  int x2 = CheckAxis(param.axis2, (*in_attrs)[0].ndim());
+  CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the the same axis " << x1;
+  for ( int i = 0, j = 0; i < ndim; ++i ) {
+    if (i != x1 && i != x2) {
+      oshape[j++] = (*in_attrs)[0][i];
+    }
+  }
+  mxnet::TShape tshape(oshape.begin(), oshape.end());
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
+  return true;
+}
+
+DMLC_REGISTER_PARAMETER(NumpyTraceParam);
+
+NNVM_REGISTER_OP(_np_trace)
+.describe(R"code(Computes the sum of the diagonal elements of a matrix.
+Input is a tensor *A* of dimension *n >= 2*.
+
+If *n=2*, we sum the diagonal elements. The result has shape ().
+
+If *n>2*, *trace* is performed separately on the matrix defined by *axis1* and *axis2* for all
+inputs (batch mode).
+
+Examples::
+
+   // Single matrix reduction
+   A = [[1.0, 1.0], [1.0, 7.0]]
+   trace(A) = 8.0
+
+   // Batch matrix reduction
+   A = [[[1.0, 1.0], [1.0, 7.0]], [[3.0, 0], [0, 17.0]]]
+   trace(A) = [1.0, 18.0]
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<NumpyTraceParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyTraceOpShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", NumpyTraceOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_trace"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(NumpyTraceParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_np_trace)
+.set_attr_parser(ParamParser<NumpyTraceParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", NumpyTraceOpBackward<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_trace_op.cu b/src/operator/numpy/np_trace_op.cu
new file mode 100644
index 0000000..220e4ae
--- /dev/null
+++ b/src/operator/numpy/np_trace_op.cu
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file np_trace_op.cu
+ * \brief GPU Implementation of numpy-compatible trace operator
+ */
+#include "./np_trace_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_np_trace)
+.set_attr<FCompute>("FCompute<gpu>", NumpyTraceOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_np_trace)
+.set_attr<FCompute>("FCompute<gpu>", NumpyTraceOpBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 42d9043..8846d57 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2189,6 +2189,88 @@ def test_np_roll():
                                            numeric_eps=1e-3, rtol=1e-3, atol=1e-5, dtype=i_dtype[dtype])
 
 
+@with_seed()
+@use_np
+def test_np_trace():
+    class TestTrace(HybridBlock):
+        def __init__(self, axis1, axis2, offset):
+            super(TestTrace, self).__init__()
+            self._axis1 = axis1
+            self._axis2 = axis2
+            self._offset = offset
+
+        def hybrid_forward(self, F, data):
+            return F.np.trace(data, axis1=self._axis1, axis2=self._axis2, offset=self._offset)
+
+    def g(data, axis1, axis2, offset):
+        idx = _np.indices(data.shape)
+        ret = _np.zeros_like(data)
+        ret[idx[axis1] + offset == idx[axis2]] = 1.0
+        return ret
+
+    shapes = [
+        (3, 3),
+        (3, 4),
+        (0, 0),
+        (3, 3, 3),
+        (0, 0, 0),
+        (2, 2, 4, 3),
+        (2, 2, 4, 3),
+        (2, 0, 3, 0),
+        (2, 0, 2, 3)
+    ]
+    offsets = range(-5, 5)
+    dtypes = ['int32', 'float16', 'float32', 'float64']
+    for hybridize in [True, False]:
+        for shape in shapes:
+            ndim = len(shape)
+            for axis1 in range(-ndim, ndim):
+                for axis2 in range(-ndim, ndim):
+                    if (axis1 + ndim) % ndim != (axis2 + ndim) % ndim:
+                        for offset in offsets:
+                            for dtype in dtypes:
+                                if dtype == 'float16':
+                                    rtol = atol = 1e-2
+                                else:
+                                    rtol = atol = 1e-5
+                                test_trace = TestTrace(axis1, axis2, offset)
+                                if hybridize:
+                                    test_trace.hybridize()
+                                data_np = _np.random.uniform(-10.0, 10.0, shape)
+                                data = mx.nd.array(data_np, dtype=dtype)
+                                data_np = data.asnumpy()
+                                data.attach_grad()
+                                expected_np = _np.trace(data_np, axis1=axis1, axis2=axis2, offset=offset)
+                                with mx.autograd.record():
+                                    out_mx = test_trace(data.as_np_ndarray())
+                                assert out_mx.shape == expected_np.shape
+                                assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol)
+                                out_mx.backward()
+                                backward_expected = g(data_np, axis1=axis1, axis2=axis2, offset=offset)
+                                assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
+
+                                # Test imperative once again
+                                data = mx.nd.array(data_np, dtype=dtype)
+                                out_mx = np.trace(data.as_np_ndarray(), axis1=axis1, axis2=axis2, offset=offset)
+                                assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol)
+
+    # bad params
+    params = [
+        ([], 0, 1, 0),
+        ([2], 0, 1, 0),
+        ([3, 2, 2], 1, 1, 1),
+        ([3, 2, 2], 0, -4, 1)
+    ]
+    for shape, axis1, axis2, offset in params:
+        data_np = _np.random.uniform(-1.0, 1.0, shape)
+        data_mx = mx.nd.array(data_np)
+        try:
+            output = np.trace(data_mx.as_np_ndarray(), axis1=axis1, axis2=axis2, offset=offset)
+        except mx.base.MXNetError:
+            continue
+        assert False
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()