You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/31 21:11:17 UTC

[incubator-mxnet] branch master updated: Fix shape inference bug (#7682)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d71577  Fix shape inference bug (#7682)
6d71577 is described below

commit 6d7157768e216bcb4b505f98555fc72286d160fb
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Aug 31 14:11:14 2017 -0700

    Fix shape inference bug (#7682)
---
 src/executor/infer_graph_attr_pass.cc |  5 ++++-
 tests/python/unittest/test_symbol.py  | 25 +++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index 144c371..76b95f3 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -160,7 +160,10 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
           uint32_t eid = idx.entry_id(nid, igrad[i].index);
           if (fis_none(rshape[eid])) {
             rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
-          } else {
+          } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
+            // Need to skip empty forward shape, because it may not be
+            // available now and it is possible to infer the forward
+            // shape in one of the next a few passes
             CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
                 << "Backward shape inconsistent with the forward shape";
           }
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 4d162ec..4a2cdb3 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -286,6 +286,31 @@ def test_zero_prop2():
     assert False
 
 
+def test_simple_bind_special_case():
+    """This is a special case that results in shape inference
+    failure after moving simple_bind logic from frontend to backend.
+    Added here for testing against the network similar to the following one.
+
+    Network diagram:
+    weight --> abs_op --> sum_op --
+                                   |--> add_op
+    data   --> fc_op  --> sum_op --
+
+    Given data's shape, if the shape inference starts from weight node,
+    then the node entries of negative_op and sum_op are unknown in the
+    forward pass. Therefore, there are several unknown shapes after the
+    first forward pass is done. Now the backward inference pass starts with
+    the assumption that there are no unknown-shape node entries in the forward
+    pass, and consequently, leads to CHECK_EQ failure.
+    """
+    data_shape = (5, 13)
+    data = mx.sym.Variable('data')
+    fc = mx.sym.FullyConnected(data=data, num_hidden=1, no_bias=True, name='fc')
+    modified_weight = mx.sym.abs(fc.get_internals()['fc_weight'])
+    net = mx.sym.sum(modified_weight) + mx.sym.sum(fc)
+    net.simple_bind(ctx=mx.cpu(), data=data_shape)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].