You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/09/08 19:30:12 UTC

[incubator-mxnet] branch master updated: allow foreach on input with 0 length (#12471)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4eb7626  allow foreach on input with 0 length (#12471)
4eb7626 is described below

commit 4eb7626fcc8c13edbc3fed3ce62dc1ec08a2a20a
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Sat Sep 8 12:29:58 2018 -0700

    allow foreach on input with 0 length (#12471)
    
    * allow foreach on input with 0 length
    
    * add test foreach with unknown dim
---
 src/operator/control_flow.cc                       | 1 -
 tests/python/unittest/test_contrib_control_flow.py | 9 +++++++++
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index d6b6703..ba7f5c0 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -314,7 +314,6 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
 
   // For the shape of output data.
   size_t len = in_shape->at(0)[0];
-  CHECK_GT(len, 0);
   for (int i = 0; i < params.num_out_data; i++) {
     // If the output shape isn't inferred, we don't need to propogate the info.
     const auto& g_out_shape = subg_out_shape[i];
diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py
index 1c23c91..dd5a4d6 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -2146,6 +2146,15 @@ def test_output_format_cond():
         for i in range(len(out1)):
             assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001)
 
+def test_foreach_with_unkown_dim():
+    # MXNet supports using 0 as placeholder for unknown dimensions in shape
+    step = lambda data, states: (data + states[0], [states[0] * 2])
+    # input shape with NCHW format and N is unknown
+    data = mx.sym.var('data', shape=(0, 3, 32, 32))
+    states = [mx.sym.var('state')]
+    outs, states = mx.sym.contrib.foreach(step, data, states)
+    _, output_shape, _ = outs.infer_shape_partial()
+    assert_allclose((0, 3, 32, 32), output_shape[0])
 
 if __name__ == '__main__':
     import nose