You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/30 01:05:45 UTC

[incubator-mxnet] branch master updated: [MXNET-382] Shape and Size Operator (#10889)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 33022f8  [MXNET-382] Shape and Size Operator (#10889)
33022f8 is described below

commit 33022f82a3088e95b61ed8ef3659e8c55433f556
Author: Anirudh <an...@gmail.com>
AuthorDate: Fri Jun 29 18:05:38 2018 -0700

    [MXNET-382] Shape and Size Operator (#10889)
    
    * Shape Operator
    
    * cuda
    
    * size op
    
    * lint issues
    
    * docs example
    
    * add docs, change op name to avoid conflict, add convenience confluent
    method
    
    * change name to _nd
    
    * fix test cases, add new kernel
    
    * test name fix.
    
    * solve gpu memory problem for size and shape
    
    * get rid of FIgnoreInputs attr of shape_nd
    
    * op name change
    
    * fix
    
    * retrigger CI
    
    * retrigger CI
    
    * retrigger CI
    
    * trigger CI
    
    * fix comments
    
    * cpplint
    
    * nit
    
    * trigger CI
---
 docs/api/python/ndarray/ndarray.md             |  4 ++
 docs/api/python/symbol/symbol.md               |  4 ++
 python/mxnet/ndarray/ndarray.py                | 16 +++++
 python/mxnet/symbol/symbol.py                  | 16 +++++
 python/mxnet/test_utils.py                     |  6 +-
 src/operator/tensor/elemwise_unary_op_basic.cc | 92 ++++++++++++++++++++++++++
 src/operator/tensor/elemwise_unary_op_basic.cu | 46 +++++++++++++
 tests/python/unittest/test_ndarray.py          |  4 +-
 tests/python/unittest/test_operator.py         | 18 +++++
 tests/python/unittest/test_symbol.py           |  4 +-
 10 files changed, 204 insertions(+), 6 deletions(-)

diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index 323344d..dda5341 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -124,6 +124,8 @@ The `ndarray` package provides several classes:
     :nosignatures:
 
     NDArray.T
+    NDArray.shape_array
+    NDArray.size_array
     NDArray.reshape
     NDArray.reshape_like
     NDArray.flatten
@@ -375,6 +377,8 @@ The `ndarray` package provides several classes:
     :nosignatures:
 
     cast
+    shape_array
+    size_array
     reshape
     reshape_like
     flatten
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index cc63e13..304b178 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -191,6 +191,8 @@ Composite multiple symbols into a new one by an operator.
     :nosignatures:
 
     Symbol.astype
+    Symbol.shape_array
+    Symbol.size_array
     Symbol.reshape
     Symbol.reshape_like
     Symbol.flatten
@@ -373,6 +375,8 @@ Composite multiple symbols into a new one by an operator.
     :nosignatures:
 
     cast
+    shape_array
+    size_array
     reshape
     reshape_like
     flatten
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 002ce3e..09395e2 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1254,6 +1254,22 @@ fixed-size items.
         """
         return op.flatten(self, *args, **kwargs)
 
+    def shape_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`shape_array`.
+
+        The arguments are the same as for :py:func:`shape_array`, with
+        this array as data.
+        """
+        return op.shape_array(self, *args, **kwargs)
+
+    def size_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`size_array`.
+
+        The arguments are the same as for :py:func:`size_array`, with
+        this array as data.
+        """
+        return op.size_array(self, *args, **kwargs)
+
     def expand_dims(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`expand_dims`.
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index c5e2f5c..b041f4e 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1982,6 +1982,22 @@ class Symbol(SymbolBase):
         """
         return op.flatten(self, *args, **kwargs)
 
+    def shape_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`shape_array`.
+
+        The arguments are the same as for :py:func:`shape_op`, with
+        this array as data.
+        """
+        return op.shape_array(self, *args, **kwargs)
+
+    def size_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`size_array`.
+
+        The arguments are the same as for :py:func:`size_array`, with
+        this array as data.
+        """
+        return op.size_array(self, *args, **kwargs)
+
     def expand_dims(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`expand_dims`.
 
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 19fe074..ae5a473 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1251,13 +1251,15 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
                np.dtype(np.float32): 1e-3,
                np.dtype(np.float64): 1e-5,
                np.dtype(np.uint8): 0,
-               np.dtype(np.int32): 0}
+               np.dtype(np.int32): 0,
+               np.dtype(np.int64): 0}
     elif isinstance(tol, numbers.Number):
         tol = {np.dtype(np.float16): tol,
                np.dtype(np.float32): tol,
                np.dtype(np.float64): tol,
                np.dtype(np.uint8): tol,
-               np.dtype(np.int32): tol}
+               np.dtype(np.int32): tol,
+               np.dtype(np.int64): tol}
 
     assert len(ctx_list) > 1
     if isinstance(sym, Symbol):
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 46f6265..5b89d49 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -398,6 +398,98 @@ NNVM_REGISTER_OP(reshape_like)
 .add_argument("lhs", "NDArray-or-Symbol", "First input.")
 .add_argument("rhs", "NDArray-or-Symbol", "Second input.");
 
