You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/08/18 00:23:26 UTC
[incubator-mxnet] branch master updated: [MXNET-806] Report error
when shape mismatch in "where" operator (#12174)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 54af0cc [MXNET-806] Report error when shape mismatch in "where" operator (#12174)
54af0cc is described below
commit 54af0ccabaceb19b60e69fdafd2975142a9018ef
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Fri Aug 17 17:23:18 2018 -0700
[MXNET-806] Report error when shape mismatch in "where" operator (#12174)
* Fix undefined behavior in where operator
* Add unit test
---
src/operator/tensor/control_flow_op.h | 2 +-
tests/python/unittest/test_operator.py | 16 +++++++++++++++-
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h
index 503bc7c..94e6510 100644
--- a/src/operator/tensor/control_flow_op.h
+++ b/src/operator/tensor/control_flow_op.h
@@ -188,7 +188,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs,
SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape);
return true;
} else if ((*in_attrs)[0].ndim() == 1) {
- return (*in_attrs)[0].Size() == static_cast<size_t>(tshape[0]);
+ CHECK_EQ((*in_attrs)[0].Size(), static_cast<size_t>(tshape[0]));
}
return false;
}
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 55b797c..c49d402 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4473,6 +4473,20 @@ def test_where():
condition_np, x_np, y_np = get_forward_inputs_condition_vector(shape)
check_numeric_gradient(where_sym, [condition_np, x_np, y_np], grad_nodes=['x', 'y'])
+ def test_invalid_shape():
+ condition = mx.sym.Variable('condition')
+ x = mx.sym.Variable('x')
+ y = mx.sym.Variable('y')
+ where_sym = mx.sym.where(condition, x, y)
+
+ assert_exception(lambda: where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
+ y=mx.nd.array([[8,9],[10,11],[12,13]]),
+ condition=mx.nd.array([1,0])), MXNetError)
+
+ assert_exception(lambda: mx.nd.where(x=mx.nd.array([[2,3],[4,5],[6,7]]),
+ y=mx.nd.array([[8,9],[10,11],[12,13]]),
+ condition=mx.nd.array([1,0])), MXNetError)
+
test_where_helper((5, 9), True)
test_where_helper((5, 9), False)
test_where_helper((5, 7, 9), True)
@@ -4483,7 +4497,7 @@ def test_where():
test_where_numeric_gradient((5, 9), False)
test_where_numeric_gradient((5, 7, 9), True)
test_where_numeric_gradient((5, 7, 9), False)
-
+ test_invalid_shape()
@with_seed()
def test_new_softmax():