You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/03 22:48:11 UTC
[incubator-mxnet] branch master updated: Bug fixes in control flow
operators (#11942)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 22c97ef Bug fixes in control flow operators (#11942)
22c97ef is described below
commit 22c97efc418463befc12881892246f999a85b971
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Aug 3 15:48:02 2018 -0700
Bug fixes in control flow operators (#11942)
---
python/mxnet/symbol/contrib.py | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 8842883..1d42cf7 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -486,12 +486,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"):
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
- # input_syms: all inputs to the `graph`
- name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some loop_vars are inputs to `graph`, some are not
name_to_loop_vars = {sym.name: sym for sym in loop_vars}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
+ # input_syms: all inputs to the `graph`
+ name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# also we collect the mapping from var's name to var's loc in loop_vars
name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
# collect arguments for each subgraph
@@ -644,12 +644,12 @@ def cond(pred, then_func, else_func, name="cond"):
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
- # input_syms: all inputs to the `graph`
- name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some input_vars are inputs to `graph`, some are not
name_to_input_vars = {sym.name: sym for sym in inputs}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
+ # input_syms: all inputs to the `graph`
+ name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# collect arguments for each subgraph
input_locs = [] # results from the second step
for name in graph.list_inputs():
@@ -696,5 +696,4 @@ def cond(pred, then_func, else_func, name="cond"):
else_input_locs=else_input_locs,
num_outputs=then_num_outputs
)
- result = _to_symbol_tuple(result, "result")
- return list(result)
+ return [result[i] for i in range(then_num_outputs)]