You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/20 02:18:47 UTC

[GitHub] zheng-da commented on a change in pull request #11251: [WIP] Graph partitioner and subgraph op

zheng-da commented on a change in pull request #11251: [WIP] Graph partitioner and subgraph op
URL: https://github.com/apache/incubator-mxnet/pull/11251#discussion_r196630254
 
 

 ##########
 File path: src/operator/subgraph/partition_graph.cc
 ##########
 @@ -0,0 +1,688 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <queue>
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+
+#include "./default_subgraph_op.h"
+#include "./common.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+// TODO(junwu): Change this to 0
+#define SUBGRAPH_DEBUG 1
+
+namespace sg {  // sg stands for subgraph
+
+#if SUBGRAPH_DEBUG
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+  std::string op_names = "";
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    op_names += simple_nodes[i]->node->attrs.name + ' ';
+  }
+  LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+  std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+    + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
+  LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+  for (size_t i = 0; i < entries.size(); ++i) {
+    PrintNodeEntry(*entries[i]);
+  }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+                       std::vector<SimpleNodePtr>* simple_nodes) {
+  const auto& indexed_graph = g.indexed_graph();
+  simple_nodes->reserve(indexed_graph.num_nodes());
+  DFSVisit(g.outputs, [&](const NodePtr& node) {
+    SimpleNodePtr sn = SimpleNode::Create();
+    sn->node = node.get();
+    for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+      const auto& e = sn->node->inputs[i];
+      const auto input_nid = indexed_graph.node_id(e.node.get());
+      CHECK_LT(input_nid, simple_nodes->size());
+      auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+      auto it = input_node_outputs.find(sn->node);
+      if (it == input_node_outputs.end()) {
+        input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+      } else {
+        it->second.push_back(i);
+      }
+    }
+    simple_nodes->emplace_back(std::move(sn));
+  });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+                     const std::vector<SimpleNodePtr>& simple_nodes,
+                     std::vector<nnvm::Node*>* subgraph_nodes) {
+  for (auto n : *subgraph_nodes) {
+    const auto nid = g.indexed_graph().node_id(n);
+    simple_nodes[nid]->label = -1;
+  }
+  subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+                   SubgraphSelectorPtr subgraph_selector,
+                   const int label,
+                   const size_t snid,  // simple node id, this is a seed
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<nnvm::Node*>* subgraph_nodes,
+                   std::unordered_set<const nnvm::Node*>* excluded_nodes = nullptr) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::queue<SimpleNode*> node_queue;
+  if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+    CHECK_EQ(simple_nodes[snid]->label, -1);
+    simple_nodes[snid]->label = label;
+    node_queue.push(simple_nodes[snid].get());
+  }
+  // key: nodes that serve as input/output nodes to the subgraph
+  // value: pair of vectors of nodes in the subgraph. The first vector contains the
+  // output nodes of the key in the subgraph, and the second vector contains the
+  // input ndoes of the key in the subgraph. If both vectors are non-empty,
+  // it means there is a loop between the subgraph and the key node.
+  // When breaking the loop, we want to start removing the node with the largest node id.
+  std::unordered_map<const nnvm::Node*,
+    std::pair<std::vector<const nnvm::Node*>,
+              std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+  while (!node_queue.empty()) {
+    SimpleNode* cur_node = node_queue.front();
+    node_queue.pop();
+    //cur_node->label = label;
+    subgraph_nodes->push_back(cur_node->node);
+    // get qualified adjacent input nodes
+    for (auto& e : cur_node->node->inputs) {
+      const bool select_input = (!excluded_nodes || !excluded_nodes->count(e.node.get()))
+        && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+      if (select_input) {
+        // e.node is a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // e.node is an input node of the subgraph
+        non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+      }
+    }
+    // get qualified output nodes
+    for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
+      const bool select_output = (!excluded_nodes || !excluded_nodes->count(it->first))
+          && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+      if (select_output) {
+        // it->first is a subgraph node
+        const auto nid = indexed_graph.node_id(it->first);
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // it->first is an output node of the subgraph
+        non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+      }
+    }
+  }
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  // check whether there is a loop between the subgraph and its input/output nodes
+  int excluded_node_id = -1;
+  for (auto& kv : non_subgraph_node_map) {
+    auto& output_nodes = kv.second.first;
+    auto& input_nodes = kv.second.second;
+    if (!output_nodes.empty() && !input_nodes.empty()) {
+      // there is a loop between kv->first and the subgraph
+      std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+      std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+      const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+                                    indexed_graph.node_id(input_nodes.back()));
+      excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    }
+  }
+  if (excluded_node_id != -1) {
+    CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+    CHECK_NE(excluded_node_id, static_cast<int>(snid))
+      << "A cycle is found in the computational graph between nodes "
+      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+      << simple_nodes[snid]->node->attrs.name;
+    excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+    ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+    return false;
+  }
+  std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+  return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+                            SubgraphSelectorPtr subgraph_selector,
+                            const int label,
+                            const size_t snid,
+                            const std::vector<SimpleNodePtr>& simple_nodes,
+                            std::vector<nnvm::Node*>* subgraph_nodes) {
+  std::unordered_set<const nnvm::Node*> excluded_nodes;
+  const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+  size_t count = 0;
+  bool success = false;
+  while (!success && count < max_num_retry) {
+    success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+                            subgraph_nodes, &excluded_nodes);
+    if (!success) {
+      CHECK(!excluded_nodes.empty());
+      std::string excluded_node_names;
+      for (auto node : excluded_nodes) {
+        excluded_node_names += node->attrs.name + ", ";
+      }
+      LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
+                << ". Excluding nodes " << excluded_node_names << "and retrying";
+    }
+    ++count;
+  }
+  if (!success) {
+    LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
+               << simple_nodes[snid]->node->attrs.name << " without success because a loop "
+                  "is always found between the subgraph and some other nodes. Will treat "
+                  "seed node " << simple_nodes[snid]->node->attrs.name << "as a subgraph with one node";
+    CHECK(subgraph_nodes->empty());
+    simple_nodes[snid]->label = label;
+    subgraph_nodes->push_back(simple_nodes[snid]->node);
+  }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+                               const std::vector<nnvm::Node*>& nodes,
+                               const std::vector<SimpleNodePtr>& simple_nodes,
+                               std::vector<std::vector<SimpleNode*>>* subgraphs,
+                               size_t* subgraph_id) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+  auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode* node2) {
+    return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
+  };
+  for (auto node : nodes) {
+    if (!node_set.count(node)) {
+      // The node has been included in a subgraph
+      continue;
+    }
+    std::queue<nnvm::Node*> q;
+    q.push(node);
+    CHECK_EQ(node_set.erase(node), 1U);
+    subgraphs->emplace_back();
+    const auto nid = indexed_graph.node_id(node);
+    simple_nodes[nid]->label = *subgraph_id;
+    subgraphs->back().push_back(simple_nodes[nid].get());
+    while (!q.empty()) {
+      nnvm::Node* cur_node = q.front();
+      q.pop();
+      for (auto& e : cur_node->inputs) {
+        auto in_it = node_set.find(e.node.get());
+        if (in_it != node_set.end()) {
+          q.push(*in_it);
+          const auto in_nid = indexed_graph.node_id(*in_it);
+          simple_nodes[in_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[in_nid].get());
+          node_set.erase(in_it);
+        }
+      }
+      const auto cur_nid = indexed_graph.node_id(cur_node);
+      const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+      for (const auto& kv : cur_snode->outputs) {
+        const auto out_it = node_set.find(kv.first);
+        if (out_it != node_set.end()) {
+          q.push(*out_it);
+          const auto out_nid = indexed_graph.node_id(*out_it);
+          simple_nodes[out_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[out_nid].get());
+          node_set.erase(out_it);
+        }
+      }
+    }
+    ++(*subgraph_id);
+    std::sort(subgraphs->back().begin(), subgraphs->back().end(), simple_node_cmp);
+  }
+  CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+                   const SubgraphProperty &subg_prop,
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+  //CHECK(simple_nodes != nullptr);
+  const auto& indexed_graph = g->indexed_graph();
+  CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  size_t subgraph_id = 0;
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    nnvm::Node* node = simple_nodes[i]->node;
+    auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+    if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+      // pre-select nodes that can be grouped in a subgraph
+      std::vector<nnvm::Node*> preselected_nodes;
+      PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i, simple_nodes,
+                             &preselected_nodes);
+
+      // filter out unqualified pre-selected nodes
+      std::vector<nnvm::Node*> filtered_nodes = preselected_nodes;
+      subgraph_selector->Filter(g, &filtered_nodes);
 
 Review comment:
   Should we have Filter to return a vector of Nodes that are selected?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services