You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/10 06:34:13 UTC
[incubator-mxnet] branch master updated: [MXNET-749] Correct usages
of `CutSubgraph` in 3 control flow operators (#12078)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new af15853 [MXNET-749] Correct usages of `CutSubgraph` in 3 control flow operators (#12078)
af15853 is described below
commit af15853a5989e8e40a7300669c8d0c42f0be79ad
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Aug 9 23:34:02 2018 -0700
[MXNET-749] Correct usages of `CutSubgraph` in 3 control flow operators (#12078)
* Fix cut graph
* Copy only when necessary
* Add unittest for while_loop
* Add unittest for foreach
* Add unittest for cond
* Avoid magic number: 0 => kUndefinedStorage
---
python/mxnet/symbol/contrib.py | 28 +++---
src/c_api/c_api_symbolic.cc | 6 +-
src/operator/control_flow.cc | 3 +
tests/python/unittest/test_contrib_control_flow.py | 101 +++++++++++++++++++++
4 files changed, 123 insertions(+), 15 deletions(-)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 1d42cf7..38195bd 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -127,7 +127,7 @@ def _cut_subgraph(subg):
# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
-def _construct_subgraph(sym_out, sym_states):
+def _construct_subgraph(sym_out, sym_states, name):
sym_out = _as_list(sym_out)
sym_states = _as_list(sym_states)
all_outputs = []
@@ -137,18 +137,16 @@ def _construct_subgraph(sym_out, sym_states):
flat_out = []
all_input_names = g.list_inputs()
- output_names = [o.name for o in sym_out]
+ output_names = {o.name for o in sym_out}
for o in sym_out:
- if o.name in all_input_names:
+ if o.name in all_input_names or o.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(o))
else:
flat_out.append(o)
for s in sym_states:
- if s.name in all_input_names or s.name in output_names:
- # There is a problem if the outputs are the same as the inputs
- # or the first output. By calling identity, we can make sure that
- # all symbols will refer to different NDArrays.
+ if s.name in all_input_names or s.name in output_names or \
+ s.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(s))
else:
flat_out.append(s)
@@ -256,7 +254,7 @@ def foreach(body, data, init_states, name="foreach"):
num_out_data = len(sym_out)
num_states = len(sym_states)
num_outputs = num_out_data + num_states
- g = _construct_subgraph(sym_out, sym_states)
+ g = _construct_subgraph(sym_out, sym_states, name)
input_syms = _get_graph_inputs(g)
cut_syms = _cut_subgraph(g)
@@ -469,9 +467,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
num_outputs = len(outputs) + len(final_state)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
- all_input_names = symbol.Group(outputs + final_state).list_inputs()
- make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
+ all_input_names = symbol.Group(outputs + final_state).list_inputs()
+ in_input = lambda x: x.name in all_input_names
+ in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
+ make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
+ else x
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
return graph, num_out_data, num_outputs
@@ -627,9 +628,12 @@ def cond(pred, then_func, else_func, name="cond"):
num_outputs = len(outputs)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
- all_input_names = symbol.Group(outputs).list_inputs()
- make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
+ all_input_names = symbol.Group(outputs).list_inputs()
+ in_input = lambda x: x.name in all_input_names
+ in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
+ make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
+ else x
graph = symbol.Group(list(map(make_identity, outputs)))
return graph, num_outputs
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index c27a59a..35ecec7 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -372,13 +372,13 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
// a subgraph.
API_BEGIN();
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
- std::string subg_attr = "__subgraph_name__";
+ const std::string subg_attr = "__subgraph_name__";
auto out_node = s->outputs[0].node;
auto it = out_node->attrs.dict.find(subg_attr);
if (it != out_node->attrs.dict.end()) {
- std::string subg_name = it->second;
+ const std::string &subg_name = it->second;
std::vector<nnvm::NodeEntry *> input_entries;
- DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries]
+ DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
(nnvm::NodePtr n) {
// If the node itself isn't in the subgraph, we ignore it.
auto it = n->attrs.dict.find(subg_attr);
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index 7c1becc..d6b6703 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -1225,6 +1225,9 @@ static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs,
CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf));
return ret;
};
+ for (const dim_t &cond_in : params.cond_input_locs) {
+ (*out_attrs)[cond_in] = kDefaultStorage;
+ }
bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs);
bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs);
return succ_0 && succ_1;
diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py
index f1188b5..a4b794c 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -1664,6 +1664,107 @@ def test_foreach_rnn():
check_foreach_rnn(cell_type, num_states)
+@with_seed()
+def test_cut_subgraph_foreach():
+ class TestLayer(gluon.HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(TestLayer, self).__init__(prefix=prefix, params=params)
+
+ def hybrid_forward(self, F, inputs, states):
+ def step1(data, states):
+ return data + 1, states
+ out1, states1 = F.contrib.foreach(step1, inputs, states)
+ out2, states2 = F.contrib.foreach(step1, out1, states)
+ def step2(data, states):
+ return data + states[0], states1
+ out, states = F.contrib.foreach(step2, out2, states)
+ return out
+
+ data = mx.nd.normal(loc=0, scale=1, shape=(5, 10))
+ states = mx.nd.normal(loc=0, scale=1, shape=(10))
+ layer = TestLayer()
+ layer.initialize(ctx=default_context())
+ res1 = layer(data, [states])
+
+ with mx.autograd.record():
+ res1 = layer(data, [states])
+
+ layer = TestLayer()
+ layer.initialize(ctx=default_context())
+ layer.hybridize()
+ res2 = layer(data, [states])
+
+ with mx.autograd.record():
+ res2 = layer(data, [states])
+ assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_cut_subgraph_while_loop():
+ class TestLayer(gluon.HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(TestLayer, self).__init__(prefix=prefix, params=params)
+ def hybrid_forward(self, F, data):
+ out1, data1 = F.contrib.while_loop(
+ cond=lambda i: i <= 5,
+ func=lambda i: (None, (i + 1, )),
+ loop_vars=(data, ),
+ max_iterations=10,
+ )
+ out2, data2 = F.contrib.while_loop(
+ cond=lambda i: data1[0],
+ func=lambda i: (None, (i + 1, )),
+ loop_vars=data1[0],
+ max_iterations=10,
+ )
+ return data2[0]
+ data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
+ layer = TestLayer()
+ layer.initialize(ctx=default_context())
+ res1 = layer(data)
+ with mx.autograd.record():
+ res1 = layer(data)
+ layer = TestLayer()
+ layer.initialize(ctx=default_context())
+ layer.hybridize()
+ res2 = layer(data)
+ with mx.autograd.record():
+ res2 = layer(data)
+ assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_cut_subgraph_cond():
+ class TestLayer(gluon.HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(TestLayer, self).__init__(prefix=prefix, params=params)
+ def hybrid_forward(self, F, data):
+ (data1, ) = F.contrib.cond(
+ data > 0.5,
+ then_func=lambda: data * 2,
+ else_func=lambda: data * 3,
+ )
+ (data2, ) = F.contrib.cond(
+ data1 > 0.5,
+ then_func=lambda: data1 * 2,
+ else_func=lambda: data1 * 3,
+ )
+ return data2
+ data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
+ layer = TestLayer()
+ layer.initialize(ctx=default_context())
+ res1 = layer(data)
+ with mx.autograd.record():
+ res1 = layer(data)
+ layer = TestLayer()
+ layer.initialize(ctx=default_context())
+ layer.hybridize()
+ res2 = layer(data)
+ with mx.autograd.record():
+ res2 = layer(data)
+ assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()