You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/30 23:51:55 UTC

[GitHub] [tvm] mbs-octoml opened a new pull request, #11981: [Collage] SubGraphs

mbs-octoml opened a new pull request, #11981:
URL: https://github.com/apache/tvm/pull/11981

   See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md.
   
   Collage works in units of 'sub-graphs', which are potential partitions of the
   overall Relay model. This PR introduces SubGraph (an arbitrary partitioning, without
   any implication about how it is to be represented), it's companion SubSubGraph
   (implying a representation as a function), and some supporting odds 'n ends.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914622659


##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;
+  os << ",max_max_depth=" << max_max_depth;
+  os << "}";
+  return os.str();
+}
+
+TVM_REGISTER_NODE_TYPE(SubSubGraphNode);
+
+void SubSubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+SubGraph SubSubGraphNode::sub_graph() const { return Downcast<SubGraph>(sub_graph_obj_); }
+
+bool SubSubGraphNode::operator==(const SubSubGraphNode& that) const {
+  return *sub_graph().get() == *that.sub_graph().get();
+}
+
+bool SubSubGraphNode::operator<(const SubSubGraphNode& that) const {
+  return *sub_graph().get() < *that.sub_graph().get();
+}
+
+size_t SubSubGraphNode::hash() const {
+  size_t h = StructuralHash()(attrs_);
+  h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  return h;
+}
+
+std::string SubSubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{sub_graph=" << sub_graph()->ToString();
+  os << ",attrs=" << PrettyPrint(attrs_);

Review Comment:
   ```suggestion
     os << ", attrs=" << PrettyPrint(attrs_);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914626850


##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;
+  os << ",max_max_depth=" << max_max_depth;
+  os << "}";
+  return os.str();
+}
+
+TVM_REGISTER_NODE_TYPE(SubSubGraphNode);
+
+void SubSubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+SubGraph SubSubGraphNode::sub_graph() const { return Downcast<SubGraph>(sub_graph_obj_); }
+
+bool SubSubGraphNode::operator==(const SubSubGraphNode& that) const {
+  return *sub_graph().get() == *that.sub_graph().get();
+}
+
+bool SubSubGraphNode::operator<(const SubSubGraphNode& that) const {
+  return *sub_graph().get() < *that.sub_graph().get();
+}
+
+size_t SubSubGraphNode::hash() const {
+  size_t h = StructuralHash()(attrs_);
+  h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  return h;
+}
+
+std::string SubSubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{sub_graph=" << sub_graph()->ToString();
+  os << ",attrs=" << PrettyPrint(attrs_);
+  os << "}";
+  return os.str();
+}
+
+Function SubSubGraphNode::Extract(const DataflowGraph& dataflow_graph) const {
+  Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_);
+  extractor.Extract();
+  return Downcast<Function>(extractor.extracted());
+}
+
+Expr SubSubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const {
+  Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_);
+  extractor.Extract();
+  Rewriter rewriter(&extractor);
+  return rewriter.VisitExpr(expr);
+}
+
+SubSubGraph::SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs) {
+  auto data = runtime::make_object<SubSubGraphNode>();
+  data->sub_graph_obj_ = std::move(sub_graph);
+  data->attrs_ = std::move(attrs);
+  data_ = std::move(data);
+}
+
+SubSubGraph SubSubGraph::Subst(const DataflowGraph& new_dataflow_graph,
+                               const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const {
+  return SubSubGraph(get()->sub_graph().Subst(new_dataflow_graph, subst), get()->attrs_);
+}
+
+bool SubSubGraph::TriviallyUnionable(const SubSubGraph& that) const {
+  if (get()->attrs_.size() != that->attrs_.size()) {
+    return false;
+  }
+  for (const auto& kv : get()->attrs_) {
+    if (kv.first == "Composite") {
+      // Even if all the attributes agree we don't consider "Composite" functions to
+      // ever be unionable.
+      // TODO(mbs): Find a cleaner way to do this.
+      return false;
+    }
+    auto itr = that->attrs_.find(kv.first);
+    if (itr == that->attrs_.end()) {
+      return false;
+    }
+    if (!StructuralEqual()(kv.second, (*itr).second)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+SubSubGraph SubSubGraph::DisjointUnion(const DataflowGraph& dataflow_graph,
+                                       const SubSubGraph& that) const {
+  ICHECK(TriviallyUnionable(that));
+  return SubSubGraph(get()->sub_graph().DisjointUnion(dataflow_graph, that->sub_graph()),
+                     get()->attrs_);
+}
+
+/*static*/
+Expr SubSubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                                  std::vector<SubSubGraph> sub_sub_graphs) {
+  // IMPORTANT: See the corresponding comment in SubGraph::ParallelRewrite.
+  std::sort(sub_sub_graphs.begin(), sub_sub_graphs.end(),
+            [](const SubSubGraph& left, const SubSubGraph& right) {
+              return left->sub_graph()->last_inside_index_ > right->sub_graph()->last_inside_index_;
+            });
+
+  Expr result = expr;
+  for (const auto& sub_sub_graph : sub_sub_graphs) {
+    result = sub_sub_graph->Rewrite(dataflow_graph, result);
+  }
+  return result;
+}
+
+TVM_REGISTER_NODE_TYPE(SubGraphNode);
+
+void SubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+IndexSet SubGraphNode::Downstream(const DataflowGraph& dataflow_graph) const {
+  IndexSet downstream(dataflow_graph.size());
+  for (PostDfsIndex exit_index : exit_) {
+    downstream = downstream | dataflow_graph.downstream_of(exit_index);
+  }
+  return downstream;
+}
+
+bool SubGraphNode::IsValid(const DataflowGraph& dataflow_graph,
+                           const SubGraphConfig& config) const {
+  // Check we don't have too many exit nodes.
+  if (config.max_exits > 0 && exit_.PopCount() > config.max_exits) {
+    VLOG(1) << "Subgraph " << ToString() << " is invalid: " << exit_.PopCount()
+            << " exits exceeds maximum " << config.max_exits;
+    return false;
+  }
+
+  // Check the maximum path depth is in limit.
+  if (config.max_max_depth > 0 && max_depth_ > config.max_max_depth) {
+    VLOG(1) << "Subgraph " << ToString() << " is invalid: maximum depth " << max_depth_
+            << " exceeds limit " << config.max_max_depth;
+    return false;
+  }
+
+  // All inside nodes must be in the same basic block.
+  const DataflowGraph::Node* basic_block = nullptr;
+  for (PostDfsIndex index : inside_) {
+    auto node = dataflow_graph.index_to_node(index);
+    if (basic_block == nullptr) {
+      basic_block = node->basic_block_;
+    }
+    if (node->basic_block_ != basic_block) {
+      VLOG(1) << "Subgraph " << ToString() << " is invalid: nodes are from different basic blocks";
+      return false;
+    }
+  }
+
+  // The sub-sub-graphs must be subsets and non-overlapping.
+  IndexSet union_inside(dataflow_graph.size());
+  for (const auto& sub_sub_graph : sub_sub_graphs_) {
+    if (!sub_sub_graph->sub_graph()->inside_.AreDisjoint(union_inside)) {
+      VLOG(1) << "Subgraph " << ToString() << " is invalid: sub-sub-graphs overlap";
+      return false;
+    }
+    if (!sub_sub_graph->sub_graph()->inside_.IsSubset(inside_)) {
+      VLOG(1) << "Subgraph " << ToString()
+              << " is invalid: sub-sub-graph is not subset of overall sub-graph";
+      return false;
+    }
+  }
+
+  if (!config.allow_taps) {
+    // Exit nodes cannot also contribute to inside nodes.
+    for (PostDfsIndex index : exit_) {
+      auto node = dataflow_graph.index_to_node(index);
+      if (AnyOutputInside(node)) {
+        VLOG(1) << "Subgraph " << ToString()
+                << " is invalid: inner node is 'tapped' and also contributes to output, but taps "
+                   "are disabled";
+        return false;
+      }
+    }
+  }
+
+  // Check no output would end up feeding into any entry node.
+  for (PostDfsIndex output_index : output_) {
+    if (dataflow_graph.downstream_of(output_index).Intersects(entry_)) {
+      VLOG(1) << "Subgraph " << ToString() << " is invalid: output node " << output_index
+              << " feeds back into this sub-graph";
+      return false;
+    }
+  }
+
+  // Looks legit!
+  return true;
+}
+
+Function SubGraphNode::ExtractAsFunction(const DataflowGraph& dataflow_graph) const {
+  SubSubGraph sub_sub_graph(GetRef<SubGraph>(this), FunctionAttrsMap());
+  return sub_sub_graph->Extract(dataflow_graph);
+}
+
+Expr SubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const {
+  if (sub_sub_graphs_.empty()) {
+    // Nothing to rewrite.
+    return expr;
+  }
+  Extractor extractor(&dataflow_graph, this, NullValue<FunctionAttrsMap>());
+  extractor.Extract();
+  Rewriter rewriter(&extractor);
+  return rewriter.VisitExpr(expr);
+}
+
+std::string SubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{inside=" << inside_.ToString();
+  os << ",entry=" << entry_.ToString();
+  os << ",exit=" << exit_.ToString();
+  os << ",input=" << input_.ToString();
+  os << ",output=" << output_.ToString();
+  os << ",max_depth=" << max_depth_;
+  os << ",kind=" << KindToString(kind_);
+  if (!label_.empty()) {
+    os << ",label=" << label_;
+  }
+  for (const auto& sub_sub_graph : sub_sub_graphs_) {
+    os << ",sub_sub_graph=" << sub_sub_graph->ToString();
+  }
+  os << "}";
+  return os.str();
+}
+
+bool SubGraphNode::operator==(const SubGraphNode& that) const {
+  ICHECK_EQ(inside_.end_index(), that.inside_.end_index());
+  if (inside_ != that.inside_) {
+    return false;
+  }
+  if (sub_sub_graphs_.size() != that.sub_sub_graphs_.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < sub_sub_graphs_.size(); ++i) {
+    if (*sub_sub_graphs_[i].get() != *that.sub_sub_graphs_[i].get()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool SubGraphNode::operator<(const SubGraphNode& that) const {
+  if (first_inside_index_ < that.first_inside_index_) {
+    return true;
+  }
+  if (that.first_inside_index_ < first_inside_index_) {
+    return false;
+  }
+  return inside_ < that.inside_;
+}
+
+size_t SubGraphNode::hash() const {
+  size_t h = inside_.hash();
+  for (const auto& sub_sub_graph : sub_sub_graphs_) {
+    h ^= sub_sub_graph->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  }
+  return h;
+}
+
+void SubGraphNode::Init(const DataflowGraph& dataflow_graph) {
+  for (PostDfsIndex index = 0; index < inside_.end_index(); ++index) {
+    auto node = dataflow_graph.index_to_node(index);
+    if (inside_[index]) {
+      if (AnyInputOutside(node)) {
+        entry_.Add(index);
+      }
+      if (AnyOutputOutside(node) || node->is_external_) {
+        exit_.Add(index);
+      }
+    } else {
+      if (AnyInputInside(node)) {
+        output_.Add(index);
+      }
+      if (AnyOutputInside(node) && !CanInline(node->ref())) {
+        input_.Add(index);
+      }
+    }
+  }
+  max_depth_ = MaxDepth(dataflow_graph);
+}
+
+size_t SubGraphNode::MaxDepth(const DataflowGraph& dataflow_graph) const {
+  std::unordered_map<const DataflowGraph::Node*, size_t> max_depths;
+  std::vector<const DataflowGraph::Node*> stack;
+  size_t max_depth = 0;
+  // All the entry nodes have max depth 0.
+  for (PostDfsIndex index : entry_) {
+    auto node = dataflow_graph.index_to_node(index);
+    max_depths.emplace(node, 0);
+    stack.push_back(node);
+  }
+  while (!stack.empty()) {
+    const DataflowGraph::Node* node = stack.back();
+    stack.pop_back();
+    size_t next_depth = max_depths[node] + 1;
+    if (exit_[node->index_]) {
+      // If this node is external then it will have no outputs but we still wish to consider
+      // the path to the implied output as requiring one more step.
+      // Otherwise we're accounting for reaching one of the external outputs belowe.
+      max_depth = std::max(max_depth, next_depth);
+    }
+    for (const DataflowGraph::Node* output_node : node->outputs_) {
+      if (!inside_[output_node->index_]) {
+        continue;
+      }
+      if (max_depths.count(output_node) == 0) {
+        max_depths.emplace(output_node, next_depth);
+        stack.push_back(output_node);
+      } else if (next_depth > max_depths[output_node]) {
+        // We found a deeper path to an already expanded node. We'll expand again.
+        max_depths[output_node] = next_depth;
+        stack.push_back(output_node);
+      }
+    }
+  }
+  return max_depth;
+}
+
+/*! \brief Return's true if any (input/output) of node is (outside/inside) the sub-graph.  */

Review Comment:
   ```suggestion
   /*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph.  */
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbaret commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbaret commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r917919373


##########
src/relay/collage/README.md:
##########
@@ -0,0 +1,26 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+The `CollagePartition` pass for finding optimal partitionings of Relay models.
+
+See the [RFC](https://github.com/mbs-octoml/mbs-tvm-rfcs/blob/mbs-rfcs-collage/rfcs/xxxx-collage.md).
+
+Based on:
+> *Collage: Automated Integration of Deep Learning Backends*  
+> Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia
+
+CAUTION: This is a prototype, do not use in prod.

Review Comment:
   Might be wise to get rid of this now :)



##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {

Review Comment:
   I think 'NestedSubGraph' is a clearer name for this class.



##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {

Review Comment:
   It feels like there's an opportunity here to have SubSubGraph and SubGraph derive from some BaseSubGraph. Any thoughts on whether this would be practical?



##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {
+ public:
+  SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs);
+
+  /*!
+   * \brief Returns copy of this sub-sub-graph with all indexes substituted according to \p subst,
+   * whose range is w.r.t. \p new_dataflow_graph.
+   */
+  SubSubGraph Subst(const DataflowGraph& new_dataflow_graph,
+                    const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
+
+  /*!
+   * \brief Returns true if this can be safely unioned.
+   */
+  bool TriviallyUnionable(const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns the disjoin union of this and \p that sub-sub graphs, which must agree on
+   * their attributes.
+   */
+  SubSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns \p expr rewritten according to all the given sub-sub-graphs. The sub-sub-graphs
+   * can be given in any order, but must be disjoint.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside the sub-sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                              std::vector<SubSubGraph> sub_sub_graphs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SubSubGraph, ObjectRef, SubSubGraphNode);
+};
+
+using SubSubGraphs = Array<SubSubGraph>;
+
+/*!
+ * \brief A compact representation of a sub-graph within an (implied) overall Relay expression.
+ *
+ * Sub-graphs can be used to represent partitions/kernels/composite functions without having to
+ * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a
+ * function to use for measuring a partition/kernel's latency independently from 'rewriting'
+ * the overall Relay expression since only a tiny subset of candidate partitions will end up being
+ * needed after Collage has completed its search.
+ *
+ * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so are
+ * mindful of space overhead.
+ *
+ * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or
+ * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not
+ * valid for an inside node to feed into another inside node via outside nodes. We provide the
+ * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules
+ * apply (such as maximum depth).
+ *
+ * We generally work with the \p DataflowGraph representation of the overall Relay expression
+ * rather than the expression itself. We use the post-dfs visit index to uniquely refer to
+ * expression nodes.
+ *
+ * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely
+ * determined from the 'inside' nodes:
+ *  - 'entry' nodes are those inside with at least one dataflow input outside.
+ *  - 'exit' nodes are  those inside with at least one dataflow output outside, or which
+ *    are considered 'external' in the underlying dataflow graph (eg because they represent
+ *    the result of the overall function).
+ *  - 'input' nodes are those outside with at least one dataflow output inside.
+ *  - 'output' nodes are those outside with at least one dataflow input inside.
+ * Index sets for these are cached with the sub-graph for performance.
+ *
+ * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to
+ * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes
+ * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result).
+ *
+ * Sub-graphs are closed under:
+ *  - Disjoint union.
+ *  - Wrapping by a function with given attributes (see \p SubSubGraph above). This can be used
+ *    to encode "Composite" functions, or to represent a candidate kernel within a "Primitive"
+ *    function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should
+ *    be placed inside a primitive function which itself may have calls to composite functions).
+ *  - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to
+ *    match some other (typically smaller) dataflow graph.
+ *
+ * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage
+ * search.
+ *
+ * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs
+ * a kind, which is generally the maximum of the kinds of all the operator calls appearing
+ * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging
+ * and guide the selection of global symbol names.
+ */
+class SubGraphNode : public Object {
+ public:
+  /*!
+   * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t.
+   * the implied DataflowGraph).
+   */
+  IndexSet inside_;
+
+  /*!
+   * \brief Index of first and last inside nodes.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  PostDfsIndex first_inside_index_ = 0;
+  PostDfsIndex last_inside_index_ = 0;
+
+  /*!
+   * \brief Which sub-expressions are entry/exit/input/output for this sub-graph.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  IndexSet entry_;
+  IndexSet exit_;
+  IndexSet input_;
+  IndexSet output_;
+
+  /*!
+   * \brief Maximum depth of any dataflow path from an entry to an output sub-expression.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  size_t max_depth_ = 0;
+
+  /*!
+   * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph.
+   *
+   * A sub-graph consisting of a single Relay expression node is given kind:
+   *  - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the
+   *    call does not involve data-dependent dynamic shapes).
+   *  - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has
+   *    that attribute)
+   *  - For Constants, \p kElemWise.
+   *  - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor
+   *    type)
+   *  - All other nodes \p kOpaque.
+   * Sub-graphs with more than one node have the maximum of the kind of each node.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  OpPatternKind kind_ = kOpaque;
+
+  /*!
+   * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary
+   * of the sub-graph which can help with debugging and guide the selection of global symbol names.
+   */
+  String label_;
+
+  /*!
+   * \brief Sub-sub-graphs of this sub-graph which must be represented by functions. These must
+   * be disjoint, but it's ok for this sub-graph to have nodes not inside any sub-sub-graph.
+   */
+  SubSubGraphs sub_sub_graphs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  // TODO(mbs): 'Anchor nodes' and rules for unioning them.
+  // In FuseOps it's just the unique kEWiseFusable node, if any.
+  // I'd like to allow writing vertical fusion rules, eg if two candidates are directly
+  // connected and have nn.conv2d anchors allow their join.
+  // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly
+  // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors
+  // then do so. Come back to this.
+
+  /*! \brief Number of nodes in overall dataflow graph. */
+  size_t overall_size() const { return inside_.end_index(); }
+
+  bool IsEmpty() const { return inside_.IsZero(); }
+
+  /*! \brief Number of nodes in sub-graph. */
+  size_t Size() const { return inside_.PopCount(); }
+
+  /*!
+   * \brief Returns the dataflow nodes downstream of all exit nodes.
+   */
+  IndexSet Downstream(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns true if this sub-graph is valid. Ie:
+   *  - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up
+   *    with a dataflow cycle when we partition).
+   *  - all inputs and outputs of the sub-graph are in the same scope, ie not separated by
+   *    control flow (otherwise there'd be no consistent program point at which to eval the
+   *    partitioned function).
+   *  - no more than config.max_outputs outputs are require.
+   *  - if config.allow_taps is false, no inside node has outputs to nodes both inside and
+   *    outside the sub-graph.
+   */
+  bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const;
+
+  /*!
+   * \brief Returns this sub-graph extracted as a stand-alone function. The function will have
+   * no attributes, and is suitable for building and profiling by the \p CostEstimator.
+   */
+  Function ExtractAsFunction(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  std::string ToString() const;
+
+  bool operator==(const SubGraphNode& that) const;
+  bool operator!=(const SubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubGraphNode& that) const;
+  size_t hash() const;
+
+ private:
+  /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */
+  void Init(const DataflowGraph& dataflow_graph);
+
+  /*! \brief Calculates and returns the maximum path depth. */
+  size_t MaxDepth(const DataflowGraph& dataflow_graph) const;

Review Comment:
   I think this should be referred to as just Depth, 'MaxDepth' to me implies a restriction (and will later lead to the unusual MaxMaxDepth parameter). Equally, I think the depth of a graph/tree is commonly understood to be the depth from the root to the deepest node.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi merged pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
masahi merged PR #11981:
URL: https://github.com/apache/tvm/pull/11981


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914505662


##########
src/relay/collage/dataflow_graph.h:
##########
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/dataflow_graph.h
+ * \brief A representation of the dataflow for an overall Relay expression.
+ */
+#ifndef TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_
+#define TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_
+
+#include <tvm/relay/expr.h>
+
+#include <memory>
+#include <vector>
+
+#include "../ir/indexed_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief Represents the dataflow of an overall Relay expression.
+ */
+class DataflowGraph {
+ public:
+  using Node = IndexedGraph<Expr>::Node;
+
+  explicit DataflowGraph(Expr expr);
+
+  size_t size() const { return indexed_graph_->size(); }
+  const Node* index_to_node(PostDfsIndex index) const {
+    return indexed_graph_->index_to_node(index);
+  }
+  const Node* item_to_node(const Expr& expr) const { return indexed_graph_->item_to_node(expr); }
+  const Node* item_to_node(const ExprNode* expr_node) const {
+    return indexed_graph_->item_to_node(expr_node);
+  }
+  const Expr& expr() const { return expr_; }
+  const IndexedGraph<Expr>& indexed_graph() const { return *indexed_graph_; }
+
+  const IndexSet& downstream_of(PostDfsIndex index) const {
+    ICHECK_LT(index, indexed_graph_->size());
+    return downstream_map_[index];
+  }
+
+ private:
+  /*! \brief The overall expression. */
+  Expr expr_;
+  /*! \brief The indexed graph which captures the main dataflow. */
+  std::unique_ptr<IndexedGraph<Expr>> indexed_graph_;
+  /*! \brief Map from a node's PostDfsIndex to the set of it's downstream dataflow node indexes. */

Review Comment:
   ```suggestion
     /*! \brief Map from a node's PostDfsIndex to the set of its downstream dataflow node indexes. */
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914645778


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {
+ public:
+  SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs);
+
+  /*!
+   * \brief Returns copy of this sub-sub-graph with all indexes substituted according to \p subst,
+   * whose range is w.r.t. \p new_dataflow_graph.
+   */
+  SubSubGraph Subst(const DataflowGraph& new_dataflow_graph,
+                    const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
+
+  /*!
+   * \brief Returns true if this can be safely unioned.
+   */
+  bool TriviallyUnionable(const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns the disjoin union of this and \p that sub-sub graphs, which must agree on
+   * their attributes.
+   */
+  SubSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns \p expr rewritten according to all the given sub-sub-graphs. The sub-sub-graphs
+   * can be given in any order, but must be disjoint.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside the sub-sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                              std::vector<SubSubGraph> sub_sub_graphs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SubSubGraph, ObjectRef, SubSubGraphNode);
+};
+
+using SubSubGraphs = Array<SubSubGraph>;
+
+/*!
+ * \brief A compact representation of a sub-graph within an (implied) overall Relay expression.
+ *
+ * Sub-graphs can be used to represent partitions/kernels/composite functions without having to
+ * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a
+ * function to use for measuring a partition/kernel's latency independently from 'rewriting'
+ * the overall Relay expression since only a tiny subset of candidate partitions will end up being
+ * needed after Collage has completed its search.
+ *
+ * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so are
+ * mindful of space overhead.
+ *
+ * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or
+ * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not
+ * valid for an inside node to feed into another inside node via outside nodes. We provide the
+ * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules
+ * apply (such as maximum depth).
+ *
+ * We generally work with the \p DataflowGraph representation of the overall Relay expression
+ * rather than the expression itself. We use the post-dfs visit index to uniquely refer to
+ * expression nodes.
+ *
+ * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely
+ * determined from the 'inside' nodes:
+ *  - 'entry' nodes are those inside with at least one dataflow input outside.
+ *  - 'exit' nodes are  those inside with at least one dataflow output outside, or which
+ *    are considered 'external' in the underlying dataflow graph (eg because they represent
+ *    the result of the overall function).
+ *  - 'input' nodes are those outside with at least one dataflow output inside.
+ *  - 'output' nodes are those outside with at least one dataflow input inside.
+ * Index sets for these are cached with the sub-graph for performance.
+ *
+ * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to
+ * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes
+ * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result).
+ *
+ * Sub-graphs are closed under:
+ *  - Disjoint union.
+ *  - Wrapping by a function with given attributes (see \p SubSubGraph above). This can be used
+ *    to encode "Composite" functions, or to represent a candidate kernel within a "Primitive"
+ *    function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should
+ *    be placed inside a primitive function which itself may have calls to composite functions).
+ *  - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to
+ *    match some other (typically smaller) dataflow graph.
+ *
+ * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage
+ * search.
+ *
+ * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs
+ * a kind, which is generally the maximum of the kinds of all the operator calls appearing
+ * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging
+ * and guide the selection of global symbol names.
+ */
+class SubGraphNode : public Object {
+ public:
+  /*!
+   * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t.
+   * the implied DataflowGraph).
+   */
+  IndexSet inside_;
+
+  /*!
+   * \brief Index of first and last inside nodes.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  PostDfsIndex first_inside_index_ = 0;
+  PostDfsIndex last_inside_index_ = 0;
+
+  /*!
+   * \brief Which sub-expressions are entry/exit/input/output for this sub-graph.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  IndexSet entry_;
+  IndexSet exit_;
+  IndexSet input_;
+  IndexSet output_;
+
+  /*!
+   * \brief Maximum depth of any dataflow path from an entry to an output sub-expression.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  size_t max_depth_ = 0;
+
+  /*!
+   * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph.
+   *
+   * A sub-graph consisting of a single Relay expression node is given kind:
+   *  - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the
+   *    call does not involve data-dependent dynamic shapes).
+   *  - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has
+   *    that attribute)
+   *  - For Constants, \p kElemWise.
+   *  - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor
+   *    type)
+   *  - All other nodes \p kOpaque.
+   * Sub-graphs with more than one node have the maximum of the kind of each node.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  OpPatternKind kind_ = kOpaque;
+
+  /*!
+   * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary
+   * of the sub-graph which can help with debugging and guide the selection of global symbol names.
+   */
+  String label_;
+
+  /*!
+   * \brief Sub-sub-graphs of this sub-graph which must be represented by functions. These must
+   * be disjoint, but it's ok for this sub-graph to have nodes not inside any sub-sub-graph.
+   */
+  SubSubGraphs sub_sub_graphs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  // TODO(mbs): 'Anchor nodes' and rules for unioning them.
+  // In FuseOps it's just the unique kEWiseFusable node, if any.
+  // I'd like to allow writing vertical fusion rules, eg if two candidates are directly
+  // connected and have nn.conv2d anchors allow their join.
+  // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly
+  // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors
+  // then do so. Come back to this.
+
+  /*! \brief Number of nodes in overall dataflow graph. */
+  size_t overall_size() const { return inside_.end_index(); }
+
+  bool IsEmpty() const { return inside_.IsZero(); }
+
+  /*! \brief Number of nodes in sub-graph. */
+  size_t Size() const { return inside_.PopCount(); }
+
+  /*!
+   * \brief Returns the dataflow nodes downstream of all exit nodes.
+   */
+  IndexSet Downstream(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns true if this sub-graph is valid. Ie:
+   *  - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up
+   *    with a dataflow cycle when we partition).
+   *  - all inputs and outputs of the sub-graph are in the same scope, ie not separated by
+   *    control flow (otherwise there'd be no consistent program point at which to eval the
+   *    partitioned function).
+   *  - no more than config.max_outputs outputs are require.

Review Comment:
   ```suggestion
      *  - no more than config.max_outputs outputs are required.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914642215


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {
+ public:
+  SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs);
+
+  /*!
+   * \brief Returns copy of this sub-sub-graph with all indexes substituted according to \p subst,
+   * whose range is w.r.t. \p new_dataflow_graph.
+   */
+  SubSubGraph Subst(const DataflowGraph& new_dataflow_graph,
+                    const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
+
+  /*!
+   * \brief Returns true if this can be safely unioned.
+   */
+  bool TriviallyUnionable(const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns the disjoin union of this and \p that sub-sub graphs, which must agree on
+   * their attributes.
+   */
+  SubSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns \p expr rewritten according to all the given sub-sub-graphs. The sub-sub-graphs
+   * can be given in any order, but must be disjoint.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside the sub-sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                              std::vector<SubSubGraph> sub_sub_graphs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SubSubGraph, ObjectRef, SubSubGraphNode);
+};
+
+using SubSubGraphs = Array<SubSubGraph>;
+
+/*!
+ * \brief A compact representation of a sub-graph within an (implied) overall Relay expression.
+ *
+ * Sub-graphs can be used to represent partitions/kernels/composite functions without having to
+ * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a
+ * function to use for measuring a partition/kernel's latency independently from 'rewriting'
+ * the overall Relay expression since only a tiny subset of candidate partitions will end up being
+ * needed after Collage has completed its search.
+ *
+ * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so are

Review Comment:
   ```suggestion
    * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so be
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on PR #11981:
URL: https://github.com/apache/tvm/pull/11981#issuecomment-1180710973

   Thanks @SebastianBoblest, every little bit helps.
   
   Thanks and PTAL @mbaret.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on PR #11981:
URL: https://github.com/apache/tvm/pull/11981#issuecomment-1172683228

   green and ready for review


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r918212033


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {

Review Comment:
   I hadn't thought of that actually. At first I had SubGraph be directly recursive until I realized things were much clearer with the intermediate NestedSubGraph, and was so happy with that I didn't push further. You are right there's some signature sharing but no implementation sharing I can see, and I think making code polymorphic on SubGraph vs NestedSubGraph would only make things even more confusing. So let me leave it as is.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on PR #11981:
URL: https://github.com/apache/tvm/pull/11981#issuecomment-1176033088

   I read this out of curiosity, I cannot review this on a technical level.
   On a fundamental level the code looks great though.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914625838


##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;
+  os << ",max_max_depth=" << max_max_depth;
+  os << "}";
+  return os.str();
+}
+
+TVM_REGISTER_NODE_TYPE(SubSubGraphNode);
+
+void SubSubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+SubGraph SubSubGraphNode::sub_graph() const { return Downcast<SubGraph>(sub_graph_obj_); }
+
+bool SubSubGraphNode::operator==(const SubSubGraphNode& that) const {
+  return *sub_graph().get() == *that.sub_graph().get();
+}
+
+bool SubSubGraphNode::operator<(const SubSubGraphNode& that) const {
+  return *sub_graph().get() < *that.sub_graph().get();
+}
+
+size_t SubSubGraphNode::hash() const {
+  size_t h = StructuralHash()(attrs_);
+  h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  return h;
+}
+
+std::string SubSubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{sub_graph=" << sub_graph()->ToString();
+  os << ",attrs=" << PrettyPrint(attrs_);

Review Comment:
   It seems you intentionally omit the whitespace, so please ignore comment.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914620736


##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;

Review Comment:
   ```suggestion
     os << ", allow_taps=" << allow_taps;
   ```



##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;
+  os << ",max_max_depth=" << max_max_depth;

Review Comment:
   ```suggestion
     os << ", max_max_depth=" << max_max_depth;
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914631285


##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;
+  os << ",max_max_depth=" << max_max_depth;
+  os << "}";
+  return os.str();
+}
+
+TVM_REGISTER_NODE_TYPE(SubSubGraphNode);
+
+void SubSubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+SubGraph SubSubGraphNode::sub_graph() const { return Downcast<SubGraph>(sub_graph_obj_); }
+
+bool SubSubGraphNode::operator==(const SubSubGraphNode& that) const {
+  return *sub_graph().get() == *that.sub_graph().get();
+}
+
+bool SubSubGraphNode::operator<(const SubSubGraphNode& that) const {
+  return *sub_graph().get() < *that.sub_graph().get();
+}
+
+size_t SubSubGraphNode::hash() const {
+  size_t h = StructuralHash()(attrs_);
+  h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  return h;
+}
+
+std::string SubSubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{sub_graph=" << sub_graph()->ToString();
+  os << ",attrs=" << PrettyPrint(attrs_);
+  os << "}";
+  return os.str();
+}
+
+Function SubSubGraphNode::Extract(const DataflowGraph& dataflow_graph) const {
+  Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_);
+  extractor.Extract();
+  return Downcast<Function>(extractor.extracted());
+}
+
+Expr SubSubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const {
+  Extractor extractor(&dataflow_graph, sub_graph().get(), attrs_);
+  extractor.Extract();
+  Rewriter rewriter(&extractor);
+  return rewriter.VisitExpr(expr);
+}
+
+SubSubGraph::SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs) {
+  auto data = runtime::make_object<SubSubGraphNode>();
+  data->sub_graph_obj_ = std::move(sub_graph);
+  data->attrs_ = std::move(attrs);
+  data_ = std::move(data);
+}
+
+SubSubGraph SubSubGraph::Subst(const DataflowGraph& new_dataflow_graph,
+                               const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const {
+  return SubSubGraph(get()->sub_graph().Subst(new_dataflow_graph, subst), get()->attrs_);
+}
+
+bool SubSubGraph::TriviallyUnionable(const SubSubGraph& that) const {
+  if (get()->attrs_.size() != that->attrs_.size()) {
+    return false;
+  }
+  for (const auto& kv : get()->attrs_) {
+    if (kv.first == "Composite") {
+      // Even if all the attributes agree we don't consider "Composite" functions to
+      // ever be unionable.
+      // TODO(mbs): Find a cleaner way to do this.
+      return false;
+    }
+    auto itr = that->attrs_.find(kv.first);
+    if (itr == that->attrs_.end()) {
+      return false;
+    }
+    if (!StructuralEqual()(kv.second, (*itr).second)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+SubSubGraph SubSubGraph::DisjointUnion(const DataflowGraph& dataflow_graph,
+                                       const SubSubGraph& that) const {
+  ICHECK(TriviallyUnionable(that));
+  return SubSubGraph(get()->sub_graph().DisjointUnion(dataflow_graph, that->sub_graph()),
+                     get()->attrs_);
+}
+
+/*static*/
+Expr SubSubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                                  std::vector<SubSubGraph> sub_sub_graphs) {
+  // IMPORTANT: See the corresponding comment in SubGraph::ParallelRewrite.
+  std::sort(sub_sub_graphs.begin(), sub_sub_graphs.end(),
+            [](const SubSubGraph& left, const SubSubGraph& right) {
+              return left->sub_graph()->last_inside_index_ > right->sub_graph()->last_inside_index_;
+            });
+
+  Expr result = expr;
+  for (const auto& sub_sub_graph : sub_sub_graphs) {
+    result = sub_sub_graph->Rewrite(dataflow_graph, result);
+  }
+  return result;
+}
+
+TVM_REGISTER_NODE_TYPE(SubGraphNode);
+
+void SubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+IndexSet SubGraphNode::Downstream(const DataflowGraph& dataflow_graph) const {
+  IndexSet downstream(dataflow_graph.size());
+  for (PostDfsIndex exit_index : exit_) {
+    downstream = downstream | dataflow_graph.downstream_of(exit_index);
+  }
+  return downstream;
+}
+
+bool SubGraphNode::IsValid(const DataflowGraph& dataflow_graph,
+                           const SubGraphConfig& config) const {
+  // Check we don't have too many exit nodes.
+  if (config.max_exits > 0 && exit_.PopCount() > config.max_exits) {
+    VLOG(1) << "Subgraph " << ToString() << " is invalid: " << exit_.PopCount()
+            << " exits exceeds maximum " << config.max_exits;
+    return false;
+  }
+
+  // Check the maximum path depth is in limit.
+  if (config.max_max_depth > 0 && max_depth_ > config.max_max_depth) {
+    VLOG(1) << "Subgraph " << ToString() << " is invalid: maximum depth " << max_depth_
+            << " exceeds limit " << config.max_max_depth;
+    return false;
+  }
+
+  // All inside nodes must be in the same basic block.
+  const DataflowGraph::Node* basic_block = nullptr;
+  for (PostDfsIndex index : inside_) {
+    auto node = dataflow_graph.index_to_node(index);
+    if (basic_block == nullptr) {
+      basic_block = node->basic_block_;
+    }
+    if (node->basic_block_ != basic_block) {
+      VLOG(1) << "Subgraph " << ToString() << " is invalid: nodes are from different basic blocks";
+      return false;
+    }
+  }
+
+  // The sub-sub-graphs must be subsets and non-overlapping.
+  IndexSet union_inside(dataflow_graph.size());
+  for (const auto& sub_sub_graph : sub_sub_graphs_) {
+    if (!sub_sub_graph->sub_graph()->inside_.AreDisjoint(union_inside)) {
+      VLOG(1) << "Subgraph " << ToString() << " is invalid: sub-sub-graphs overlap";
+      return false;
+    }
+    if (!sub_sub_graph->sub_graph()->inside_.IsSubset(inside_)) {
+      VLOG(1) << "Subgraph " << ToString()
+              << " is invalid: sub-sub-graph is not subset of overall sub-graph";
+      return false;
+    }
+  }
+
+  if (!config.allow_taps) {
+    // Exit nodes cannot also contribute to inside nodes.
+    for (PostDfsIndex index : exit_) {
+      auto node = dataflow_graph.index_to_node(index);
+      if (AnyOutputInside(node)) {
+        VLOG(1) << "Subgraph " << ToString()
+                << " is invalid: inner node is 'tapped' and also contributes to output, but taps "
+                   "are disabled";
+        return false;
+      }
+    }
+  }
+
+  // Check no output would end up feeding into any entry node.
+  for (PostDfsIndex output_index : output_) {
+    if (dataflow_graph.downstream_of(output_index).Intersects(entry_)) {
+      VLOG(1) << "Subgraph " << ToString() << " is invalid: output node " << output_index
+              << " feeds back into this sub-graph";
+      return false;
+    }
+  }
+
+  // Looks legit!
+  return true;
+}
+
+Function SubGraphNode::ExtractAsFunction(const DataflowGraph& dataflow_graph) const {
+  SubSubGraph sub_sub_graph(GetRef<SubGraph>(this), FunctionAttrsMap());
+  return sub_sub_graph->Extract(dataflow_graph);
+}
+
+Expr SubGraphNode::Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const {
+  if (sub_sub_graphs_.empty()) {
+    // Nothing to rewrite.
+    return expr;
+  }
+  Extractor extractor(&dataflow_graph, this, NullValue<FunctionAttrsMap>());
+  extractor.Extract();
+  Rewriter rewriter(&extractor);
+  return rewriter.VisitExpr(expr);
+}
+
+std::string SubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{inside=" << inside_.ToString();
+  os << ",entry=" << entry_.ToString();
+  os << ",exit=" << exit_.ToString();
+  os << ",input=" << input_.ToString();
+  os << ",output=" << output_.ToString();
+  os << ",max_depth=" << max_depth_;
+  os << ",kind=" << KindToString(kind_);
+  if (!label_.empty()) {
+    os << ",label=" << label_;
+  }
+  for (const auto& sub_sub_graph : sub_sub_graphs_) {
+    os << ",sub_sub_graph=" << sub_sub_graph->ToString();
+  }
+  os << "}";
+  return os.str();
+}
+
+bool SubGraphNode::operator==(const SubGraphNode& that) const {
+  ICHECK_EQ(inside_.end_index(), that.inside_.end_index());
+  if (inside_ != that.inside_) {
+    return false;
+  }
+  if (sub_sub_graphs_.size() != that.sub_sub_graphs_.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < sub_sub_graphs_.size(); ++i) {
+    if (*sub_sub_graphs_[i].get() != *that.sub_sub_graphs_[i].get()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+bool SubGraphNode::operator<(const SubGraphNode& that) const {
+  if (first_inside_index_ < that.first_inside_index_) {
+    return true;
+  }
+  if (that.first_inside_index_ < first_inside_index_) {
+    return false;
+  }
+  return inside_ < that.inside_;
+}
+
+size_t SubGraphNode::hash() const {
+  size_t h = inside_.hash();
+  for (const auto& sub_sub_graph : sub_sub_graphs_) {
+    h ^= sub_sub_graph->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  }
+  return h;
+}
+
+void SubGraphNode::Init(const DataflowGraph& dataflow_graph) {
+  for (PostDfsIndex index = 0; index < inside_.end_index(); ++index) {
+    auto node = dataflow_graph.index_to_node(index);
+    if (inside_[index]) {
+      if (AnyInputOutside(node)) {
+        entry_.Add(index);
+      }
+      if (AnyOutputOutside(node) || node->is_external_) {
+        exit_.Add(index);
+      }
+    } else {
+      if (AnyInputInside(node)) {
+        output_.Add(index);
+      }
+      if (AnyOutputInside(node) && !CanInline(node->ref())) {
+        input_.Add(index);
+      }
+    }
+  }
+  max_depth_ = MaxDepth(dataflow_graph);
+}
+
+size_t SubGraphNode::MaxDepth(const DataflowGraph& dataflow_graph) const {
+  std::unordered_map<const DataflowGraph::Node*, size_t> max_depths;
+  std::vector<const DataflowGraph::Node*> stack;
+  size_t max_depth = 0;
+  // All the entry nodes have max depth 0.
+  for (PostDfsIndex index : entry_) {
+    auto node = dataflow_graph.index_to_node(index);
+    max_depths.emplace(node, 0);
+    stack.push_back(node);
+  }
+  while (!stack.empty()) {
+    const DataflowGraph::Node* node = stack.back();
+    stack.pop_back();
+    size_t next_depth = max_depths[node] + 1;
+    if (exit_[node->index_]) {
+      // If this node is external then it will have no outputs but we still wish to consider
+      // the path to the implied output as requiring one more step.
+      // Otherwise we're accounting for reaching one of the external outputs belowe.
+      max_depth = std::max(max_depth, next_depth);
+    }
+    for (const DataflowGraph::Node* output_node : node->outputs_) {
+      if (!inside_[output_node->index_]) {
+        continue;
+      }
+      if (max_depths.count(output_node) == 0) {
+        max_depths.emplace(output_node, next_depth);
+        stack.push_back(output_node);
+      } else if (next_depth > max_depths[output_node]) {
+        // We found a deeper path to an already expanded node. We'll expand again.
+        max_depths[output_node] = next_depth;
+        stack.push_back(output_node);
+      }
+    }
+  }
+  return max_depth;
+}
+
+/*! \brief Return's true if any (input/output) of node is (outside/inside) the sub-graph.  */
+bool SubGraphNode::AnyInputOutside(const DataflowGraph::Node* node) const {
+  return std::any_of(node->inputs_.begin(), node->inputs_.end(),
+                     [this](const DataflowGraph::Node* sub_node) {
+                       return !inside_[sub_node->index_] && !CanInline(sub_node->ref());
+                     });
+}
+
+bool SubGraphNode::AnyInputInside(const DataflowGraph::Node* node) const {
+  return std::any_of(
+      node->inputs_.begin(), node->inputs_.end(),
+      [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; });
+}
+
+bool SubGraphNode::AnyOutputOutside(const DataflowGraph::Node* node) const {
+  return std::any_of(
+      node->outputs_.begin(), node->outputs_.end(),
+      [this](const DataflowGraph::Node* sub_node) { return !inside_[sub_node->index_]; });
+}
+
+bool SubGraphNode::AnyOutputInside(const DataflowGraph::Node* node) const {
+  return std::any_of(
+      node->outputs_.begin(), node->outputs_.end(),
+      [this](const DataflowGraph::Node* sub_node) { return inside_[sub_node->index_]; });
+}
+
+SubGraph::SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind,
+                   String label, std::vector<SubSubGraph> sub_sub_graphs) {
+  std::sort(
+      sub_sub_graphs.begin(), sub_sub_graphs.end(),
+      [](const SubSubGraph& left, const SubSubGraph& right) { return *left.get() < *right.get(); });
+  auto node = runtime::make_object<SubGraphNode>();
+  node->inside_ = std::move(inside);
+  node->first_inside_index_ = node->inside_.FirstInsideIndex();
+  node->last_inside_index_ = node->inside_.LastInsideIndex();
+  node->entry_ = IndexSet(node->inside_.end_index());
+  node->exit_ = IndexSet(node->inside_.end_index());
+  node->input_ = IndexSet(node->inside_.end_index());
+  node->output_ = IndexSet(node->inside_.end_index());
+  node->kind_ = kind;
+  node->label_ = std::move(label);
+  node->sub_sub_graphs_ = sub_sub_graphs;
+  node->Init(dataflow_graph);
+  data_ = std::move(node);
+}
+
+SubGraph::SubGraph(const DataflowGraph& dataflow_graph)
+    : SubGraph(dataflow_graph, IndexSet(dataflow_graph.size())) {}
+
+bool SubGraph::AreDisjoint(const SubGraph& that) const {
+  return get()->inside_.AreDisjoint(that->inside_);
+}
+
+namespace {
+/*! \brief Returns true if an output of \p left not in \p right ultimately flows into \p right. */
+bool FlowsInto(const DataflowGraph& dataflow_graph, const SubGraph& left, const SubGraph& right) {
+  for (PostDfsIndex output_index : left->output_) {
+    if (!right->inside_[output_index] &&
+        dataflow_graph.downstream_of(output_index).Intersects(right->entry_)) {
+      return true;
+    }
+  }
+  return false;
+}
+}  // namespace
+
+bool SubGraph::AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const {
+  if (!get()->inside_.AreDisjoint(that->inside_)) {
+    // Easy rejection.
+    return false;
+  }
+  if (!get()->output_.Intersects(that->entry_)) {
+    // Not touching.
+    return false;
+  }
+  if (FlowsInto(dataflow_graph, *this, that) || FlowsInto(dataflow_graph, that, *this)) {
+    // Unioning would create a cycle.
+    return false;
+  }
+  return true;
+}
+
+bool SubGraph::AreSelfContained(const SubGraph& that) const {
+  return get()->output_.IsSubset(that->entry_) && that->input_.IsSubset(get()->exit_);
+}
+
+SubGraph SubGraph::DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const {
+  ICHECK(AreDisjoint(that));
+  IndexSet inside = get()->inside_ | that->inside_;
+  std::vector<SubSubGraph> sub_sub_graphs;
+  for (const auto& sub_sub_graph : get()->sub_sub_graphs_) {
+    sub_sub_graphs.push_back(sub_sub_graph);
+  }
+  for (const auto& sub_sub_graph : that->sub_sub_graphs_) {
+    auto existing_itr = std::find_if(sub_sub_graphs.begin(), sub_sub_graphs.end(),
+                                     [&sub_sub_graph](const SubSubGraph& existing) {
+                                       return existing.TriviallyUnionable(sub_sub_graph);
+                                     });
+    if (existing_itr != sub_sub_graphs.end()) {
+      *existing_itr = existing_itr->DisjointUnion(dataflow_graph, sub_sub_graph);
+    } else {
+      sub_sub_graphs.push_back(sub_sub_graph);
+    }
+  }
+  return SubGraph(dataflow_graph, std::move(inside), CombineKinds(get()->kind_, that->kind_),
+                  UnionLabels(get()->label_, that->label_), std::move(sub_sub_graphs));
+}
+
+SubGraph SubGraph::WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const {
+  std::vector<SubSubGraph> sub_sub_graphs;
+  sub_sub_graphs.push_back(SubSubGraph(*this, attrs));
+  return SubGraph(dataflow_graph, get()->inside_, get()->kind_, get()->label_,
+                  std::move(sub_sub_graphs));
+}
+
+SubGraph SubGraph::Subst(const DataflowGraph& new_dataflow_graph, const IndexSubst& subst) const {
+  IndexSet new_inside = get()->inside_.Subst(new_dataflow_graph.size(), subst);
+  std::vector<SubSubGraph> new_sub_sub_graphs;
+  for (const auto& sub_sub_graph : get()->sub_sub_graphs_) {
+    new_sub_sub_graphs.push_back(sub_sub_graph.Subst(new_dataflow_graph, subst));
+  }
+  return SubGraph(new_dataflow_graph, std::move(new_inside), get()->kind_, get()->label_,
+                  std::move(new_sub_sub_graphs));
+}
+
+/*static*/
+Expr SubGraph::ParallelRewrite(const DataflowGraph& dataflow_graph,
+                               std::vector<SubGraph> sub_graphs) {
+  // IMPORTANT:
+  //  - All the sub-graphs will be w.r.t. the dataflow graph for the original expression.
+  //    Each time we call Rewrite on one of those graphs the result expression will be rewritten
+  //    from the final output back to the inputs. The inputs will then be shared with the original
+  //    expression. Thus it is safe to iteratively rewrite all the sub-graphs without redoing the
+  //    dataflow_graph and substituting indexes provided we work in reverse dataflow order.
+  //  - We rely on the dataflow_graph expression reference holding the original expression alive
+  //    so that the dataflow_graph will never contain dangling pointers (even though as per above
+  //    we'll never dereference them).
+  std::sort(sub_graphs.begin(), sub_graphs.end(), [](const SubGraph& left, const SubGraph& right) {
+    return left->last_inside_index_ > right->last_inside_index_;
+  });
+  Expr result = dataflow_graph.expr();
+  for (const auto& sub_graph : sub_graphs) {
+    result = sub_graph->Rewrite(dataflow_graph, result);
+  }
+  return result;
+}
+
+/*!
+ * \brief A pass which partitions (the unique) global function in the module according to the
+ * post-dfs indexes in \p indexes. The partiting must respect the configuration with \p max_exits

Review Comment:
   ```suggestion
    * post-dfs indexes in \p indexes. The partitioning must respect the configuration with \p max_exits
   ``



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914646468


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {
+ public:
+  SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs);
+
+  /*!
+   * \brief Returns copy of this sub-sub-graph with all indexes substituted according to \p subst,
+   * whose range is w.r.t. \p new_dataflow_graph.
+   */
+  SubSubGraph Subst(const DataflowGraph& new_dataflow_graph,
+                    const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
+
+  /*!
+   * \brief Returns true if this can be safely unioned.
+   */
+  bool TriviallyUnionable(const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns the disjoin union of this and \p that sub-sub graphs, which must agree on
+   * their attributes.
+   */
+  SubSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns \p expr rewritten according to all the given sub-sub-graphs. The sub-sub-graphs
+   * can be given in any order, but must be disjoint.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside the sub-sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                              std::vector<SubSubGraph> sub_sub_graphs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SubSubGraph, ObjectRef, SubSubGraphNode);
+};
+
+using SubSubGraphs = Array<SubSubGraph>;
+
+/*!
+ * \brief A compact representation of a sub-graph within an (implied) overall Relay expression.
+ *
+ * Sub-graphs can be used to represent partitions/kernels/composite functions without having to
+ * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a
+ * function to use for measuring a partition/kernel's latency independently from 'rewriting'
+ * the overall Relay expression since only a tiny subset of candidate partitions will end up being
+ * needed after Collage has completed its search.
+ *
+ * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so are
+ * mindful of space overhead.
+ *
+ * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or
+ * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not
+ * valid for an inside node to feed into another inside node via outside nodes. We provide the
+ * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules
+ * apply (such as maximum depth).
+ *
+ * We generally work with the \p DataflowGraph representation of the overall Relay expression
+ * rather than the expression itself. We use the post-dfs visit index to uniquely refer to
+ * expression nodes.
+ *
+ * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely
+ * determined from the 'inside' nodes:
+ *  - 'entry' nodes are those inside with at least one dataflow input outside.
+ *  - 'exit' nodes are  those inside with at least one dataflow output outside, or which
+ *    are considered 'external' in the underlying dataflow graph (eg because they represent
+ *    the result of the overall function).
+ *  - 'input' nodes are those outside with at least one dataflow output inside.
+ *  - 'output' nodes are those outside with at least one dataflow input inside.
+ * Index sets for these are cached with the sub-graph for performance.
+ *
+ * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to
+ * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes
+ * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result).
+ *
+ * Sub-graphs are closed under:
+ *  - Disjoint union.
+ *  - Wrapping by a function with given attributes (see \p SubSubGraph above). This can be used
+ *    to encode "Composite" functions, or to represent a candidate kernel within a "Primitive"
+ *    function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should
+ *    be placed inside a primitive function which itself may have calls to composite functions).
+ *  - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to
+ *    match some other (typically smaller) dataflow graph.
+ *
+ * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage
+ * search.
+ *
+ * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs
+ * a kind, which is generally the maximum of the kinds of all the operator calls appearing
+ * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging
+ * and guide the selection of global symbol names.
+ */
+class SubGraphNode : public Object {
+ public:
+  /*!
+   * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t.
+   * the implied DataflowGraph).
+   */
+  IndexSet inside_;
+
+  /*!
+   * \brief Index of first and last inside nodes.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  PostDfsIndex first_inside_index_ = 0;
+  PostDfsIndex last_inside_index_ = 0;
+
+  /*!
+   * \brief Which sub-expressions are entry/exit/input/output for this sub-graph.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  IndexSet entry_;
+  IndexSet exit_;
+  IndexSet input_;
+  IndexSet output_;
+
+  /*!
+   * \brief Maximum depth of any dataflow path from an entry to an output sub-expression.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  size_t max_depth_ = 0;
+
+  /*!
+   * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph.
+   *
+   * A sub-graph consisting of a single Relay expression node is given kind:
+   *  - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the
+   *    call does not involve data-dependent dynamic shapes).
+   *  - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has
+   *    that attribute)
+   *  - For Constants, \p kElemWise.
+   *  - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor
+   *    type)
+   *  - All other nodes \p kOpaque.
+   * Sub-graphs with more than one node have the maximum of the kind of each node.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  OpPatternKind kind_ = kOpaque;
+
+  /*!
+   * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary
+   * of the sub-graph which can help with debugging and guide the selection of global symbol names.
+   */
+  String label_;
+
+  /*!
+   * \brief Sub-sub-graphs of this sub-graph which must be represented by functions. These must
+   * be disjoint, but it's ok for this sub-graph to have nodes not inside any sub-sub-graph.
+   */
+  SubSubGraphs sub_sub_graphs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  // TODO(mbs): 'Anchor nodes' and rules for unioning them.
+  // In FuseOps it's just the unique kEWiseFusable node, if any.
+  // I'd like to allow writing vertical fusion rules, eg if two candidates are directly
+  // connected and have nn.conv2d anchors allow their join.
+  // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly
+  // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors
+  // then do so. Come back to this.
+
+  /*! \brief Number of nodes in overall dataflow graph. */
+  size_t overall_size() const { return inside_.end_index(); }
+
+  bool IsEmpty() const { return inside_.IsZero(); }
+
+  /*! \brief Number of nodes in sub-graph. */
+  size_t Size() const { return inside_.PopCount(); }
+
+  /*!
+   * \brief Returns the dataflow nodes downstream of all exit nodes.
+   */
+  IndexSet Downstream(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns true if this sub-graph is valid. Ie:
+   *  - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up
+   *    with a dataflow cycle when we partition).
+   *  - all inputs and outputs of the sub-graph are in the same scope, ie not separated by
+   *    control flow (otherwise there'd be no consistent program point at which to eval the
+   *    partitioned function).
+   *  - no more than config.max_outputs outputs are require.
+   *  - if config.allow_taps is false, no inside node has outputs to nodes both inside and
+   *    outside the sub-graph.
+   */
+  bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const;
+
+  /*!
+   * \brief Returns this sub-graph extracted as a stand-alone function. The function will have
+   * no attributes, and is suitable for building and profiling by the \p CostEstimator.
+   */
+  Function ExtractAsFunction(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  std::string ToString() const;
+
+  bool operator==(const SubGraphNode& that) const;
+  bool operator!=(const SubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubGraphNode& that) const;
+  size_t hash() const;
+
+ private:
+  /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */
+  void Init(const DataflowGraph& dataflow_graph);
+
+  /*! \brief Calculates and returns the maximum path depth. */
+  size_t MaxDepth(const DataflowGraph& dataflow_graph) const;
+
+  /*! \brief Return's true if any (input/output) of node is (outside/inside) the sub-graph. */

Review Comment:
   ```suggestion
     /*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph. */
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r918201673


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {

Review Comment:
   Ah, thank you! I really didn't like that name either.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r918212927


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.
+ */
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
+
+/*!
+ * \brief Returns a kind and label for all the nodes in \p inside.
+ */
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside);
+
+/*!
+ * \brief Returns the index set representing all the sub-expression matched by \p matcher.
+ */
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
+
+/*!
+ * \brief Configuration controlling which sub-graphs are considered valid.
+ */
+struct SubGraphConfig {
+  /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
+  size_t max_exits = 0;
+  /*!
+   * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
+   * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
+   * even with this flag false.
+   */
+  bool allow_taps = false;
+  /*!
+   * \brief Maximum allowed maximum depth, or zero if no-limit.
+   */
+  size_t max_max_depth = 0;
+
+  std::string ToString() const;
+};
+
+class SubGraph;
+using FunctionAttrsMap = Map<String, ObjectRef>;
+
+/*!
+ * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some
+ * enclosing sub-graph.
+ *
+ * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
+ * function result. Rewriting replaces the sub-graph with a call to that function, and all
+ * outputs with (projections from) the call result.
+ *
+ * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
+ * However we found the implementation was easier to understand in this form since it makes
+ * the result of \p Extract unambiguous.)
+ */
+class SubSubGraphNode : public Object {
+ public:
+  /*! \brief The nested sub-graph. */
+  ObjectRef /* actually SubGraph */ sub_graph_obj_;
+  /*! \brief Attributes (possibly empty) to attach to the extracted function. */
+  FunctionAttrsMap attrs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  SubGraph sub_graph() const;
+
+  bool operator==(const SubSubGraphNode& that) const;
+  bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubSubGraphNode& that) const;
+  size_t hash() const;
+
+  std::string ToString() const;
+
+  /*!
+   * \brief Returns the function representing this sub-sub-graph within the overall expression
+   * represented by \p dataflow_graph:
+   *  - All sub-graph inputs become parameters.
+   *  - All sub-graph outputs become function results (either directly or as a field in a tuple).
+   *  - The function has attrs_ for attributes (which may be empty).
+   *  - The function body accounts for any rewrites implied by the nested sub-graph.
+   */
+  Function Extract(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  static constexpr const char* _type_key = "relay.collage.SubSubGraph";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object);
+};
+
+class SubSubGraph : public ObjectRef {
+ public:
+  SubSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs);
+
+  /*!
+   * \brief Returns copy of this sub-sub-graph with all indexes substituted according to \p subst,
+   * whose range is w.r.t. \p new_dataflow_graph.
+   */
+  SubSubGraph Subst(const DataflowGraph& new_dataflow_graph,
+                    const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
+
+  /*!
+   * \brief Returns true if this can be safely unioned.
+   */
+  bool TriviallyUnionable(const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns the disjoin union of this and \p that sub-sub graphs, which must agree on
+   * their attributes.
+   */
+  SubSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubSubGraph& that) const;
+
+  /*!
+   * \brief Returns \p expr rewritten according to all the given sub-sub-graphs. The sub-sub-graphs
+   * can be given in any order, but must be disjoint.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside the sub-sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
+                              std::vector<SubSubGraph> sub_sub_graphs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SubSubGraph, ObjectRef, SubSubGraphNode);
+};
+
+using SubSubGraphs = Array<SubSubGraph>;
+
+/*!
+ * \brief A compact representation of a sub-graph within an (implied) overall Relay expression.
+ *
+ * Sub-graphs can be used to represent partitions/kernels/composite functions without having to
+ * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a
+ * function to use for measuring a partition/kernel's latency independently from 'rewriting'
+ * the overall Relay expression since only a tiny subset of candidate partitions will end up being
+ * needed after Collage has completed its search.
+ *
+ * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so are
+ * mindful of space overhead.
+ *
+ * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or
+ * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not
+ * valid for an inside node to feed into another inside node via outside nodes. We provide the
+ * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules
+ * apply (such as maximum depth).
+ *
+ * We generally work with the \p DataflowGraph representation of the overall Relay expression
+ * rather than the expression itself. We use the post-dfs visit index to uniquely refer to
+ * expression nodes.
+ *
+ * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely
+ * determined from the 'inside' nodes:
+ *  - 'entry' nodes are those inside with at least one dataflow input outside.
+ *  - 'exit' nodes are  those inside with at least one dataflow output outside, or which
+ *    are considered 'external' in the underlying dataflow graph (eg because they represent
+ *    the result of the overall function).
+ *  - 'input' nodes are those outside with at least one dataflow output inside.
+ *  - 'output' nodes are those outside with at least one dataflow input inside.
+ * Index sets for these are cached with the sub-graph for performance.
+ *
+ * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to
+ * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes
+ * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result).
+ *
+ * Sub-graphs are closed under:
+ *  - Disjoint union.
+ *  - Wrapping by a function with given attributes (see \p SubSubGraph above). This can be used
+ *    to encode "Composite" functions, or to represent a candidate kernel within a "Primitive"
+ *    function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should
+ *    be placed inside a primitive function which itself may have calls to composite functions).
+ *  - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to
+ *    match some other (typically smaller) dataflow graph.
+ *
+ * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage
+ * search.
+ *
+ * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs
+ * a kind, which is generally the maximum of the kinds of all the operator calls appearing
+ * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging
+ * and guide the selection of global symbol names.
+ */
+class SubGraphNode : public Object {
+ public:
+  /*!
+   * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t.
+   * the implied DataflowGraph).
+   */
+  IndexSet inside_;
+
+  /*!
+   * \brief Index of first and last inside nodes.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  PostDfsIndex first_inside_index_ = 0;
+  PostDfsIndex last_inside_index_ = 0;
+
+  /*!
+   * \brief Which sub-expressions are entry/exit/input/output for this sub-graph.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  IndexSet entry_;
+  IndexSet exit_;
+  IndexSet input_;
+  IndexSet output_;
+
+  /*!
+   * \brief Maximum depth of any dataflow path from an entry to an output sub-expression.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  size_t max_depth_ = 0;
+
+  /*!
+   * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph.
+   *
+   * A sub-graph consisting of a single Relay expression node is given kind:
+   *  - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the
+   *    call does not involve data-dependent dynamic shapes).
+   *  - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has
+   *    that attribute)
+   *  - For Constants, \p kElemWise.
+   *  - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor
+   *    type)
+   *  - All other nodes \p kOpaque.
+   * Sub-graphs with more than one node have the maximum of the kind of each node.
+   *
+   * Cached for performance, uniquely determined by inside_.
+   */
+  OpPatternKind kind_ = kOpaque;
+
+  /*!
+   * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary
+   * of the sub-graph which can help with debugging and guide the selection of global symbol names.
+   */
+  String label_;
+
+  /*!
+   * \brief Sub-sub-graphs of this sub-graph which must be represented by functions. These must
+   * be disjoint, but it's ok for this sub-graph to have nodes not inside any sub-sub-graph.
+   */
+  SubSubGraphs sub_sub_graphs_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  // TODO(mbs): 'Anchor nodes' and rules for unioning them.
+  // In FuseOps it's just the unique kEWiseFusable node, if any.
+  // I'd like to allow writing vertical fusion rules, eg if two candidates are directly
+  // connected and have nn.conv2d anchors allow their join.
+  // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly
+  // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors
+  // then do so. Come back to this.
+
+  /*! \brief Number of nodes in overall dataflow graph. */
+  size_t overall_size() const { return inside_.end_index(); }
+
+  bool IsEmpty() const { return inside_.IsZero(); }
+
+  /*! \brief Number of nodes in sub-graph. */
+  size_t Size() const { return inside_.PopCount(); }
+
+  /*!
+   * \brief Returns the dataflow nodes downstream of all exit nodes.
+   */
+  IndexSet Downstream(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns true if this sub-graph is valid. Ie:
+   *  - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up
+   *    with a dataflow cycle when we partition).
+   *  - all inputs and outputs of the sub-graph are in the same scope, ie not separated by
+   *    control flow (otherwise there'd be no consistent program point at which to eval the
+   *    partitioned function).
+   *  - no more than config.max_outputs outputs are require.
+   *  - if config.allow_taps is false, no inside node has outputs to nodes both inside and
+   *    outside the sub-graph.
+   */
+  bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const;
+
+  /*!
+   * \brief Returns this sub-graph extracted as a stand-alone function. The function will have
+   * no attributes, and is suitable for building and profiling by the \p CostEstimator.
+   */
+  Function ExtractAsFunction(const DataflowGraph& dataflow_graph) const;
+
+  /*!
+   * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-graph.
+   *
+   * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
+   * inside this sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
+   * \p expr. See \p SubGraph::ParallelRewrite below.
+   */
+  Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
+
+  std::string ToString() const;
+
+  bool operator==(const SubGraphNode& that) const;
+  bool operator!=(const SubGraphNode& that) const { return !(*this == that); }
+  bool operator<(const SubGraphNode& that) const;
+  size_t hash() const;
+
+ private:
+  /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */
+  void Init(const DataflowGraph& dataflow_graph);
+
+  /*! \brief Calculates and returns the maximum path depth. */
+  size_t MaxDepth(const DataflowGraph& dataflow_graph) const;

Review Comment:
   From now on every time I come up with some awkward stuttering variable name I'll thing "what would Matthew write?" :-)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] SebastianBoblest commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
SebastianBoblest commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r914638359


##########
src/relay/collage/sub_graph.h:
##########
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.h
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/op_attr_types.h>
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../ir/indexed_graph.h"
+#include "./dataflow_graph.h"
+#include "./index_set.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Returns operator pattern kind as single-letter string. */
+std::string KindToString(OpPatternKind kind);
+
+/*!
+ * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions.

Review Comment:
   ```suggestion
    * \brief Returns a kind and label for the single \p sub_expr, ignoring its sub-sub expressions.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r918198870


##########
src/relay/collage/sub_graph.cc:
##########
@@ -0,0 +1,1032 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/sub_graph.cc
+ * \brief Represents a sub-graph of an overall Relay expression.
+ */
+
+#include "./sub_graph.h"
+
+#include <tvm/relay/transform.h>
+
+#include "../../support/scalars.h"
+#include "../transforms/pass_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+class Extractor;
+
+/*!
+ * \brief Helper class for rewriting expressions to replace a sub-graph according to the
+ * given extractor.
+ */
+class Rewriter : public ExprMutator {
+ public:
+  explicit Rewriter(const Extractor* extractor) : extractor_(extractor) {}
+
+  Expr VisitExpr(const Expr& expr) final;
+
+ private:
+  /*! \brief Already prepared extractor which will guide the rewrite. */
+  const Extractor* extractor_;
+};
+
+/*! \brief Helper class for extracting matched sub-graphs from the overall expression. */
+class Extractor : public ExprMutator {
+ public:
+  Extractor(const DataflowGraph* dataflow_graph, const SubGraphNode* sub_graph,
+            FunctionAttrsMap opt_attrs)
+      : dataflow_graph_(dataflow_graph), sub_graph_(sub_graph), opt_attrs_(std::move(opt_attrs)) {
+    ICHECK_EQ(dataflow_graph_->size(), sub_graph_->overall_size());
+  }
+
+  const DataflowGraph& dataflow_graph() const { return *dataflow_graph_; }
+
+  /*!
+   * \brief Collect the parameters and output expressions for the function representing
+   * the sub-graph.
+   */
+  void Extract() {
+    ICHECK(!sub_graph_->IsEmpty());
+    VLOG(2) << "Extracting " << sub_graph_->ToString();
+    const bool for_function = opt_attrs_.defined();
+
+    //  In reverse dataflow order...
+    for (PostDfsIndex i = dataflow_graph_->size(); i > 0; --i) {
+      PostDfsIndex index = i - 1;
+      if (!sub_graph_->inside_[index]) {
+        // Node is outside sub-graph.
+        continue;
+      }
+      VLOG(2) << "index " << index;
+      auto node = dataflow_graph_->index_to_node(index);
+      if (sub_graph_->exit_[node->index_] || node->is_external_ || memo_.count(node->ref()) == 0) {
+        // This sub-expression is:
+        //  - inside the sub-graph and needed outside the sub-graph. So it must contribute to an
+        //    output (even if we've already visited it while constructing an output from a
+        //    downstream sub-expression).
+        //  - not yet visited, in which case it must still be considered an 'output' so it will
+        //    be evaluated for any possible side effects.
+        Expr output = VisitExpr(GetRef<Expr>(node->node_ref_));
+        VLOG(2) << "index " << index << " added as output:\n"
+                << PrettyPrint(output) << "\nat " << outputs_.size();
+        expr_to_output_index_.emplace(node->node_ref_, outputs_.size());
+        outputs_.emplace_back(std::move(output));
+        output_types_.emplace_back(node->node_ref_->checked_type());
+      }
+    }
+    ICHECK(!outputs_.empty());
+
+    // Reverse the outputs so as to preserve the original evaluation order.
+    std::reverse(outputs_.begin(), outputs_.end());
+    std::reverse(output_types_.begin(), output_types_.end());
+    for (auto& kv : expr_to_output_index_) {
+      kv.second = static_cast<int>(outputs_.size()) - 1 - kv.second;
+    }
+
+    // Build a 'body' expression to represent the extracted sub-graph. If we have multiple
+    // outputs we'll place them in a tuple.
+    Type body_type;
+    Expr body;
+    if (outputs_.size() > 1) {
+      body_type = TupleType(output_types_);
+      body = Tuple(outputs_);
+      body->checked_type_ = body_type;
+    } else {
+      body_type = output_types_.front();
+      body = outputs_.front();
+    }
+
+    // Re-express all the sub-sub-graphs in terms of the body.
+    DataflowGraph body_dataflow_graph(body);
+    std::vector<SubSubGraph> sub_sub_graphs;
+    IndexSubst subst = MakeIndexSubst(body_dataflow_graph);
+    for (const auto& sub_sub_graph : sub_graph_->sub_sub_graphs_) {
+      sub_sub_graphs.emplace_back(sub_sub_graph.Subst(body_dataflow_graph, subst));
+    }
+
+    // Sweep backwards through the body, rewriting to account for each sub-sub-graph.
+    body = SubSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(sub_sub_graphs));
+
+    if (for_function) {
+      // Rewrite so all input nodes are now conveyed via call arguments to a new function.
+      Array<Type> arg_types;
+      arg_types.reserve(params_.size());
+      for (const auto& param : params_) {
+        arg_types.push_back(param->checked_type());
+      }
+      extracted_ = Function(std::move(params_), std::move(body), body_type,
+                            /*ty_params=*/{}, DictAttrs(opt_attrs_));
+      extracted_->checked_type_ =
+          FuncType(std::move(arg_types), body_type, /*type_params=*/{}, /*type_constraints=*/{});
+      body = Call(extracted_, std::move(args_));
+      body->checked_type_ = body_type;
+    } else {
+      // Don't do anything with the inputs.
+      extracted_ = body;
+    }
+
+    // Setup the output substitution.
+    for (const auto& kv : expr_to_output_index_) {
+      Expr expr;
+      if (outputs_.size() == 1) {
+        expr = body;
+      } else if (for_function) {
+        expr = TupleGetItem(body, kv.second);
+        expr->checked_type_ = output_types_[kv.second];
+      } else {
+        const auto* tuple_node = body.as<TupleNode>();
+        ICHECK(tuple_node);
+        expr = tuple_node->fields[kv.second];
+      }
+      VLOG(2) << "output " << dataflow_graph_->item_to_node(kv.first)->index_ << " is at index "
+              << kv.second << " (of " << outputs_.size() << " outputs)";
+      output_substitution_.emplace(kv.first, std::move(expr));
+    }
+  }
+
+  ////// Following members are valid only after Extract() has returned.
+
+  /*!
+   * \brief Returns the expression representing the extracted sub-graph. If opt_attrs_ is
+   * defined then will be a function.
+   */
+  Expr extracted() const { return extracted_; }
+
+  /*!
+   * \brief Returns the substitution to apply to all expression nodes in the overall expression
+   * so as to replace references to outputs of the sub-graph with their rewritten form.
+   */
+  const std::unordered_map<const ExprNode*, Expr>& output_substitution() const {
+    return output_substitution_;
+  }
+
+ private:
+  /*!
+   * \brief Returns a map from original index to new index for each node inside the sub-graph. Only
+   * valid after \p Extract has made its backwards dataflow sweep.
+   */
+  IndexSubst MakeIndexSubst(const DataflowGraph& new_dataflow_graph) const {
+    VLOG(2) << "building extractor substitution";
+    IndexSubst subst;
+    for (PostDfsIndex index : sub_graph_->inside_) {
+      auto orig_node = dataflow_graph_->index_to_node(index);
+      ICHECK_EQ(orig_node->index_, index);
+      auto itr = memo_.find(orig_node->ref());
+      ICHECK(itr != memo_.end());
+      auto new_node = new_dataflow_graph.item_to_node(itr->second);
+      VLOG(2) << orig_node->index_ << " |-> " << new_node->index_;
+      subst.emplace(orig_node->index_, new_node->index_);
+    }
+    return subst;
+  }
+
+  /*! \brief Returns true if \p expr is inside the sub-graph. */
+  bool inside(const Expr& expr) {
+    return sub_graph_->inside_[dataflow_graph_->item_to_node(expr)->index_];
+  }
+
+  /*!
+   * \brief Returns the variable uniquely representing \p expr, which should be
+   * an input node (ie outside the sub-graph but feeding into a node inside the sub-graph).
+   *
+   * It is valid for:
+   *  - An expression outside the sub-graph to be used multiple times inside the sub-graph.
+   *  - An expression outside the sub-graph to be used both inside and outside the sub-graph.
+   */
+  Var VarFor(const Expr& expr) {
+    ICHECK(!inside(expr));
+    ICHECK(opt_attrs_.defined());
+    auto itr = expr_to_param_.find(expr.get());
+    if (itr != expr_to_param_.end()) {
+      return itr->second;
+    }
+    auto fresh_var = Var("FunctionVar_" + std::to_string(params_.size()), expr->checked_type());
+    fresh_var->checked_type_ = expr->checked_type();
+    params_.push_back(fresh_var);
+    args_.push_back(expr);
+    expr_to_param_.emplace(expr.get(), fresh_var);
+    return fresh_var;
+  }
+
+  /*!
+   * \brief If \p expr is inside the sub-graph then return it's rewritten form.
+   * If \p expr is outside the sub-graph then it must correspond to an input node.
+   *  - If opt_attrs_ is defined return the variable to represent it.
+   *  - Otherwise just return the expression directly.
+   *
+   * Should be called only on inputs to nodes which are inside the sub-graph.
+   */
+  Expr VisitExpr(const Expr& expr) final {
+    if (inside(expr)) {
+      return ExprMutator::VisitExpr(expr);
+    } else if (CanInline(expr)) {
+      // Implicitly include inlinable input sub-expressions.
+      return expr;
+    } else if (opt_attrs_.defined()) {
+      // Map to a function parameter.
+      return VarFor(expr);
+    } else {
+      // Stop rewriting.
+      return expr;
+    }
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) override {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+    return ExprMutator::VisitExpr_(function_node);
+  }
+
+  //// Context fields, passed in constructor.
+
+  /*! \brief The dataflow graph corresponding to the overall expression. */
+  const DataflowGraph* dataflow_graph_;
+  /*! \brief The sub-graph of the above we are extracting. */
+  const SubGraphNode* sub_graph_;
+  /*! \brief Optional attributes if the sub-graph should be extracted as a function. */
+  FunctionAttrsMap opt_attrs_;
+
+  //// Result fields, available after Extract() called.
+
+  /*!
+   * \brief The extracted expression. If opt_attrs_ is defined this will be a function.
+   */
+  Expr extracted_;
+  /*!
+   * \brief Map from output nodes to corresponding expressions. If the sub-graph has more than
+   * one exit node then each entry will be a tuple projection.
+   */
+  std::unordered_map<const ExprNode*, Expr> output_substitution_;
+
+  //// Accumulator fields, built as we visit expressions.
+
+  /*! \brief (If opt_attrs_ is defined) Parameters representing input expression nodes. */
+  Array<Var> params_;
+  /*!
+   * \brief (If opt_attrs_ is defined) The input expression nodes for each of the above params_.
+   */
+  Array<Expr> args_;
+  /*!
+   * \brief (If opt_attrs_ is defined) Map from existing input expression nodes to the parameters
+   * in params_ which now representing them.
+   */
+  std::unordered_map<const ExprNode*, Var> expr_to_param_;
+  /*!
+   * \brief Accumulated new expressions which represent the exit nodes of the rewritten sub-graph.
+   * It is possible to have multiple outputs. It is possible one output also contributes to other
+   * outputs (ie the output is a 'tap').
+   */
+  std::vector<Expr> outputs_;
+  /*! \brief (If opt_attrs_ is defined) Types of original expressions corresponding to outputs_. */
+  std::vector<Type> output_types_;
+  /*!
+   * \brief Map from existing exit expression nodes to the index in outputs_ which should
+   * represent them in the rewritten overall expression.
+   */
+  std::unordered_map<const ExprNode*, int> expr_to_output_index_;
+};
+
+Expr Rewriter::VisitExpr(const Expr& expr) {
+  auto itr = extractor_->output_substitution().find(expr.get());
+  if (itr == extractor_->output_substitution().end()) {
+    return ExprMutator::VisitExpr(expr);
+  } else {
+    return itr->second;
+  }
+}
+
+}  // namespace
+
+std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr) {
+  class Visitor : public ExprFunctor<std::pair<OpPatternKind, std::string>(const Expr&)> {
+   private:
+    std::pair<OpPatternKind, std::string> VisitExpr_(const CallNode* call_node) final {
+      if (const auto* op_node = call_node->op.as<OpNode>()) {
+        auto op = GetRef<Op>(op_node);
+        static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+        if (fpattern.count(op) == 0) {
+          VLOG(1) << "no TOpPattern known for " << op->name << ", considering opaque";
+          return {kOpaque, op->name};
+        } else if (IsDynamic(call_node->checked_type()) && IsDataDependent(call_node)) {
+          VLOG(1) << "call has dynamic shape which is data-dependent, considering opaque";
+          return {kOpaque, op->name};
+        } else {
+          OpPatternKind kind = static_cast<OpPatternKind>(fpattern[op]);
+          VLOG(2) << "TOpPattern for " << op->name << " is " << KindToString(kind);
+          return {kind, op->name};
+        }
+      } else if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        Optional<Integer> opt_i =
+            function_node->GetAttr<Integer>("TOpPattern", Optional<Integer>());
+        if (opt_i.defined()) {
+          OpPatternKind kind = static_cast<OpPatternKind>(opt_i.value()->value);
+          VLOG(1) << "TOpPattern for function is " << KindToString(kind);
+          return {kind, "call_prim"};
+        } else {
+          VLOG(1) << "calling function without TOpPattern, considering opaque";
+          return {kOpaque, "call_fun"};
+        }
+      } else {
+        VLOG(1) << "unsupported call, considering opaque";
+        return {kOpaque, "call_any"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstantNode* constant_node) final {
+      VLOG(2) << "TOpPattern for constant is " << KindToString(kElemWise);
+      if (support::IsSimpleScalar(constant_node)) {
+        return {kElemWise, "scalar"};
+      } else {
+        return {kElemWise, "const"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const TupleNode* tuple_node) final {
+      const auto* tuple_type_node = tuple_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple is " << KindToString(kInjective);
+        return {kInjective, "tuple"};
+      } else {
+        VLOG(1) << "tuple contains non-tensors, considering opaque";
+        return {kOpaque, "tuple"};
+      }
+    }
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(
+        const TupleGetItemNode* tuple_get_item_node) final {
+      const auto* tuple_type_node = tuple_get_item_node->tuple->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type_node != nullptr);
+      if (std::all_of(tuple_type_node->fields.begin(), tuple_type_node->fields.end(),
+                      [](const Type& type) { return type.as<TensorTypeNode>() != nullptr; })) {
+        VLOG(2) << "TOpPattern for tuple projection is " << KindToString(kInjective);
+        return {kInjective, "proj"};
+      } else {
+        VLOG(1) << "tuple being projected contains non-tensors, considering opaque";
+        return {kOpaque, "proj"};
+      }
+    }
+
+    // TODO(mbs): We implement the following mostly so we have a lightweight way of describing
+    // the current sub-expression. If partitioning is ever extended beyond the usual call/tuple/proj
+    // sub-language we should revise the returned operator kinds to match.
+
+    std::pair<OpPatternKind, std::string> VisitExpr_(const VarNode* var_node) final {
+      return {kOpaque, "%" + var_node->name_hint()};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const GlobalVarNode* global_var_node) final {
+      return {kOpaque, "@" + global_var_node->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const OpNode* op_node) final {
+      return {kOpaque, "`" + op_node->name};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const FunctionNode* function_node) final {
+      return {kOpaque, "fn"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const LetNode* let_node) final {
+      return {kOpaque, "let"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const IfNode* if_node) final {
+      return {kOpaque, "if"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefCreateNode* ref_create_node) final {
+      return {kOpaque, "ref"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefReadNode* op) final {
+      return {kOpaque, "ref_read"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const RefWriteNode* op) final {
+      return {kOpaque, "ref_write"};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const ConstructorNode* op) final {
+      return {kOpaque, "`" + op->name_hint};
+    }
+    std::pair<OpPatternKind, std::string> VisitExpr_(const MatchNode* op) final {
+      return {kOpaque, "match"};
+    }
+  };
+  return Visitor().VisitExpr(sub_expr);
+}
+
+std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
+                                                           const IndexSet& inside) {
+  std::ostringstream os;
+  bool first = true;
+  OpPatternKind max_kind = kElemWise;
+  for (PostDfsIndex index : inside) {
+    OpPatternKind sub_kind;
+    std::string sub_label;
+    std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
+    if (!sub_label.empty()) {
+      if (first) {
+        first = false;
+      } else {
+        os << "+";
+      }
+      os << sub_label;
+    }
+    max_kind = CombineKinds(max_kind, sub_kind);
+  }
+  return {max_kind, os.str()};
+}
+
+IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher) {
+  IndexSet result(matcher.size());
+  for (const auto& kv : matcher.memo()) {
+    for (const auto& matched_sub_expr : kv.second) {
+      if (CanInline(matched_sub_expr)) {
+        // Trivial sub-expressions can just be included in the extracted function body
+        // when we construct it and don't need to be considered part of the sub-graph.
+        continue;
+      }
+      if (kv.first.as<WildcardPatternNode>()) {
+        // Don't consider the expressions matched by a wildcard to be part of the sub-graph.
+        continue;
+      }
+      result.Add(matcher.expr_to_node(matched_sub_expr)->index_);
+    }
+  }
+  return result;
+}
+
+std::string SubGraphConfig::ToString() const {
+  std::ostringstream os;
+  os << "{max_exits=" << max_exits;
+  os << ",allow_taps=" << allow_taps;
+  os << ",max_max_depth=" << max_max_depth;
+  os << "}";
+  return os.str();
+}
+
+TVM_REGISTER_NODE_TYPE(SubSubGraphNode);
+
+void SubSubGraphNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+SubGraph SubSubGraphNode::sub_graph() const { return Downcast<SubGraph>(sub_graph_obj_); }
+
+bool SubSubGraphNode::operator==(const SubSubGraphNode& that) const {
+  return *sub_graph().get() == *that.sub_graph().get();
+}
+
+bool SubSubGraphNode::operator<(const SubSubGraphNode& that) const {
+  return *sub_graph().get() < *that.sub_graph().get();
+}
+
+size_t SubSubGraphNode::hash() const {
+  size_t h = StructuralHash()(attrs_);
+  h ^= sub_graph()->hash() + 0x9e3779b9 + (h << 6) + (h >> 2);
+  return h;
+}
+
+std::string SubSubGraphNode::ToString() const {
+  std::ostringstream os;
+  os << "{sub_graph=" << sub_graph()->ToString();
+  os << ",attrs=" << PrettyPrint(attrs_);

Review Comment:
   The debug output is pretty overwhelming so I was trying to compress it a bit but I went ahead and took your suggestion throughout anyway :-)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on a diff in pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on code in PR #11981:
URL: https://github.com/apache/tvm/pull/11981#discussion_r918201354


##########
src/relay/collage/README.md:
##########
@@ -0,0 +1,26 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+The `CollagePartition` pass for finding optimal partitionings of Relay models.
+
+See the [RFC](https://github.com/mbs-octoml/mbs-tvm-rfcs/blob/mbs-rfcs-collage/rfcs/xxxx-collage.md).
+
+Based on:
+> *Collage: Automated Integration of Deep Learning Backends*  
+> Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia
+
+CAUTION: This is a prototype, do not use in prod.

Review Comment:
   Decided to leave it in there given the unit testing is pretty thin on the ground and we've only been exercising on a handful of models.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mbs-octoml commented on pull request #11981: [Collage] SubGraphs

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on PR #11981:
URL: https://github.com/apache/tvm/pull/11981#issuecomment-1171778704

   @mbaret, thanks.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org