You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/04 04:41:44 UTC

[incubator-mxnet] 04/04: [DO NOT REVIEW] Fix bug of eliminating cycles (#11907)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 6acf6cce1c64ff0a11f91f4e9b9e3588211e4e12
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Jul 26 21:01:32 2018 -0700

    [DO NOT REVIEW] Fix bug of eliminating cycles (#11907)
    
    * Fix cycle bug
    
    * Fix decycle bug
    
    * Fix comment
---
 src/operator/subgraph/partition_graph.cc  | 71 +++++++++++++++++++++++++++----
 tests/python/unittest/test_subgraph_op.py | 11 +++++
 2 files changed, 74 insertions(+), 8 deletions(-)

diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
index 0546430..9672877 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -147,9 +147,12 @@ bool LabelSubgraph(const Graph& g,
   // key: nodes that serve as input/output nodes to the subgraph
   // value: pair of vectors of nodes in the subgraph. The first vector contains the
   // output nodes of the key in the subgraph, and the second vector contains the
-  // input ndoes of the key in the subgraph. If both vectors are non-empty,
-  // it means there is a loop between the subgraph and the key node.
-  // When breaking the loop, we want to start removing the node with the largest node id.
+  // input nodes of the key in the subgraph.
+  // If a non-subgraph node has inputs from the subgraph and the other non-subgraph node
+  // has outputs to the subgraph, and the first non-subgraph node is an ancestor
+  // of the second non-subgraph node, there exits a cycle.
+  // When breaking the cycle, we want to start from removing the node with the largest node id
+  // in the subgraph.
   std::unordered_map<const nnvm::Node*,
     std::pair<std::vector<const nnvm::Node*>,
               std::vector<const nnvm::Node*>>> non_subgraph_node_map;
@@ -194,23 +197,75 @@ bool LabelSubgraph(const Graph& g,
       }
     }
   }
+  // prepare to check if there is a cycle
   auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
     return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
   };
-  // check whether there is a loop between the subgraph and its input/output nodes
-  int excluded_node_id = -1;
+  std::vector<const nnvm::Node*> non_subgraph_nodes;
+  non_subgraph_nodes.reserve(non_subgraph_node_map.size());
   for (auto& kv : non_subgraph_node_map) {
     auto& output_nodes = kv.second.first;
+    std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
     auto& input_nodes = kv.second.second;
+    std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+    non_subgraph_nodes.push_back(kv.first);
+  }
+  // check whether there is a cycle between the subgraph and its input/output nodes
+  auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* descendant,
+                         const std::vector<nnvm::Node*>& snodes) {
+    if (ancestor == descendant) return true;
+    std::stack<const nnvm::Node*> s;
+    s.push(descendant);
+    size_t count = 0;
+    while (!s.empty()) {
+      CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably"
+                                                    " a loop in the graph";
+      ++count;
+      const nnvm::Node* top = s.top();
+      s.pop();
+      if (top == ancestor) {
+        return true;
+      }
+      for (const auto& entry : top->inputs) {
+        // when searching for the ancestor, the path cannot cross any subgraph node
+        auto it = std::find(snodes.begin(), snodes.end(), entry.node.get());
+        if (it == snodes.end()) {
+          s.push(entry.node.get());
+        }
+      }
+    }
+    return false;
+  };
+  std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
+  int excluded_node_id = -1;
+  for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
+    auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
+    CHECK(it1 != non_subgraph_node_map.end());
+    auto& output_nodes = it1->second.first;  // has been top sorted
+    auto& input_nodes = it1->second.second;  // has been top sorted
     if (!output_nodes.empty() && !input_nodes.empty()) {
-      // there is a loop between kv->first and the subgraph
-      std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
-      std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+      // there is a loop between node i and the subgraph
       const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
                                     indexed_graph.node_id(input_nodes.back()));
       excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    } else if (!input_nodes.empty()) {
+      // node i is an input to the subgraph, find out if there is a node j
+      // which is an output of the subgraph and also a child of node i.
+      for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
+        auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
+        CHECK(it2 != non_subgraph_node_map.end());
+        // i is topologically before j, j might be a direct/indirect output node of i
+        CHECK_LT(indexed_graph.node_id(it1->first), indexed_graph.node_id(it2->first));
+        if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, *subgraph_nodes)) {
+          // found a loop
+          const auto node_id = std::max(indexed_graph.node_id(input_nodes.back()),
+                                        indexed_graph.node_id(it2->second.first.back()));
+          excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+        }
+      }
     }
   }
+
   if (excluded_node_id != -1) {
     CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
     CHECK_NE(excluded_node_id, static_cast<int>(snid))
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index c3b408b..f6a33c2 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -114,12 +114,23 @@ def test_subgraph_exe():
         for sym, op_names in get_graph():
             check_subgraph_exe(sym, op_names)
 
+    def test_network_structure_7():
+        # in this graph, the subgraph node and the other two external nodes form a cycle
+        data = mx.sym.Variable('data', shape=(1,))
+        ret1 = mx.sym.sin(data)
+        ret2 = mx.sym.cos(ret1)
+        for _ in range(5):
+            ret2 = mx.sym.cos(ret2)
+        ret = ret1 + ret2
+        check_subgraph_exe(ret, ['sin', 'elemwise_add', '_plus', '_Plus'])
+
     test_network_structure_1()
     test_network_structure_2()
     test_network_structure_3()
     test_network_structure_4()
     test_network_structure_5()
     test_network_structure_6()
+    test_network_structure_7()
 
 
 if __name__ == '__main__':