You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/30 18:59:19 UTC

[GitHub] piiswrong closed pull request #11025: added ravel/unravel operators

piiswrong closed pull request #11025: added ravel/unravel operators
URL: https://github.com/apache/incubator-mxnet/pull/11025
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index 5bc3c52f2a7..323344d69c0 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -430,6 +430,8 @@ The `ndarray` package provides several classes:
     one_hot
     pick
     where
+    ravel_multi_index
+    unravel_index
 ```
 
 ## Mathematical functions
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index f1e90a0c4d3..cc63e13e6ec 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -291,6 +291,8 @@ Composite multiple symbols into a new one by an operator.
     Symbol.take
     Symbol.one_hot
     Symbol.pick
+    Symbol.ravel_multi_index
+    Symbol.unravel_index
 ```
 
 ### Get internal and output symbol
diff --git a/src/operator/tensor/ravel.cc b/src/operator/tensor/ravel.cc
new file mode 100644
index 00000000000..94e38948434
--- /dev/null
+++ b/src/operator/tensor/ravel.cc
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file ravel.cc
+ * \brief CPU-operators for ravel/unravel.
+ */
+#include "./ravel.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(RavelParam);
+
+NNVM_REGISTER_OP(_ravel_multi_index)
+.add_alias("ravel_multi_index")
+.describe(R"code(Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix. 
+
+Examples::
+   
+   A = [[3,6,6],[4,5,1]]
+   ravel(A, shape=(7,6)) = [22,41,37]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<RavelParam>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+  { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
+  { return std::vector<std::string>{"data"}; } )
+.set_attr<nnvm::FInferShape>("FInferShape", RavelOpShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", RavelForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("data", "NDArray-or-Symbol", "Batch of multi-indices")
+.add_arguments(RavelParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_unravel_index)
+.add_alias("unravel_index")
+.describe(R"code(Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix.
+
+Examples::
+
+   A = [22,41,37]
+   unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<RavelParam>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+  { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
+  { return std::vector<std::string>{"data"}; } )
+.set_attr<nnvm::FInferShape>("FInferShape", UnravelOpShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", UnravelForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("data", "NDArray-or-Symbol", "Array of flat indices")
+.add_arguments(RavelParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/ravel.cu b/src/operator/tensor/ravel.cu
new file mode 100644
index 00000000000..cae50482013
--- /dev/null
+++ b/src/operator/tensor/ravel.cu
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ravel.cu
+ * \brief GPU-Operators for ravel/unravel.
+ */
+#include "./ravel.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(_ravel_multi_index)
+.set_attr<FCompute>("FCompute<gpu>", RavelForward<gpu>);
+
+NNVM_REGISTER_OP(_unravel_index)
+.set_attr<FCompute>("FCompute<gpu>", UnravelForward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/ravel.h b/src/operator/tensor/ravel.h
new file mode 100644
index 00000000000..1eb61e1b681
--- /dev/null
+++ b/src/operator/tensor/ravel.h
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file ravel.h
+ * \brief Operators for ravel/unravel of indices.
+ */
+#ifndef MXNET_OPERATOR_TENSOR_RAVEL_H_
+#define MXNET_OPERATOR_TENSOR_RAVEL_H_
+
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <algorithm>
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct RavelParam : public dmlc::Parameter<RavelParam> {
+  TShape shape;
+  DMLC_DECLARE_PARAMETER(RavelParam) {
+    DMLC_DECLARE_FIELD(shape)
+      .set_default(TShape())
+      .describe("Shape of the array into which the multi-indices apply.");
+  }
+};
+
+inline bool RavelOpShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape>* in_attrs,
+                         std::vector<TShape>* out_attrs) {
+  using namespace mshadow;
+  const TShape& shape = nnvm::get<RavelParam>(attrs.parsed).shape;
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  CHECK_GT(shape.ndim(), 0) << "Empty shape parameter for ravel operator.";
+  if ((*in_attrs)[0].ndim() > 0) {
+    CHECK_EQ((*in_attrs)[0].ndim(), 2)
+      << "Input to ravel operator must be two-dimensional.";
+    CHECK_EQ((*in_attrs)[0][0], shape.ndim())
+      << "First dimension of input of ravel operator does not match shape parameter dimension.";
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, Shape1((*in_attrs)[0][1]));
+    return true;
+  }
+  if ((*out_attrs)[0].ndim() > 0) {
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, Shape2(shape.ndim(), (*out_attrs)[0][0]));
+    return true;
+  }
+  return false;
+}
+
+inline bool UnravelOpShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape>* in_attrs,
+                           std::vector<TShape>* out_attrs) {
+  using namespace mshadow;
+  const TShape& shape = nnvm::get<RavelParam>(attrs.parsed).shape;
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  CHECK_GT(shape.ndim(), 0) << "Empty shape parameter for unravel operator.";
+  if ((*in_attrs)[0].ndim() > 0) {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, Shape2(shape.ndim(), (*in_attrs)[0][0]));
+    return true;
+  }
+  if ((*out_attrs)[0].ndim() > 0) {
+    CHECK_EQ((*out_attrs)[0].ndim(), 2)
+      << "Output of unravel operator must be two-dimensional.";
+    CHECK_EQ((*out_attrs)[0][0], shape.ndim())
+      << "First dimension of output of ravel operator does not match shape parameter dimension.";
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, Shape1((*out_attrs)[0][1]));
+    return true;
+  }
+  return false;
+}
+
+struct ravel_index {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, index_t N, index_t ndim, index_t *shape,
+                                  DType *unravelled, DType *ravelled) {
+    index_t ret = 0;
+    #pragma unroll
+    for (index_t j = 0; j < ndim; ++j) {
+      ret = ret * shape[j] + unravelled[i+j*N];
+    }
+    ravelled[i] = ret;
+  }
+};
+
+struct unravel_index {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, index_t N, index_t ndim, index_t *shape,
+                                  DType *unravelled, DType *ravelled) {
+    index_t idx(ravelled[i]);
+    #pragma unroll
+    for (int j = ndim; j--; ) {
+      index_t tmp = idx / shape[j];
+      unravelled[i+j*N] = idx - tmp*shape[j];
+      idx = tmp;
+    }
+  }
+};
+
+template<typename xpu>
+void RavelForward(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TShape& shape = nnvm::get<RavelParam>(attrs.parsed).shape;
+  std::vector<index_t> buffer(shape.data(), shape.data()+shape.ndim());
+  Tensor<xpu, 1, index_t> work
+    = ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(shape.ndim()), s);
+  Copy(work, Tensor<cpu, 1, index_t>(&buffer[0], Shape1(buffer.size()), 0), s);
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    Tensor<xpu, 1, OType> in = inputs[0].FlatTo1D<xpu, OType>(s);
+    Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s);
+    mxnet_op::Kernel<ravel_index, xpu>::Launch(s, out.size(0), out.size(0), in.size(0)/out.size(0),
+                                               work.dptr_, in.dptr_, out.dptr_);
+  });
+}
+
+template<typename xpu>
+void UnravelForward(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TShape& shape = nnvm::get<RavelParam>(attrs.parsed).shape;
+  std::vector<index_t> buffer(shape.data(), shape.data()+shape.ndim());
+  Tensor<xpu, 1, index_t> work
+    = ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(shape.ndim()), s);
+  Copy(work, Tensor<cpu, 1, index_t>(&buffer[0], Shape1(buffer.size()), 0), s);
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    Tensor<xpu, 1, OType> in = inputs[0].FlatTo1D<xpu, OType>(s);
+    Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s);
+    mxnet_op::Kernel<unravel_index, xpu>::Launch(s, in.size(0), in.size(0), out.size(0)/in.size(0),
+                                                 work.dptr_, out.dptr_, in.dptr_);
+  });
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_TENSOR_RAVEL_H_
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index e7976e01f9d..e94e1240f8e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6006,6 +6006,21 @@ def test_activation():
         finite_diff_unary_op(
             name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps)
 
+@with_seed()
+def test_ravel():
+    # be aware that check_symbolic_forward will use float type internally
+    # for the arrays and that limits the representable flat index range.
+    # Taking dim==4 and a range of [0,..,100] for the data can already
+    # cause precision issues and break this test.
+    for dim in [1, 2, 3, 4]:
+      data = np.random.randint(50, size=(dim, 500))
+      shape = tuple(np.add(np.amax(data, axis=1), [1]))
+      a = mx.sym.Variable('a')
+      ravel_npy = np.ravel_multi_index(data, shape)
+      b = mx.sym.ravel_multi_index(a, shape=shape)
+      check_symbolic_forward(b, location={'a': data}, expected=[ravel_npy])
+      c = mx.sym.unravel_index(a, shape=shape)
+      check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data])
 
 def test_context_num_gpus():
     try:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services