+void ShapeComputeCPU(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  memcpy(out_data.dptr_, in_data.shape_.data(), in_data.ndim() * sizeof(int64_t));
+}
+
+NNVM_REGISTER_OP(shape_array)
+.describe(R"code(Returns a 1D int64 array containing the shape of data.
+
+Example::
+
+  shape_array([[1,2,3,4], [5,6,7,8]]) = [2,4]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<FCompute>("FCompute<cpu>", ShapeComputeCPU)
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TShape target_shape(1);
+    target_shape[0] = in_attrs->at(0).ndim();
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
+    return !shape_is_none(out_attrs->at(0));
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int>* in_attrs,
+     std::vector<int>* out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
+    return out_attrs->at(0) != -1;
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input Array.");
+
+void SizeComputeCPU(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const index_t size_var = in_data.Size();
+  memcpy(out_data.dptr_, &size_var, 1U * sizeof(int64_t));
+}
+
+NNVM_REGISTER_OP(size_array)
+.describe(R"code(Returns a 1D int64 array containing the size of data.
+
+Example::
+
+  size_array([[1,2,3,4], [5,6,7,8]]) = [8]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<FCompute>("FCompute<cpu>", SizeComputeCPU)
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, 1U);
+    return !shape_is_none(out_attrs->at(0));
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int>* in_attrs,
+     std::vector<int>* out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
+    return out_attrs->at(0) != -1;
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input Array.");
 
 DMLC_REGISTER_PARAMETER(CastParam);
 NNVM_REGISTER_OP(Cast)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 3c8b49a..8e2a099 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -77,6 +77,52 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
 NNVM_REGISTER_OP(reshape_like)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
+void ShapeComputeGPU(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  cudaMemcpyAsync(out_data.dptr_,
+                  in_data.shape_.data(),
+                  in_data.ndim() * sizeof(int64_t),
+                  cudaMemcpyHostToDevice,
+                  mshadow::Stream<gpu>::GetStream(s));
+}
+
+NNVM_REGISTER_OP(shape_array)
+.set_attr<FCompute>("FCompute<gpu>", ShapeComputeGPU);
+
+void SizeComputeGPU(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  const index_t size_var = in_data.Size();
+  cudaMemcpyAsync(out_data.dptr_,
+                  &size_var,
+                  1U * sizeof(int64_t),
+                  cudaMemcpyHostToDevice,
+                  mshadow::Stream<gpu>::GetStream(s));
+}
+
+NNVM_REGISTER_OP(size_array)
+.set_attr<FCompute>("FCompute<gpu>", SizeComputeGPU);
+
 NNVM_REGISTER_OP(Cast)
 .set_attr<FCompute>("FCompute<gpu>", CastCompute<gpu>);
 
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index a01514d..cf5906a 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -932,8 +932,8 @@ def test_ndarray_fluent():
                 assert almost_equal(regular.asnumpy(), fluent.asnumpy(), equal_nan=equal_nan)
 
     for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
-                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians',
-                 'exp', 'expm1', 'square', 'reciprocal', 'argmax_channel']:
+                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
+                 'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
         check_fluent_regular(func, {})
 
     for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 2707d8f..28ad555 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -823,6 +823,24 @@ def test_sigmoid():
     check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])
 
 @with_seed()
+def test_shape_array():
+    for i in range(1,6):
+        shape = rand_shape_nd(i)
+        x = np.random.ranf(shape)
+        y = mx.nd.shape_array(mx.nd.array(x))
+        expected_y = np.shape(x)
+        same(y.asnumpy(), expected_y)
+
+@with_seed()
+def test_size_array():
+    for i in range(1,6):
+        shape = rand_shape_nd(i)
+        x = np.random.ranf(shape)
+        y = mx.nd.size_array(mx.nd.array(x))
+        expected_y = np.size(x)
+        same(y.asnumpy(), expected_y)
+
+@with_seed()
 def test_hard_sigmoid():
     def fhardsigmoid(a, alpha=0.2, beta=0.5):
         return np.maximum(np.zeros(a.shape, dtype=a.dtype),
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 387428a..aece9a3 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -190,8 +190,8 @@ def test_symbol_fluent():
                                      equal_nan=equal_nan)
 
     for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
-                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians',
-                 'exp', 'expm1',  'square', 'reciprocal', 'argmax_channel']:
+                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
+                 'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
         check_fluent_regular(func, {})
 
     for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',