You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/31 21:11:17 UTC
[incubator-mxnet] branch master updated: Fix shape inference bug
(#7682)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6d71577 Fix shape inference bug (#7682)
6d71577 is described below
commit 6d7157768e216bcb4b505f98555fc72286d160fb
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Aug 31 14:11:14 2017 -0700
Fix shape inference bug (#7682)
---
src/executor/infer_graph_attr_pass.cc | 5 ++++-
tests/python/unittest/test_symbol.py | 25 +++++++++++++++++++++++++
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index 144c371..76b95f3 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -160,7 +160,10 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
uint32_t eid = idx.entry_id(nid, igrad[i].index);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
- } else {
+ } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
+ // Need to skip empty forward shape, because it may not be
+ // available now and it is possible to infer the forward
+ // shape in one of the next a few passes
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
<< "Backward shape inconsistent with the forward shape";
}
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 4d162ec..4a2cdb3 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -286,6 +286,31 @@ def test_zero_prop2():
assert False
+def test_simple_bind_special_case():
+ """This is a special case that results in shape inference
+ failure after moving simple_bind logic from frontend to backend.
+ Added here for testing against the network similar to the following one.
+
+ Network diagram:
+ weight --> abs_op --> sum_op --
+ |--> add_op
+ data --> fc_op --> sum_op --
+
+ Given data's shape, if the shape inference starts from weight node,
+ then the node entries of negative_op and sum_op are unknown in the
+ forward pass. Therefore, there are several unknown shapes after the
+ first forward pass is done. Now the backward inference pass starts with
+ the assumption that there are no unknown-shape node entries in the forward
+ pass, and consequently, leads to CHECK_EQ failure.
+ """
+ data_shape = (5, 13)
+ data = mx.sym.Variable('data')
+ fc = mx.sym.FullyConnected(data=data, num_hidden=1, no_bias=True, name='fc')
+ modified_weight = mx.sym.abs(fc.get_internals()['fc_weight'])
+ net = mx.sym.sum(modified_weight) + mx.sym.sum(fc)
+ net.simple_bind(ctx=mx.cpu(), data=data_shape)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].