You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/04 04:41:44 UTC
[incubator-mxnet] 04/04: [DO NOT REVIEW] Fix bug of eliminating
cycles (#11907)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 6acf6cce1c64ff0a11f91f4e9b9e3588211e4e12
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Jul 26 21:01:32 2018 -0700
[DO NOT REVIEW] Fix bug of eliminating cycles (#11907)
* Fix cycle bug
* Fix decycle bug
* Fix comment
---
src/operator/subgraph/partition_graph.cc | 71 +++++++++++++++++++++++++++----
tests/python/unittest/test_subgraph_op.py | 11 +++++
2 files changed, 74 insertions(+), 8 deletions(-)
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
index 0546430..9672877 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -147,9 +147,12 @@ bool LabelSubgraph(const Graph& g,
// key: nodes that serve as input/output nodes to the subgraph
// value: pair of vectors of nodes in the subgraph. The first vector contains the
// output nodes of the key in the subgraph, and the second vector contains the
- // input ndoes of the key in the subgraph. If both vectors are non-empty,
- // it means there is a loop between the subgraph and the key node.
- // When breaking the loop, we want to start removing the node with the largest node id.
+ // input nodes of the key in the subgraph.
+ // If a non-subgraph node has inputs from the subgraph and the other non-subgraph node
+ // has outputs to the subgraph, and the first non-subgraph node is an ancestor
+ // of the second non-subgraph node, there exits a cycle.
+ // When breaking the cycle, we want to start from removing the node with the largest node id
+ // in the subgraph.
std::unordered_map<const nnvm::Node*,
std::pair<std::vector<const nnvm::Node*>,
std::vector<const nnvm::Node*>>> non_subgraph_node_map;
@@ -194,23 +197,75 @@ bool LabelSubgraph(const Graph& g,
}
}
}
+ // prepare to check if there is a cycle
auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
};
- // check whether there is a loop between the subgraph and its input/output nodes
- int excluded_node_id = -1;
+ std::vector<const nnvm::Node*> non_subgraph_nodes;
+ non_subgraph_nodes.reserve(non_subgraph_node_map.size());
for (auto& kv : non_subgraph_node_map) {
auto& output_nodes = kv.second.first;
+ std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
auto& input_nodes = kv.second.second;
+ std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+ non_subgraph_nodes.push_back(kv.first);
+ }
+ // check whether there is a cycle between the subgraph and its input/output nodes
+ auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* descendant,
+ const std::vector<nnvm::Node*>& snodes) {
+ if (ancestor == descendant) return true;
+ std::stack<const nnvm::Node*> s;
+ s.push(descendant);
+ size_t count = 0;
+ while (!s.empty()) {
+ CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably"
+ " a loop in the graph";
+ ++count;
+ const nnvm::Node* top = s.top();
+ s.pop();
+ if (top == ancestor) {
+ return true;
+ }
+ for (const auto& entry : top->inputs) {
+ // when searching for the ancestor, the path cannot cross any subgraph node
+ auto it = std::find(snodes.begin(), snodes.end(), entry.node.get());
+ if (it == snodes.end()) {
+ s.push(entry.node.get());
+ }
+ }
+ }
+ return false;
+ };
+ std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
+ int excluded_node_id = -1;
+ for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
+ auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
+ CHECK(it1 != non_subgraph_node_map.end());
+ auto& output_nodes = it1->second.first; // has been top sorted
+ auto& input_nodes = it1->second.second; // has been top sorted
if (!output_nodes.empty() && !input_nodes.empty()) {
- // there is a loop between kv->first and the subgraph
- std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
- std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+ // there is a loop between node i and the subgraph
const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
indexed_graph.node_id(input_nodes.back()));
excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+ } else if (!input_nodes.empty()) {
+ // node i is an input to the subgraph, find out if there is a node j
+ // which is an output of the subgraph and also a child of node i.
+ for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
+ auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
+ CHECK(it2 != non_subgraph_node_map.end());
+ // i is topologically before j, j might be a direct/indirect output node of i
+ CHECK_LT(indexed_graph.node_id(it1->first), indexed_graph.node_id(it2->first));
+ if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, *subgraph_nodes)) {
+ // found a loop
+ const auto node_id = std::max(indexed_graph.node_id(input_nodes.back()),
+ indexed_graph.node_id(it2->second.first.back()));
+ excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+ }
+ }
}
}
+
if (excluded_node_id != -1) {
CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
CHECK_NE(excluded_node_id, static_cast<int>(snid))
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index c3b408b..f6a33c2 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -114,12 +114,23 @@ def test_subgraph_exe():
for sym, op_names in get_graph():
check_subgraph_exe(sym, op_names)
+ def test_network_structure_7():
+ # in this graph, the subgraph node and the other two external nodes form a cycle
+ data = mx.sym.Variable('data', shape=(1,))
+ ret1 = mx.sym.sin(data)
+ ret2 = mx.sym.cos(ret1)
+ for _ in range(5):
+ ret2 = mx.sym.cos(ret2)
+ ret = ret1 + ret2
+ check_subgraph_exe(ret, ['sin', 'elemwise_add', '_plus', '_Plus'])
+
test_network_structure_1()
test_network_structure_2()
test_network_structure_3()
test_network_structure_4()
test_network_structure_5()
test_network_structure_6()
+ test_network_structure_7()
if __name__ == '__main__':