You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/05/25 04:06:54 UTC

[incubator-mxnet] branch master updated: Revert the change broadcast_to param shape (#14998)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9250a73  Revert the change broadcast_to param shape (#14998)
9250a73 is described below

commit 9250a73d7891c235b91a8f06b14a6bf687892487
Author: reminisce <wu...@gmail.com>
AuthorDate: Fri May 24 21:06:25 2019 -0700

    Revert the change broadcast_to param shape (#14998)
---
 src/operator/tensor/broadcast_reduce_op.h |  4 ++--
 tests/python/unittest/test_operator.py    | 13 +++++++++++--
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 1723e9a..c7c4993 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -379,7 +379,7 @@ inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs,
 
 inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
                              mxnet::ShapeVector *in_attrs,
-                            mxnet::ShapeVector *out_attrs) {
+                             mxnet::ShapeVector *out_attrs) {
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
   mxnet::TShape& ishape = (*in_attrs)[0];
@@ -389,7 +389,7 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
     << "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape;
   mxnet::TShape oshape = param.shape;
   for (int i = 0; i < ishape.ndim(); ++i) {
-    if (oshape[i] != -1) {
+    if (oshape[i] != 0) {
       CHECK(ishape[i] == oshape[i] || ishape[i] == 1)
         << "Array cannot be broadcasted from " << ishape << " to " << param.shape;
     } else {
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index cb9b2f9..a419bc5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2535,19 +2535,27 @@ def test_broadcast():
         size = tuple([shape[ele] for ele in axis])
         for ele in axis:
             shape[ele] = 1
+        target_shape_with_zero = list(target_shape)
+        for idx in range(len(target_shape_with_zero)):
+            if idx not in axis:
+                target_shape_with_zero[idx] = 0
+                break
+
         a = mx.symbol.Variable('a')
         sym_bcast_axis = mx.symbol.broadcast_axis(a, axis=axis, size=size)
         sym_bcast_to = mx.symbol.broadcast_to(a, shape=tuple(target_shape))
+        sym_bcast_to_with_zero = mx.symbol.broadcast_to(a, shape=tuple(target_shape_with_zero))
         sym_bcast_like = mx.symbol.broadcast_like(a, sym_bcast_to)
+
         def test_broadcasting_ele(sym_bcast):
             dat_npy = np.random.rand(*shape)
             groundtruth = dat_npy
             grad_nd = mx.nd.empty(shape)
             outgrad_npy = np.random.rand(*target_shape)
             grad_groundtruth = np_reduce(outgrad_npy, axis=axis, keepdims=True,
-                                          numpy_reduce_func=np.sum)
+                                         numpy_reduce_func=np.sum)
             net = sym_bcast.bind(default_context(), args={'a': mx.nd.array(dat_npy)},
-                                                 args_grad={'a': grad_nd})
+                                 args_grad={'a': grad_nd})
             net.forward(is_train=True)
             assert (net.outputs[0].shape == target_shape).all()
             assert_almost_equal(net.outputs[0].asnumpy(), groundtruth, rtol=1e-4)
@@ -2555,6 +2563,7 @@ def test_broadcast():
             assert_almost_equal(grad_nd.asnumpy(), grad_groundtruth, rtol=1e-4)
         test_broadcasting_ele(sym_bcast_axis)
         test_broadcasting_ele(sym_bcast_to)
+        test_broadcasting_ele(sym_bcast_to_with_zero)
         test_broadcasting_ele(sym_bcast_like)