You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/04 04:41:42 UTC

[incubator-mxnet] 02/04: Fix duplicate entry bugs (#11767)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 6393058a0a053538a92cd8402bec62ed27b6d581
Author: reminisce <wu...@gmail.com>
AuthorDate: Tue Jul 17 22:57:13 2018 -0700

    Fix duplicate entry bugs (#11767)
---
 src/operator/subgraph/partition_graph.cc  | 6 +++++-
 tests/python/unittest/test_subgraph_op.py | 8 ++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
index 11af49a..c935c8f 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -622,6 +622,10 @@ void TopSortEntries(const Graph& g,
     if (visited.count(node) == 0U) {
       s.emplace(node, 0U, &e);
       visited.insert(node);
+    } else {
+      // The entry's source node has been visited before.
+      // Marking the order for it.
+      entry_top_order_map->emplace(&e, entry_top_order_map->size());
     }
     while (!s.empty()) {
       auto& top = s.top();
@@ -641,7 +645,7 @@ void TopSortEntries(const Graph& g,
           visited.insert(input_node);
         } else {
           // The entry's source node has been visited before.
-          // Marking order for it.
+          // Marking the order for it.
           entry_top_order_map->emplace(&entry, entry_top_order_map->size());
         }
       }
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index f08c42c..0e5c1e0 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -124,10 +124,18 @@ def test_input_name_order():
         check_input_order(ret, ['exp', 'BatchNorm'])
         check_input_order(ret, ['BatchNorm'])
 
+    def test_network_structure_5():
+        # the last op has multiple duplicate outputs
+        data = mx.sym.var('data')
+        ret = mx.sym.exp(data)
+        ret = mx.sym.Group([ret, ret, ret])
+        check_input_order(ret, ['exp'])
+
     test_network_structure_1()
     test_network_structure_2()
     test_network_structure_3()
     test_network_structure_4()
+    test_network_structure_5()
 
 
 if __name__ == '__main__':