You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/08/18 00:23:26 UTC

[incubator-mxnet] branch master updated: [MXNET-806] Report error when shape mismatch in "where" operator (#12174)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 54af0cc  [MXNET-806] Report error when shape mismatch in "where" operator (#12174)
54af0cc is described below

commit 54af0ccabaceb19b60e69fdafd2975142a9018ef
Author: Lin Yuan <ap...@gmail.com>
AuthorDate: Fri Aug 17 17:23:18 2018 -0700

    [MXNET-806] Report error when shape mismatch in "where" operator (#12174)
    
    * Fix undefined behavior in where operator
    
    * Add unit test
---
 src/operator/tensor/control_flow_op.h  |  2 +-
 tests/python/unittest/test_operator.py | 16 +++++++++++++++-
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h
index 503bc7c..94e6510 100644
--- a/src/operator/tensor/control_flow_op.h
+++ b/src/operator/tensor/control_flow_op.h
@@ -188,7 +188,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs,
     SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape);
     return true;
   } else if ((*in_attrs)[0].ndim() == 1) {
-    return (*in_attrs)[0].Size() == static_cast<size_t>(tshape[0]);
+    CHECK_EQ((*in_attrs)[0].Size(), static_cast<size_t>(tshape[0]));
   }
   return false;
 }
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 55b797c..c49d402 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4473,6 +4473,20 @@ def test_where():
             condition_np, x_np, y_np = get_forward_inputs_condition_vector(shape)
         check_numeric_gradient(where_sym, [condition_np, x_np, y_np], grad_nodes=['x', 'y'])
 
+    def test_invalid_shape():
+        condition = mx.sym.Variable('condition')
+        x = mx.sym.Variable('x')
+        y = mx.sym.Variable('y')
+        where_sym = mx.sym.where(condition, x, y)
+	
+        assert_exception(lambda: where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
+                                                y=mx.nd.array([[8,9],[10,11],[12,13]]),
+                                                condition=mx.nd.array([1,0])), MXNetError)
+
+        assert_exception(lambda: mx.nd.where(x=mx.nd.array([[2,3],[4,5],[6,7]]),
+                                             y=mx.nd.array([[8,9],[10,11],[12,13]]),
+                                             condition=mx.nd.array([1,0])), MXNetError)
+
     test_where_helper((5, 9), True)
     test_where_helper((5, 9), False)
     test_where_helper((5, 7, 9), True)
@@ -4483,7 +4497,7 @@ def test_where():
     test_where_numeric_gradient((5, 9), False)
     test_where_numeric_gradient((5, 7, 9), True)
     test_where_numeric_gradient((5, 7, 9), False)
-
+    test_invalid_shape()
 
 @with_seed()
 def test_new_softmax():