You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/03 22:48:03 UTC

[GitHub] szha closed pull request #11942: [MXNET-749] Bug fixes in control flow operators

szha closed pull request #11942: [MXNET-749] Bug fixes in control flow operators
URL: https://github.com/apache/incubator-mxnet/pull/11942
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 884288364b3..1d42cf7c18f 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -486,12 +486,12 @@ def _union_inputs(*graphs):
         input_id_to_loc = {}    # Dict[int, int], given id(sym), input_id_to_loc maps it
                                 # to a `loc`, where inputs[loc] = sym
         for graph in graphs:
-            # input_syms: all inputs to the `graph`
-            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # some loop_vars are inputs to `graph`, some are not
             name_to_loop_vars = {sym.name: sym for sym in loop_vars}
             # other inputs to `graph` created by cut_graph
             name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
+            # input_syms: all inputs to the `graph`
+            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # also we collect the mapping from var's name to var's loc in loop_vars
             name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
             # collect arguments for each subgraph
@@ -644,12 +644,12 @@ def _union_inputs(*graphs):
         input_id_to_loc = {}    # Dict[int, int], given id(sym), input_id_to_loc maps it
                                 # to a `loc`, where inputs[loc] = sym
         for graph in graphs:
-            # input_syms: all inputs to the `graph`
-            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # some input_vars are inputs to `graph`, some are not
             name_to_input_vars = {sym.name: sym for sym in inputs}
             # other inputs to `graph` created by cut_graph
             name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
+            # input_syms: all inputs to the `graph`
+            name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
             # collect arguments for each subgraph
             input_locs = []                         # results from the second step
             for name in graph.list_inputs():
@@ -696,5 +696,4 @@ def _union_inputs(*graphs):
         else_input_locs=else_input_locs,
         num_outputs=then_num_outputs
     )
-    result = _to_symbol_tuple(result, "result")
-    return list(result)
+    return [result[i] for i in range(then_num_outputs)]


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services