You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/03 22:48:11 UTC

[incubator-mxnet] branch master updated: Bug fixes in control flow operators (#11942)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 22c97ef  Bug fixes in control flow operators (#11942)
22c97ef is described below

commit 22c97efc418463befc12881892246f999a85b971
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Aug 3 15:48:02 2018 -0700

    Bug fixes in control flow operators (#11942)
---
 python/mxnet/symbol/contrib.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 8842883..1d42cf7 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -486,12 +486,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
         input_id_to_loc = {}    # Dict[int, int], given id(sym), input_id_to_loc maps it
                                 # to a `loc`, where inputs[loc] = sym
         for graph in graphs:
-            # input_syms: all inputs to the `graph`
-            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # some loop_vars are inputs to `graph`, some are not
             name_to_loop_vars = {sym.name: sym for sym in loop_vars}
             # other inputs to `graph` created by cut_graph
             name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
+            # input_syms: all inputs to the `graph`
+            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # also we collect the mapping from var's name to var's loc in loop_vars
             name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
             # collect arguments for each subgraph
@@ -644,12 +644,12 @@ def cond(pred, then_func, else_func, name="cond"):
         input_id_to_loc = {}    # Dict[int, int], given id(sym), input_id_to_loc maps it
                                 # to a `loc`, where inputs[loc] = sym
         for graph in graphs:
-            # input_syms: all inputs to the `graph`
-            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # some input_vars are inputs to `graph`, some are not
             name_to_input_vars = {sym.name: sym for sym in inputs}
             # other inputs to `graph` created by cut_graph
             name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
+            # input_syms: all inputs to the `graph`
+            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # collect arguments for each subgraph
             input_locs = []                         # results from the second step
             for name in graph.list_inputs():
@@ -696,5 +696,4 @@ def cond(pred, then_func, else_func, name="cond"):
         else_input_locs=else_input_locs,
         num_outputs=then_num_outputs
     )
-    result = _to_symbol_tuple(result, "result")
-    return list(result)
+    return [result[i] for i in range(then_num_outputs)]