You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/10 06:34:13 UTC

[incubator-mxnet] branch master updated: [MXNET-749] Correct usages of `CutSubgraph` in 3 control flow operators (#12078)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new af15853  [MXNET-749] Correct usages of `CutSubgraph` in 3 control flow operators (#12078)
af15853 is described below

commit af15853a5989e8e40a7300669c8d0c42f0be79ad
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Aug 9 23:34:02 2018 -0700

    [MXNET-749] Correct usages of `CutSubgraph` in 3 control flow operators (#12078)
    
    * Fix cut graph
    
    * Copy only when necessary
    
    * Add unittest for while_loop
    
    * Add unittest for foreach
    
    * Add unittest for cond
    
    * Avoid magic number: 0 => kUndefinedStorage
---
 python/mxnet/symbol/contrib.py                     |  28 +++---
 src/c_api/c_api_symbolic.cc                        |   6 +-
 src/operator/control_flow.cc                       |   3 +
 tests/python/unittest/test_contrib_control_flow.py | 101 +++++++++++++++++++++
 4 files changed, 123 insertions(+), 15 deletions(-)

diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 1d42cf7..38195bd 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -127,7 +127,7 @@ def _cut_subgraph(subg):
 # This construct a subgraph for given output nodes.
 # If an output node is one of the input nodes, we call identity to make sure
 # that outputs nodes are different from input nodes.
-def _construct_subgraph(sym_out, sym_states):
+def _construct_subgraph(sym_out, sym_states, name):
     sym_out = _as_list(sym_out)
     sym_states = _as_list(sym_states)
     all_outputs = []
@@ -137,18 +137,16 @@ def _construct_subgraph(sym_out, sym_states):
 
     flat_out = []
     all_input_names = g.list_inputs()
-    output_names = [o.name for o in sym_out]
+    output_names = {o.name for o in sym_out}
     for o in sym_out:
-        if o.name in all_input_names:
+        if o.name in all_input_names or o.list_attr().get("__subgraph_name__", "") != name:
             flat_out.append(symbol.op.identity(o))
         else:
             flat_out.append(o)
 
     for s in sym_states:
-        if s.name in all_input_names or s.name in output_names:
-            # There is a problem if the outputs are the same as the inputs
-            # or the first output. By calling identity, we can make sure that
-            # all symbols will refer to different NDArrays.
+        if s.name in all_input_names or s.name in output_names or \
+           s.list_attr().get("__subgraph_name__", "") != name:
             flat_out.append(symbol.op.identity(s))
         else:
             flat_out.append(s)
@@ -256,7 +254,7 @@ def foreach(body, data, init_states, name="foreach"):
         num_out_data = len(sym_out)
         num_states = len(sym_states)
         num_outputs = num_out_data + num_states
-        g = _construct_subgraph(sym_out, sym_states)
+        g = _construct_subgraph(sym_out, sym_states, name)
 
     input_syms = _get_graph_inputs(g)
     cut_syms = _cut_subgraph(g)
@@ -469,9 +467,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
             num_outputs = len(outputs) + len(final_state)
             # nnvm cut-graph does not allow inputs and outputs overlap
             # so we calculate the name of inputs, and copy outputs once it overlaps with inputs
-            all_input_names = symbol.Group(outputs + final_state).list_inputs()
-            make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
             # group all outputs of graph_func
+            all_input_names = symbol.Group(outputs + final_state).list_inputs()
+            in_input = lambda x: x.name in all_input_names
+            in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
+            make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
+                                      else x
             graph = symbol.Group(list(map(make_identity, outputs + final_state)))
         return graph, num_out_data, num_outputs
 
@@ -627,9 +628,12 @@ def cond(pred, then_func, else_func, name="cond"):
             num_outputs = len(outputs)
             # nnvm cut-graph does not allow inputs and outputs overlap
             # so we calculate the name of inputs, and copy outputs once it overlaps with inputs
-            all_input_names = symbol.Group(outputs).list_inputs()
-            make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
             # group all outputs of graph_func
+            all_input_names = symbol.Group(outputs).list_inputs()
+            in_input = lambda x: x.name in all_input_names
+            in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
+            make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
+                                      else x
             graph = symbol.Group(list(map(make_identity, outputs)))
         return graph, num_outputs
 
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index c27a59a..35ecec7 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -372,13 +372,13 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
   // a subgraph.
   API_BEGIN();
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
-  std::string subg_attr = "__subgraph_name__";
+  const std::string subg_attr = "__subgraph_name__";
   auto out_node = s->outputs[0].node;
   auto it = out_node->attrs.dict.find(subg_attr);
   if (it != out_node->attrs.dict.end()) {
-    std::string subg_name = it->second;
+    const std::string &subg_name = it->second;
     std::vector<nnvm::NodeEntry *> input_entries;
-    DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries]
+    DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
              (nnvm::NodePtr n) {
       // If the node itself isn't in the subgraph, we ignore it.
       auto it = n->attrs.dict.find(subg_attr);
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index 7c1becc..d6b6703 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -1225,6 +1225,9 @@ static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs,
     CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf));
     return ret;
   };
