You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/30 01:05:40 UTC

[GitHub] eric-haibin-lin closed pull request #10889: [MXNET-382] Shape and Size Operator

eric-haibin-lin closed pull request #10889: [MXNET-382] Shape and Size Operator
URL: https://github.com/apache/incubator-mxnet/pull/10889
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index 323344d69c0..dda534151a1 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -124,6 +124,8 @@ The `ndarray` package provides several classes:
     :nosignatures:
 
     NDArray.T
+    NDArray.shape_array
+    NDArray.size_array
     NDArray.reshape
     NDArray.reshape_like
     NDArray.flatten
@@ -375,6 +377,8 @@ The `ndarray` package provides several classes:
     :nosignatures:
 
     cast
+    shape_array
+    size_array
     reshape
     reshape_like
     flatten
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index cc63e13e6ec..304b17803ed 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -191,6 +191,8 @@ Composite multiple symbols into a new one by an operator.
     :nosignatures:
 
     Symbol.astype
+    Symbol.shape_array
+    Symbol.size_array
     Symbol.reshape
     Symbol.reshape_like
     Symbol.flatten
@@ -373,6 +375,8 @@ Composite multiple symbols into a new one by an operator.
     :nosignatures:
 
     cast
+    shape_array
+    size_array
     reshape
     reshape_like
     flatten
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 002ce3ebbc2..09395e2ec82 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1254,6 +1254,22 @@ def flatten(self, *args, **kwargs):
         """
         return op.flatten(self, *args, **kwargs)
 
+    def shape_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`shape_array`.
+
+        The arguments are the same as for :py:func:`shape_array`, with
+        this array as data.
+        """
+        return op.shape_array(self, *args, **kwargs)
+
+    def size_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`size_array`.
+
+        The arguments are the same as for :py:func:`size_array`, with
+        this array as data.
+        """
+        return op.size_array(self, *args, **kwargs)
+
     def expand_dims(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`expand_dims`.
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index c5e2f5cb77d..b041f4ef646 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1982,6 +1982,22 @@ def flatten(self, *args, **kwargs):
         """
         return op.flatten(self, *args, **kwargs)
 
+    def shape_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`shape_array`.
+
+        The arguments are the same as for :py:func:`shape_op`, with
+        this array as data.
+        """
+        return op.shape_array(self, *args, **kwargs)
+
+    def size_array(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`size_array`.
+
+        The arguments are the same as for :py:func:`size_array`, with
+        this array as data.
+        """
+        return op.size_array(self, *args, **kwargs)
+
     def expand_dims(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`expand_dims`.
 
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 19fe0749598..ae5a473d228 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1251,13 +1251,15 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
                np.dtype(np.float32): 1e-3,
                np.dtype(np.float64): 1e-5,
                np.dtype(np.uint8): 0,
-               np.dtype(np.int32): 0}
+               np.dtype(np.int32): 0,
+               np.dtype(np.int64): 0}
     elif isinstance(tol, numbers.Number):
         tol = {np.dtype(np.float16): tol,
                np.dtype(np.float32): tol,
                np.dtype(np.float64): tol,
                np.dtype(np.uint8): tol,
-               np.dtype(np.int32): tol}
+               np.dtype(np.int32): tol,
+               np.dtype(np.int64): tol}
 
     assert len(ctx_list) > 1
     if isinstance(sym, Symbol):
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 46f62651c75..5b89d49f430 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -398,6 +398,98 @@ NNVM_REGISTER_OP(reshape_like)
 .add_argument("lhs", "NDArray-or-Symbol", "First input.")
 .add_argument("rhs", "NDArray-or-Symbol", "Second input.");
 
