You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/09/08 19:30:12 UTC
[incubator-mxnet] branch master updated: allow foreach on input
with 0 length (#12471)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4eb7626 allow foreach on input with 0 length (#12471)
4eb7626 is described below
commit 4eb7626fcc8c13edbc3fed3ce62dc1ec08a2a20a
Author: Lai Wei <ro...@gmail.com>
AuthorDate: Sat Sep 8 12:29:58 2018 -0700
allow foreach on input with 0 length (#12471)
* allow foreach on input with 0 length
* add test foreach with unknown dim
---
src/operator/control_flow.cc | 1 -
tests/python/unittest/test_contrib_control_flow.py | 9 +++++++++
2 files changed, 9 insertions(+), 1 deletion(-)
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index d6b6703..ba7f5c0 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -314,7 +314,6 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
// For the shape of output data.
size_t len = in_shape->at(0)[0];
- CHECK_GT(len, 0);
for (int i = 0; i < params.num_out_data; i++) {
// If the output shape isn't inferred, we don't need to propogate the info.
const auto& g_out_shape = subg_out_shape[i];
diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py
index 1c23c91..dd5a4d6 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -2146,6 +2146,15 @@ def test_output_format_cond():
for i in range(len(out1)):
assert_almost_equal(out1[i].asnumpy(), out2[i].asnumpy(), rtol=0.001, atol=0.0001)
+def test_foreach_with_unkown_dim():
+ # MXNet supports using 0 as placeholder for unknown dimensions in shape
+ step = lambda data, states: (data + states[0], [states[0] * 2])
+ # input shape with NCHW format and N is unknown
+ data = mx.sym.var('data', shape=(0, 3, 32, 32))
+ states = [mx.sym.var('state')]
+ outs, states = mx.sym.contrib.foreach(step, data, states)
+ _, output_shape, _ = outs.infer_shape_partial()
+ assert_allclose((0, 3, 32, 32), output_shape[0])
if __name__ == '__main__':
import nose