You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/10/01 23:32:19 UTC

[GitHub] [incubator-mxnet] ptrendx opened a new pull request #19269: Faster pointwise fusion graph pass

ptrendx opened a new pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269


   ## Description ##
   This PR greatly increases the speed of pointwise fusion graph pass. As a test case I used XLNet from Gluon-NLP which did expose slowness in this graph pass before: #17105 - the total time of fwd+bwd fusion after the fix from that issue was ~8s, after this PR the time is ~11 ms, with 3 ms of that taken by IndexedGraph construction from the original graph (so the actual fusion graph pass takes ~8ms for that case - about 1000x improvement).
   
   The motivation of this PR is getting the pointwise fusion graph pass to be lightweight enough to possibly be run every time the shapes change in the graph, enabling the fusion to be shape-aware. This is important as in 2.0 NumPy semantics make multiple operations (like add, sub etc.) broadcast by default, even if in the end they are simple elementwise operations. This PR does not fully get us there (I would much like it being <5ms for the network like XLNet to be sure it is lightweight enough to not be bottleneck for all usecases), although there are some parts that would not be needed anymore if we could make this pass after infershape is already done (e.g. we would not need to insert `FusedOpHelper`/`FusedOpOutHelper`, which currently takes a little over 1ms). That  said, it is a big step in that direction.
   
   The main problem that the fusion graph pass needs to solve is not allowing the cycles to be formed: nodes which both consume the output of and provide the input to a single subgraph. The original pointwise fusion graph pass' algorithm to avoid cycles was to first construct a mapping, which nodes are excluded to be in the same subgraph with a given node. The construction of the mapping was simple but inefficient (and then improved in #17114, but still pretty slow):
   ```
   for each node n that does not qualify to be in a subgraph:
       outputs = nodes reached by DFS from n in the direction of n's outputs
       inputs = nodes reached by DFS from n in the direction of n's inputs
       for each output in outputs:
           for each input in inputs:
               put (input, output) and (output, input) pair into the exclusion mapping
   ```
   This was `O(n^3)` algorithm (improved in #17114 to be `O(n^2)` on average) and required a separate `BidirectionalGraph` datastructure that enabled DFS in both directions.
   
   The new algorithm for pointwise fusion graph pass is designed based on 2 observations:
    - most of the entries in the exclusion mappings are not useful, as those nodes are not considered to be part of the same subgraph
    - we can traverse the graph in topological order, which the original algorithm did not take advantage of
   
   In the new algorithm the graph is traversed in the topological order (and so we are sure that all the inputs of the current node were already processed) and each node has its own exclusion set of subgraphs that it can't be a part of. The exclusion set of node is constructed as the union of the exclusion sets of its inputs + all the subsets that its inputs are part of if the node itself is ineligible to be in a subgraph. Because subsets can merge (e.g. when you have operation `a + b` where `a` is part of subgraph s1 and `b` is part of subgraph s2, then s1 and s2 need to be merged into a single subgraph containing everything from s1, s2 and the `+` operator), the mapping is maintained to know which other subsets the current subset is merged with. 
   
   There are number of additional optimizations:
    - because of the topological ordering of the graph traversal, subset ids in the exclusion set are typically consecutive number (or a small group of consecutive numbers) -> therefore in the exclusion set we actually keep intervals of numbers instead of the numbers themselves
    - in most cases the union of exclusion sets is equal to one of the sets -> in order to avoid costly unnecessary memory allocations, we share the exclusion sets between nodes
    - the fwd and bwd fusions are done together in a single pass to avoid overheads
   
   The second part of the PR is the overhaul of the actual subgraph substitution - previously it was done 1 subgraph at a time, with multiple `DFSVisit` calls per subgraph. Unfortunately `DFSVisit` is costly and most of that work was wasted (as the DFS over the entire graph was needed for the substitution of a few nodes. In the new approach a new graph is created based on the subgraph assignment generated in the previous part, which requires only a single pass over the graph to apply all subgraphs.
   
   @samskalicky @Caenorst @mk-61
   
   
   ## Checklist ##
   ### Essentials ###
   - [ ] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage
   - [x] Code is well-documented
   
   ## Comments ##
   - If this change is a backward incompatible change, why must this change be made.
   - Interesting edge cases to note here
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-710675678


   Undefined action detected. 
   Permissible actions are : run ci [all], run ci [job1, job2] 
   Example : @mxnet-bot run ci [all] 
   Example : @mxnet-bot run ci [centos-cpu, clang]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-707222118


   Hi @mseth10, I think all of your 3 questions are quite connected. I agree that the nomenclature here is slightly confusing. I use "subsets" because in this algorithm there are actually 2 parts:
    - the first one just finds the sets of nodes that can be fused together but does not do any actual fusion
    - the second one turns those subsets into the actual subgraphs and puts them inside the fusion node
   About the mapping - consider a following graph:
   ![example](https://user-images.githubusercontent.com/8398980/95767353-793be580-0c69-11eb-9ff6-e7d085ba6bc6.png)
   where `a` and `b` are the outputs of some previously identified subgraphs (let's call those subgraphs `A` and `B`). Node `c` (which can't be fused) has `A` in its exclusion list. Then `+` node is being considered, which can be fused. It merges subgraphs `A` and `B` together under one of them (let's say `B`). Then subgraph `A` needs to have a mapping that says that it is a part of `B` now (and `B` should also have an inverse mapping that it contains subgraphs `A` and `B` in order to be able to check against nodes like `c` which only saw `A` but did not see `B`).


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-710675909


   Jenkins CI successfully triggered : [centos-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-710675669


   @ChaiBapchya  FYI - there seem to be some connection issues in CI - first sanity test failed because of connection reset, then centos-cpu test failed with 
   ```
   [2020-10-16T21:36:45.894Z] ERROR: for centos7_cpu  ('Connection broken: IncompleteRead(0 bytes read)', IncompleteRead(0 bytes read))
   ```
   (https://jenkins.mxnet-ci.amazon-ml.com/job/mxnet-validation/job/centos-cpu/job/PR-19269/8/display/redirect)
   
   @mxnet-bot run ci [centos-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506721813



##########
File path: src/imperative/simple_partition_pass.cc
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+void PrintSets(const IntervalVec* const sets_p) {
+  if (sets_p == nullptr || sets_p->size() == 0) {
+    std::cout << "{}" << std::endl;
+    return;
+  }
+  const auto& sets = *sets_p;
+  std::cout << "{";
+  for (size_t i = 0; i < sets.size() - 1; ++i) {
+    std::cout << "[" << sets[i].first << "," << sets[i].second << "], ";
+  }
+  std::cout << "[" << sets[sets.size()-1].first << ","
+            << sets[sets.size()-1].second << "]}" << std::endl;
+}
+
+const IntervalVec* LargerSet(const IntervalVec* const first,
+                             const IntervalVec* const second) noexcept {
+  const IntervalVec* ret = nullptr;
+  auto first_iter = first->begin();
+  auto second_iter = second->begin();
+  while (first_iter != first->end() &&
+         second_iter != second->end()) {
+    if (*first_iter == *second_iter) {
+      ++first_iter;
+      ++second_iter;
+    } else {
+      // Entry in first set not seen in the second set
+      if (first_iter->second < second_iter->first) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set not seen in the first set
+      if (second_iter->second < first_iter->first) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in first set fully encloses the entry in the second set
+      if (first_iter->first <= second_iter->first &&
+          first_iter->second >= second_iter->second) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set fully encloses the entry in the first set
+      if (second_iter->first <= first_iter->first &&
+          second_iter->second >= first_iter->second) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entries intersect but one is not fully enclosed in the other
+      return nullptr;
+    }
+  }
+  if (ret == nullptr) {
+    // The common part is the same
+    return second_iter == second->end() ? first : second;
+  } else {
+    if ((ret == first && second_iter == second->end()) ||
+        (ret == second && first_iter == first->end())) {
+      return ret;
+    }
+  }
+  return nullptr;
+}
+
+void MergeSets(const IntervalVec** const my_set,
+               const IntervalVec* const other_set,
+               std::vector<std::unique_ptr<const IntervalVec>>* const storage) noexcept {
+  if ((*my_set == nullptr) || (*my_set)->size() == 0) {
+    *my_set = other_set;
+    return;
+  }
+  if (other_set == nullptr || other_set->size() == 0) {
+    return;
+  }
+  auto* larger_set = LargerSet(*my_set, other_set);
+  if (larger_set != nullptr) {
+    *my_set = larger_set;
+    return;
+  }
+  auto my_iter = (*my_set)->cbegin();
+  auto other_iter = other_set->cbegin();
+  auto new_set = IntervalVec();
+  int last_end = -10;  // less than -1
+  while (my_iter != (*my_set)->cend() &&
+         other_iter != other_set->cend()) {
+    const auto& mine = *my_iter;
+    const auto& other = *other_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours
+      if (last_end >= other.first - 1) {
+        new_set.back().second = other.second;

Review comment:
       I looked at it and actually the other branch does not need the `max` there. I put the `max` in the other branch (and you are right, forgot about this one) during debugging of the problem in the previous version of the code that was due to always incrementing both iterators in the last branch (the intersecting one). Once that branch is fixed with the conditional increment of iterators the `max` in not needed - `last_end` is always less than the `*.second` in those 2 branches.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506688373



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {

Review comment:
       :+1:




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506722767



##########
File path: src/imperative/simple_partition_pass.cc
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+void PrintSets(const IntervalVec* const sets_p) {
+  if (sets_p == nullptr || sets_p->size() == 0) {
+    std::cout << "{}" << std::endl;
+    return;
+  }
+  const auto& sets = *sets_p;
+  std::cout << "{";
+  for (size_t i = 0; i < sets.size() - 1; ++i) {
+    std::cout << "[" << sets[i].first << "," << sets[i].second << "], ";
+  }
+  std::cout << "[" << sets[sets.size()-1].first << ","
+            << sets[sets.size()-1].second << "]}" << std::endl;
+}
+
+const IntervalVec* LargerSet(const IntervalVec* const first,
+                             const IntervalVec* const second) noexcept {
+  const IntervalVec* ret = nullptr;
+  auto first_iter = first->begin();
+  auto second_iter = second->begin();
+  while (first_iter != first->end() &&
+         second_iter != second->end()) {
+    if (*first_iter == *second_iter) {
+      ++first_iter;
+      ++second_iter;
+    } else {
+      // Entry in first set not seen in the second set
+      if (first_iter->second < second_iter->first) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set not seen in the first set
+      if (second_iter->second < first_iter->first) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in first set fully encloses the entry in the second set
+      if (first_iter->first <= second_iter->first &&
+          first_iter->second >= second_iter->second) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set fully encloses the entry in the first set
+      if (second_iter->first <= first_iter->first &&
+          second_iter->second >= first_iter->second) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entries intersect but one is not fully enclosed in the other
+      return nullptr;
+    }
+  }
+  if (ret == nullptr) {
+    // The common part is the same
+    return second_iter == second->end() ? first : second;
+  } else {
+    if ((ret == first && second_iter == second->end()) ||
+        (ret == second && first_iter == first->end())) {
+      return ret;
+    }
+  }
+  return nullptr;
+}
+
+void MergeSets(const IntervalVec** const my_set,
+               const IntervalVec* const other_set,
+               std::vector<std::unique_ptr<const IntervalVec>>* const storage) noexcept {
+  if ((*my_set == nullptr) || (*my_set)->size() == 0) {
+    *my_set = other_set;
+    return;
+  }
+  if (other_set == nullptr || other_set->size() == 0) {
+    return;
+  }
+  auto* larger_set = LargerSet(*my_set, other_set);
+  if (larger_set != nullptr) {
+    *my_set = larger_set;
+    return;
+  }
+  auto my_iter = (*my_set)->cbegin();
+  auto other_iter = other_set->cbegin();
+  auto new_set = IntervalVec();
+  int last_end = -10;  // less than -1
+  while (my_iter != (*my_set)->cend() &&
+         other_iter != other_set->cend()) {
+    const auto& mine = *my_iter;
+    const auto& other = *other_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours
+      if (last_end >= other.first - 1) {
+        new_set.back().second = other.second;
+      } else {
+        new_set.emplace_back(other);
+      }
+      last_end = other.second;
+      ++other_iter;
+    } else if (other.first > mine.second + 1) {
+      // other interval is after ours
+      if (last_end >= mine.first - 1) {
+        new_set.back().second = std::max(mine.second, last_end);
+      } else {
+        new_set.emplace_back(mine);
+      }
+      last_end = new_set.back().second;
+      ++my_iter;
+    } else {
+      // Intervals can be merged together
+      Interval n(std::min(mine.first, other.first),
+                 std::max(mine.second, other.second));
+      if (last_end >= n.first - 1) {
+        new_set.back().second = n.second;
+      } else {
+        new_set.emplace_back(n);
+      }
+      last_end = n.second;
+      if (other.second >= mine.second) {
+        ++my_iter;
+      }
+      if (mine.second >= other.second) {
+        ++other_iter;
+      }
+    }
+  }
+  // Add the rest of entries
+  for (; my_iter != (*my_set)->cend(); ++my_iter) {
+    auto& mine = new_set.back();
+    const auto& other = *my_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours, should never happen
+      continue;
+    } else if (other.first > mine.second + 1) {
+      // other interval is after ours
+      new_set.emplace_back(other);
+    } else {
+      // Intervals can be merged together
+      mine.first = std::min(mine.first, other.first);
+      mine.second = std::max(mine.second, other.second);
+    }
+  }
+  for (; other_iter != other_set->cend(); ++other_iter) {

Review comment:
       Yes, I refactored it to have only 1 loop.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-707996703


   @mxnet-bot run ci [unix-cpu, unix-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mseth10 commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mseth10 commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r504578496



##########
File path: src/imperative/simple_partition_pass.cc
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+void PrintSets(const IntervalVec* const sets_p) {

Review comment:
       can we remove this?

##########
File path: src/imperative/simple_partition_pass.cc
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+void PrintSets(const IntervalVec* const sets_p) {
+  if (sets_p == nullptr || sets_p->size() == 0) {
+    std::cout << "{}" << std::endl;
+    return;
+  }
+  const auto& sets = *sets_p;
+  std::cout << "{";
+  for (size_t i = 0; i < sets.size() - 1; ++i) {
+    std::cout << "[" << sets[i].first << "," << sets[i].second << "], ";
+  }
+  std::cout << "[" << sets[sets.size()-1].first << ","
+            << sets[sets.size()-1].second << "]}" << std::endl;
+}
+
+const IntervalVec* LargerSet(const IntervalVec* const first,
+                             const IntervalVec* const second) noexcept {
+  const IntervalVec* ret = nullptr;
+  auto first_iter = first->begin();
+  auto second_iter = second->begin();
+  while (first_iter != first->end() &&
+         second_iter != second->end()) {
+    if (*first_iter == *second_iter) {
+      ++first_iter;
+      ++second_iter;
+    } else {
+      // Entry in first set not seen in the second set
+      if (first_iter->second < second_iter->first) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set not seen in the first set
+      if (second_iter->second < first_iter->first) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in first set fully encloses the entry in the second set
+      if (first_iter->first <= second_iter->first &&
+          first_iter->second >= second_iter->second) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set fully encloses the entry in the first set
+      if (second_iter->first <= first_iter->first &&
+          second_iter->second >= first_iter->second) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entries intersect but one is not fully enclosed in the other
+      return nullptr;
+    }
+  }
+  if (ret == nullptr) {
+    // The common part is the same
+    return second_iter == second->end() ? first : second;
+  } else {
+    if ((ret == first && second_iter == second->end()) ||
+        (ret == second && first_iter == first->end())) {
+      return ret;
+    }
+  }
+  return nullptr;
+}
+
+void MergeSets(const IntervalVec** const my_set,
+               const IntervalVec* const other_set,
+               std::vector<std::unique_ptr<const IntervalVec>>* const storage) noexcept {
+  if ((*my_set == nullptr) || (*my_set)->size() == 0) {
+    *my_set = other_set;
+    return;
+  }
+  if (other_set == nullptr || other_set->size() == 0) {
+    return;
+  }
+  auto* larger_set = LargerSet(*my_set, other_set);
+  if (larger_set != nullptr) {
+    *my_set = larger_set;
+    return;
+  }
+  auto my_iter = (*my_set)->cbegin();
+  auto other_iter = other_set->cbegin();
+  auto new_set = IntervalVec();
+  int last_end = -10;  // less than -1
+  while (my_iter != (*my_set)->cend() &&
+         other_iter != other_set->cend()) {
+    const auto& mine = *my_iter;
+    const auto& other = *other_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours
+      if (last_end >= other.first - 1) {
+        new_set.back().second = other.second;

Review comment:
       shouldn't it be `= std::max(other.second, last_end);`

##########
File path: src/imperative/simple_partition_pass.cc
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+void PrintSets(const IntervalVec* const sets_p) {
+  if (sets_p == nullptr || sets_p->size() == 0) {
+    std::cout << "{}" << std::endl;
+    return;
+  }
+  const auto& sets = *sets_p;
+  std::cout << "{";
+  for (size_t i = 0; i < sets.size() - 1; ++i) {
+    std::cout << "[" << sets[i].first << "," << sets[i].second << "], ";
+  }
+  std::cout << "[" << sets[sets.size()-1].first << ","
+            << sets[sets.size()-1].second << "]}" << std::endl;
+}
+
+const IntervalVec* LargerSet(const IntervalVec* const first,
+                             const IntervalVec* const second) noexcept {
+  const IntervalVec* ret = nullptr;
+  auto first_iter = first->begin();
+  auto second_iter = second->begin();
+  while (first_iter != first->end() &&
+         second_iter != second->end()) {
+    if (*first_iter == *second_iter) {
+      ++first_iter;
+      ++second_iter;
+    } else {
+      // Entry in first set not seen in the second set
+      if (first_iter->second < second_iter->first) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set not seen in the first set
+      if (second_iter->second < first_iter->first) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in first set fully encloses the entry in the second set
+      if (first_iter->first <= second_iter->first &&
+          first_iter->second >= second_iter->second) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set fully encloses the entry in the first set
+      if (second_iter->first <= first_iter->first &&
+          second_iter->second >= first_iter->second) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entries intersect but one is not fully enclosed in the other
+      return nullptr;
+    }
+  }
+  if (ret == nullptr) {
+    // The common part is the same
+    return second_iter == second->end() ? first : second;
+  } else {
+    if ((ret == first && second_iter == second->end()) ||
+        (ret == second && first_iter == first->end())) {
+      return ret;
+    }
+  }
+  return nullptr;
+}
+
+void MergeSets(const IntervalVec** const my_set,
+               const IntervalVec* const other_set,
+               std::vector<std::unique_ptr<const IntervalVec>>* const storage) noexcept {
+  if ((*my_set == nullptr) || (*my_set)->size() == 0) {
+    *my_set = other_set;
+    return;
+  }
+  if (other_set == nullptr || other_set->size() == 0) {
+    return;
+  }
+  auto* larger_set = LargerSet(*my_set, other_set);
+  if (larger_set != nullptr) {
+    *my_set = larger_set;
+    return;
+  }
+  auto my_iter = (*my_set)->cbegin();
+  auto other_iter = other_set->cbegin();
+  auto new_set = IntervalVec();
+  int last_end = -10;  // less than -1
+  while (my_iter != (*my_set)->cend() &&
+         other_iter != other_set->cend()) {
+    const auto& mine = *my_iter;
+    const auto& other = *other_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours
+      if (last_end >= other.first - 1) {
+        new_set.back().second = other.second;
+      } else {
+        new_set.emplace_back(other);
+      }
+      last_end = other.second;
+      ++other_iter;
+    } else if (other.first > mine.second + 1) {
+      // other interval is after ours
+      if (last_end >= mine.first - 1) {
+        new_set.back().second = std::max(mine.second, last_end);
+      } else {
+        new_set.emplace_back(mine);
+      }
+      last_end = new_set.back().second;
+      ++my_iter;
+    } else {
+      // Intervals can be merged together
+      Interval n(std::min(mine.first, other.first),
+                 std::max(mine.second, other.second));
+      if (last_end >= n.first - 1) {
+        new_set.back().second = n.second;
+      } else {
+        new_set.emplace_back(n);
+      }
+      last_end = n.second;
+      if (other.second >= mine.second) {
+        ++my_iter;
+      }
+      if (mine.second >= other.second) {
+        ++other_iter;
+      }
+    }
+  }
+  // Add the rest of entries
+  for (; my_iter != (*my_set)->cend(); ++my_iter) {
+    auto& mine = new_set.back();
+    const auto& other = *my_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours, should never happen
+      continue;
+    } else if (other.first > mine.second + 1) {
+      // other interval is after ours
+      new_set.emplace_back(other);
+    } else {
+      // Intervals can be merged together
+      mine.first = std::min(mine.first, other.first);
+      mine.second = std::max(mine.second, other.second);
+    }
+  }
+  for (; other_iter != other_set->cend(); ++other_iter) {

Review comment:
       just to understand this logic, only one of the two `for` blocks will be executed, right?

##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&

Review comment:
       can `subgraph_id` be -1 here? or should we add that check here? 

##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {

Review comment:
       use `subgraph_id` and `their_subgraph_id` here?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-707996746


   Jenkins CI successfully triggered : [unix-cpu, unix-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx edited a comment on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx edited a comment on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-706450846


   I fixed the problem with the cycle - every time when I add a new node to a subset, I check whether it's exclusion set is fully included in the future fusion node exclusion set. If so, I update all descendants of the fusion node (with the assumption that typically fusion subsets are not very big and their nodes are quite close together in the indexed graph, so the cost of that is not very high). Still, the check is quite expensive compared to the rest of the algorithm and now the cost of the graph pass for XLNet is ~10-11 ms (still a huge improvement compared to the previous algorithm).


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mseth10 commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mseth10 commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-707054352


   Thanks for the optimization @ptrendx . A couple of questions based on the PR description:
   > Because subsets can merge (...), the mapping is maintained to know which other subsets the current subset is merged with.
   
   1. What do you mean by subsets? Is it the same as subgraphs?
   2. What mapping are you referring to here?
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-706450846


   I fixed the problem with the cycle - every time when I add a new node to a subset, I check whether it's exclusion set is fully included in the future fusion node exclusion set. If so, I update all descendants of the fusion node (with the assumption that typically fusion subsets are not very big and their nodes are quite close together in the indexed graph, so the cost of that is not very high). Still, the check is quite expensive compared to the rest of the algorithm and now the cost of the graph pass is ~10-11 ms (still a huge improvement compared to the previous algorithm).


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506690876



##########
File path: src/imperative/simple_partition_pass.cc
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+void PrintSets(const IntervalVec* const sets_p) {

Review comment:
       Sure, it was used for debug.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-702449495


   Hey @ptrendx , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [website, sanity, windows-cpu, edge, unix-gpu, windows-gpu, centos-gpu, miscellaneous, unix-cpu, centos-cpu, clang]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mseth10 commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mseth10 commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-707066337


   @ptrendx In regards to this edge case https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-704372201, considering  **a0 -> xa -> b0 -> b1 -> xb -> a1** as the topological graph, before reaching node **xb**, we should have already combined **b0** and **b1** into subset **b** and updated exclusion set of **b** to contain **a0**. If that's the case, **a0** should automatically be included in the exclusion set for **xb**. Why then do we need to "update all descendants of the fusion node"?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mseth10 commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mseth10 commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506933734



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&

Review comment:
       Makes sense. In general, what are the control flow dependencies of a node? I know of only one - dependency of backward op on its corresponding forward op.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mseth10 commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mseth10 commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506933734



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&

Review comment:
       Makes sense. In general, what are the control flow dependencies of a node? Is it just the dependency of backward op node on its corresponding forward op node?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506725557



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&

Review comment:
       So if your dependency is going to be fused, then you need to have this helper instead.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] szha commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
szha commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-708089497


   @mseth10 any more comments on the PR?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-708048642


   Jenkins CI successfully triggered : [unix-cpu, unix-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-710646197


   Jenkins CI successfully triggered : [windows-gpu, unix-gpu, sanity, edge, windows-cpu, miscellaneous, centos-cpu, clang, website, unix-cpu, centos-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-704372201


   Ok, issue #19264 came at a perfect time - it actually exposed a corner case that breaks both the original and the new algorithm. The minimal example looks like this:
   ![Unfused graph](https://user-images.githubusercontent.com/8398980/95224823-ff5eb480-07af-11eb-9c91-387e2eb57515.png)
   where `a0`, `a1`, `b0` and `b1` are fusable, while `xa` and `xb` are not. According to the original algorithm `a0` and `a1` can be fused together and `b0` and `b1` can be fused together. But if we do that we get:
   ![Fused graph](https://user-images.githubusercontent.com/8398980/95225193-5d8b9780-07b0-11eb-91ff-fc48a9b4467d.png)
   It happens because the algorithm does not take into account that the exclusion sets of the nodes connected to a subgraph need to contain the excluded nodes from all nodes in that subgraph, even if they were added later in the subgraph creation process.
   
   I am working on a solution that will fix that.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506689534



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&

Review comment:
       It can. The role of `FusedOpHelper` is to get information (for shape/type inference) from inside the `FusedOp` to the outside world.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-710675887


   @mxnet-bot run ci [centos-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-710646058


   @mxnet-bot run ci [all]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506962531



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the subgraph
-    // to a dependencies between the subgraph node and the nodes out of the subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, &(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + "_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&

Review comment:
       Yeah, that is the dependency we are changing as it is used for infershape/infertype. Going forward I would like to get rid of this and never do the inferattr on the fused graph (instead do it on the original graph and map the results to the fused one), but it is outside the scope of this PR.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-702462814


   Seems cpplint that we use does not like structured binding from C++17.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#issuecomment-708048590


   Unix-gpu failed due to test_countsketch (issue https://github.com/apache/incubator-mxnet/issues/10988)
   Unix-cpu failed due to instability in test_np_average (issue https://github.com/apache/incubator-mxnet/issues/19071)
   
   @mxnet-bot run ci [unix-gpu, unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience merged pull request #19269: Faster pointwise fusion graph pass

Posted by GitBox <gi...@apache.org>.
sxjscience merged pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org