+  for (const dim_t &cond_in : params.cond_input_locs) {
+    (*out_attrs)[cond_in] = kDefaultStorage;
+  }
   bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs);
   bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs);
   return succ_0 && succ_1;
diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py
index f1188b5..a4b794c 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -1664,6 +1664,107 @@ def test_foreach_rnn():
         check_foreach_rnn(cell_type, num_states)
 
 
+@with_seed()
+def test_cut_subgraph_foreach():
+    class TestLayer(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestLayer, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, inputs, states):
+            def step1(data, states):
+                return data + 1, states
+            out1, states1 = F.contrib.foreach(step1, inputs, states)
+            out2, states2 = F.contrib.foreach(step1, out1, states)
+            def step2(data, states):
+                return data + states[0], states1
+            out, states = F.contrib.foreach(step2, out2, states)
+            return out
+
+    data = mx.nd.normal(loc=0, scale=1, shape=(5, 10))
+    states = mx.nd.normal(loc=0, scale=1, shape=(10))
+    layer = TestLayer()
+    layer.initialize(ctx=default_context())
+    res1 = layer(data, [states])
+
+    with mx.autograd.record():
+        res1 = layer(data, [states])
+
+    layer = TestLayer()
+    layer.initialize(ctx=default_context())
+    layer.hybridize()
+    res2 = layer(data, [states])
+
+    with mx.autograd.record():
+        res2 = layer(data, [states])
+    assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_cut_subgraph_while_loop():
+    class TestLayer(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestLayer, self).__init__(prefix=prefix, params=params)
+        def hybrid_forward(self, F, data):
+            out1, data1 = F.contrib.while_loop(
+                cond=lambda i: i <= 5,
+                func=lambda i: (None, (i + 1, )),
+                loop_vars=(data, ),
+                max_iterations=10,
+            )
+            out2, data2 = F.contrib.while_loop(
+                cond=lambda i: data1[0],
+                func=lambda i: (None, (i + 1, )),
+                loop_vars=data1[0],
+                max_iterations=10,
+            )
+            return data2[0]
+    data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
+    layer = TestLayer()
+    layer.initialize(ctx=default_context())
+    res1 = layer(data)
+    with mx.autograd.record():
+        res1 = layer(data)
+    layer = TestLayer()
+    layer.initialize(ctx=default_context())
+    layer.hybridize()
+    res2 = layer(data)
+    with mx.autograd.record():
+        res2 = layer(data)
+    assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_cut_subgraph_cond():
+    class TestLayer(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestLayer, self).__init__(prefix=prefix, params=params)
+        def hybrid_forward(self, F, data):
+            (data1, ) = F.contrib.cond(
+                data > 0.5,
+                then_func=lambda: data * 2,
+                else_func=lambda: data * 3,
+            )
+            (data2, ) = F.contrib.cond(
+                data1 > 0.5,
+                then_func=lambda: data1 * 2,
+                else_func=lambda: data1 * 3,
+            )
+            return data2
+    data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
+    layer = TestLayer()
+    layer.initialize(ctx=default_context())
+    res1 = layer(data)
+    with mx.autograd.record():
+        res1 = layer(data)
+    layer = TestLayer()
+    layer.initialize(ctx=default_context())
+    layer.hybridize()
+    res2 = layer(data)
+    with mx.autograd.record():
+        res2 = layer(data)
+    assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()