You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/05/25 04:06:54 UTC
[incubator-mxnet] branch master updated: Revert the change
broadcast_to param shape (#14998)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9250a73 Revert the change broadcast_to param shape (#14998)
9250a73 is described below
commit 9250a73d7891c235b91a8f06b14a6bf687892487
Author: reminisce <wu...@gmail.com>
AuthorDate: Fri May 24 21:06:25 2019 -0700
Revert the change broadcast_to param shape (#14998)
---
src/operator/tensor/broadcast_reduce_op.h | 4 ++--
tests/python/unittest/test_operator.py | 13 +++++++++++--
2 files changed, 13 insertions(+), 4 deletions(-)
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 1723e9a..c7c4993 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -379,7 +379,7 @@ inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs,
inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
- mxnet::ShapeVector *out_attrs) {
+ mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
@@ -389,7 +389,7 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
<< "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape;
mxnet::TShape oshape = param.shape;
for (int i = 0; i < ishape.ndim(); ++i) {
- if (oshape[i] != -1) {
+ if (oshape[i] != 0) {
CHECK(ishape[i] == oshape[i] || ishape[i] == 1)
<< "Array cannot be broadcasted from " << ishape << " to " << param.shape;
} else {
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index cb9b2f9..a419bc5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2535,19 +2535,27 @@ def test_broadcast():
size = tuple([shape[ele] for ele in axis])
for ele in axis:
shape[ele] = 1
+ target_shape_with_zero = list(target_shape)
+ for idx in range(len(target_shape_with_zero)):
+ if idx not in axis:
+ target_shape_with_zero[idx] = 0
+ break
+
a = mx.symbol.Variable('a')
sym_bcast_axis = mx.symbol.broadcast_axis(a, axis=axis, size=size)
sym_bcast_to = mx.symbol.broadcast_to(a, shape=tuple(target_shape))
+ sym_bcast_to_with_zero = mx.symbol.broadcast_to(a, shape=tuple(target_shape_with_zero))
sym_bcast_like = mx.symbol.broadcast_like(a, sym_bcast_to)
+
def test_broadcasting_ele(sym_bcast):
dat_npy = np.random.rand(*shape)
groundtruth = dat_npy
grad_nd = mx.nd.empty(shape)
outgrad_npy = np.random.rand(*target_shape)
grad_groundtruth = np_reduce(outgrad_npy, axis=axis, keepdims=True,
- numpy_reduce_func=np.sum)
+ numpy_reduce_func=np.sum)
net = sym_bcast.bind(default_context(), args={'a': mx.nd.array(dat_npy)},
- args_grad={'a': grad_nd})
+ args_grad={'a': grad_nd})
net.forward(is_train=True)
assert (net.outputs[0].shape == target_shape).all()
assert_almost_equal(net.outputs[0].asnumpy(), groundtruth, rtol=1e-4)
@@ -2555,6 +2563,7 @@ def test_broadcast():
assert_almost_equal(grad_nd.asnumpy(), grad_groundtruth, rtol=1e-4)
test_broadcasting_ele(sym_bcast_axis)
test_broadcasting_ele(sym_bcast_to)
+ test_broadcasting_ele(sym_bcast_to_with_zero)
test_broadcasting_ele(sym_bcast_like)