+void ShapeComputeCPU(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  memcpy(out_data.dptr_, in_data.shape_.data(), in_data.ndim() * sizeof(int64_t));
+}
+
+NNVM_REGISTER_OP(shape_array)
+.describe(R"code(Returns a 1D int64 array containing the shape of data.
+
+Example::
+
+  shape_array([[1,2,3,4], [5,6,7,8]]) = [2,4]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<FCompute>("FCompute<cpu>", ShapeComputeCPU)
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TShape target_shape(1);
+    target_shape[0] = in_attrs->at(0).ndim();
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
+    return !shape_is_none(out_attrs->at(0));
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int>* in_attrs,
+     std::vector<int>* out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
+    return out_attrs->at(0) != -1;
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input Array.");
+
+void SizeComputeCPU(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const index_t size_var = in_data.Size();
+  memcpy(out_data.dptr_, &size_var, 1U * sizeof(int64_t));
+}
+
+NNVM_REGISTER_OP(size_array)
+.describe(R"code(Returns a 1D int64 array containing the size of data.
+
+Example::
+
+  size_array([[1,2,3,4], [5,6,7,8]]) = [8]
+
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<FCompute>("FCompute<cpu>", SizeComputeCPU)
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, 1U);
+    return !shape_is_none(out_attrs->at(0));
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int>* in_attrs,
+     std::vector<int>* out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
+    return out_attrs->at(0) != -1;
+  })
+.add_argument("data", "NDArray-or-Symbol", "Input Array.");
 
 DMLC_REGISTER_PARAMETER(CastParam);
 NNVM_REGISTER_OP(Cast)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 3c8b49ac0a2..8e2a0991b01 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -77,6 +77,52 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
 NNVM_REGISTER_OP(reshape_like)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
+void ShapeComputeGPU(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  cudaMemcpyAsync(out_data.dptr_,
+                  in_data.shape_.data(),
+                  in_data.ndim() * sizeof(int64_t),
+                  cudaMemcpyHostToDevice,
+                  mshadow::Stream<gpu>::GetStream(s));
+}
+
+NNVM_REGISTER_OP(shape_array)
+.set_attr<FCompute>("FCompute<gpu>", ShapeComputeGPU);
+
+void SizeComputeGPU(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  const index_t size_var = in_data.Size();
+  cudaMemcpyAsync(out_data.dptr_,
+                  &size_var,
+                  1U * sizeof(int64_t),
+                  cudaMemcpyHostToDevice,
+                  mshadow::Stream<gpu>::GetStream(s));
+}
+
+NNVM_REGISTER_OP(size_array)
+.set_attr<FCompute>("FCompute<gpu>", SizeComputeGPU);
+
 NNVM_REGISTER_OP(Cast)
 .set_attr<FCompute>("FCompute<gpu>", CastCompute<gpu>);
 
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index aeaa0b72679..4f0662de2c0 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -931,8 +931,8 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
                 assert almost_equal(regular.asnumpy(), fluent.asnumpy(), equal_nan=equal_nan)
 
     for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
-                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians',
-                 'exp', 'expm1', 'square', 'reciprocal', 'argmax_channel']:
+                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
+                 'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
         check_fluent_regular(func, {})
 
     for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index e07a602b8c1..0b429520f9a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -820,6 +820,24 @@ def fsigmoid(a):
     check_symbolic_forward(y, [xa], [ya])
     check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])
 
+@with_seed()
+def test_shape_array():
+    for i in range(1,6):
+        shape = rand_shape_nd(i)
+        x = np.random.ranf(shape)
+        y = mx.nd.shape_array(mx.nd.array(x))
+        expected_y = np.shape(x)
+        same(y.asnumpy(), expected_y)
+
+@with_seed()
+def test_size_array():
+    for i in range(1,6):
+        shape = rand_shape_nd(i)
+        x = np.random.ranf(shape)
+        y = mx.nd.size_array(mx.nd.array(x))
+        expected_y = np.size(x)
+        same(y.asnumpy(), expected_y)
+
 @with_seed()
 def test_hard_sigmoid():
     def fhardsigmoid(a, alpha=0.2, beta=0.5):
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 387428ab296..aece9a37812 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -190,8 +190,8 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
                                      equal_nan=equal_nan)
 
     for func in ['flatten', 'norm', 'round', 'rint', 'fix', 'floor', 'ceil', 'trunc', 'zeros_like',
-                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians',
-                 'exp', 'expm1',  'square', 'reciprocal', 'argmax_channel']:
+                 'ones_like', 'abs', 'sign', 'sin', 'cos', 'degrees', 'radians', 'exp', 'expm1',
+                 'square', 'reciprocal', 'argmax_channel', 'shape_array', 'size_array']:
         check_fluent_regular(func, {})
 
     for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services