You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/04 04:41:42 UTC
[incubator-mxnet] 02/04: Fix duplicate entry bugs (#11767)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 6393058a0a053538a92cd8402bec62ed27b6d581
Author: reminisce <wu...@gmail.com>
AuthorDate: Tue Jul 17 22:57:13 2018 -0700
Fix duplicate entry bugs (#11767)
---
src/operator/subgraph/partition_graph.cc | 6 +++++-
tests/python/unittest/test_subgraph_op.py | 8 ++++++++
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
index 11af49a..c935c8f 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -622,6 +622,10 @@ void TopSortEntries(const Graph& g,
if (visited.count(node) == 0U) {
s.emplace(node, 0U, &e);
visited.insert(node);
+ } else {
+ // The entry's source node has been visited before.
+ // Marking the order for it.
+ entry_top_order_map->emplace(&e, entry_top_order_map->size());
}
while (!s.empty()) {
auto& top = s.top();
@@ -641,7 +645,7 @@ void TopSortEntries(const Graph& g,
visited.insert(input_node);
} else {
// The entry's source node has been visited before.
- // Marking order for it.
+ // Marking the order for it.
entry_top_order_map->emplace(&entry, entry_top_order_map->size());
}
}
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index f08c42c..0e5c1e0 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -124,10 +124,18 @@ def test_input_name_order():
check_input_order(ret, ['exp', 'BatchNorm'])
check_input_order(ret, ['BatchNorm'])
+ def test_network_structure_5():
+ # the last op has multiple duplicate outputs
+ data = mx.sym.var('data')
+ ret = mx.sym.exp(data)
+ ret = mx.sym.Group([ret, ret, ret])
+ check_input_order(ret, ['exp'])
+
test_network_structure_1()
test_network_structure_2()
test_network_structure_3()
test_network_structure_4()
+ test_network_structure_5()
if __name__ == '__main__':