You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2022/07/13 19:48:59 UTC

[tvm] branch main updated: [Collage] CombinerRule and CandidatePartition::EstimateCost (#12078)

This is an automated email from the ASF dual-hosted git repository.

mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 261de5302f [Collage] CombinerRule and CandidatePartition::EstimateCost (#12078)
261de5302f is described below

commit 261de5302f583cdbf09d7d1ef9718875b76db3eb
Author: Mark Shields <87...@users.noreply.github.com>
AuthorDate: Wed Jul 13 12:48:51 2022 -0700

    [Collage] CombinerRule and CandidatePartition::EstimateCost (#12078)
    
    * [Collage] CombinerRule and CandidatePartition::EstimateCost
    
    See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md.
    
    We complete the PartitionRule sub-class hierarchy with the addition of
    CombinePartitionRule, which allows disjoint candidate partitions to be
    unioned based on simple rules.
     - By TOpPattern kind, eg a kOutElemwiseFusable and kBroadcast.
     - A tuple argument with injective fields.
     - The projection from an injective group (obviously of tuple type)
     - Combinations of the above.
    These let us mimic many common fusion strategies, including TVMs, so that
    the candidates explored during Collage search are as large as possible to
    expose possible fusion opportunities but no larger.
    
    Also completes CandidatePartition with the EstimateCost method, which is
    used during search to construct a stand-alone IRModule for latency estimation.
    
    Finish units tests for PartitionRule and CandidatePartition.
    
    * - fix relay.collage ffi prefix.
---
 src/relay/collage/candidate_function_cache.cc      |  49 ++
 src/relay/collage/candidate_function_cache.h       |  79 +++
 src/relay/collage/candidate_partition.cc           | 100 ++++
 src/relay/collage/candidate_partition.h            |  10 +
 src/relay/collage/combiner_rule.cc                 | 395 ++++++++++++++
 src/relay/collage/combiner_rule.h                  | 229 ++++++++
 src/relay/collage/cost.h                           |   5 +
 src/relay/collage/cost_estimator.cc                | 132 +++++
 src/relay/collage/cost_estimator.h                 | 104 ++++
 src/relay/collage/name_supply.cc                   |  90 ++++
 src/relay/collage/name_supply.h                    |  58 ++
 src/relay/collage/partition_rule.cc                |  60 +++
 src/relay/collage/partition_rule.h                 | 132 +++++
 .../cpp/relay/collage/candidate_partition_test.cc  | 220 ++++++++
 tests/cpp/relay/collage/partition_rule_test.cc     | 596 +++++++++++++++++----
 15 files changed, 2167 insertions(+), 92 deletions(-)

diff --git a/src/relay/collage/candidate_function_cache.cc b/src/relay/collage/candidate_function_cache.cc
new file mode 100644
index 0000000000..32982dc08f
--- /dev/null
+++ b/src/relay/collage/candidate_function_cache.cc
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/candidate_function_cache.cc
+ * \brief A cache of the unique global name and costs for partitioned functions.
+ */
+
+#include "./candidate_function_cache.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+CandidateFunctionCache::Entry& CandidateFunctionCache::GetEntry(const std::string& label,
+                                                                const Function& function) {
+  auto itr = cache_.find(function);
+  if (itr == cache_.end()) {
+    String compiler = function->GetAttr<String>(attr::kCompiler, String("tvm")).value();
+    std::string global_symbol_name = name_supply_->Fresh({compiler, label});
+    GlobalVar global_symbol(std::move(global_symbol_name), function->checked_type());
+    itr = cache_.emplace(function, Entry(std::move(global_symbol))).first;
+  }
+  return itr->second;
+}
+
+GlobalVar CandidateFunctionCache::GetGlobalSymbol(const Function& function) {
+  return GetEntry(/*label=*/"", function).global_symbol;
+}
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/candidate_function_cache.h b/src/relay/collage/candidate_function_cache.h
new file mode 100644
index 0000000000..8734f5a8e1
--- /dev/null
+++ b/src/relay/collage/candidate_function_cache.h
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/candidate_function_cache.h
+ * \brief A cache of the unique global symbol name and cost for partitioned functions.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_
+#define TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_
+
+#include <tvm/relay/function.h>
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "../transforms/compiler_function_utils.h"
+#include "./cost.h"
+#include "./name_supply.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief A cache of the unique global symbol and cost for functions extracted to represent
+ * partitions. If two functions are structurally equal (which includes equality of their "Compiler"
+ * attributes) then they will share the same global symbol and estimated cost. We rely on the
+ * function's attributes to distinguish partitions which are structurally the same graph but
+ * intended for different targets.
+ */
+class CandidateFunctionCache : public transform::GlobalSymbolCache {
+ public:
+  explicit CandidateFunctionCache(std::shared_ptr<NameSupply> name_supply)
+      : name_supply_(std::move(name_supply)) {}
+
+  struct Entry {
+    GlobalVar global_symbol;
+    Cost cost = Cost::Unknown();  // Filled in when have estimated cost.
+
+    explicit Entry(GlobalVar global_symbol) : global_symbol(std::move(global_symbol)) {}
+  };
+
+  /*!
+   * \brief Returns the unique entry for \p function. If no such entry already exists, create it
+   * and assign it a unique global symbol name.
+   */
+  Entry& GetEntry(const std::string& label, const Function& function);
+
+  GlobalVar GetGlobalSymbol(const Function& function) final;
+
+ private:
+  std::shared_ptr<NameSupply> name_supply_;
+  std::unordered_map<Function, Entry, StructuralHash, StructuralEqual> cache_;
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_
diff --git a/src/relay/collage/candidate_partition.cc b/src/relay/collage/candidate_partition.cc
index 9cccdf96d5..20e29a6d40 100644
--- a/src/relay/collage/candidate_partition.cc
+++ b/src/relay/collage/candidate_partition.cc
@@ -24,8 +24,12 @@
 
 #include "./candidate_partition.h"
 
+#include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
 
+#include "../transforms/compiler_function_utils.h"
+#include "./candidate_function_cache.h"
 #include "./candidate_set.h"
 #include "./partition_rule.h"
 #include "./partition_spec.h"
@@ -106,6 +110,102 @@ std::string CandidatePartitionNode::ToString() const {
   return os.str();
 }
 
+namespace {
+/*!
+ * \brief If function's body is a call to an inlined "Primitive" function, return it.
+ * Otherwise return function directly.
+ */
+Function GetPrimitiveFunction(const Function& function) {
+  if (const auto* call_node = function->body.as<CallNode>()) {
+    if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+      if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+        return GetRef<Function>(function_node);
+      }
+    }
+  }
+  return function;
+}
+
+/*!
+ * \brief Eta-expand any tuple arguments of \p function. Ie rewrite:
+ * \code
+ *   f(x: (t1, t2)) { ... x ... }
+ * \endcode
+ * to
+ * \code
+ *   f(x_1: t1, x_2: t2) { ... (x_1, x_2) ... }
+ * \endcode
+ */
+Function EtaExpandTuples(const Function& function) {
+  Map<Var, Expr> subst;
+  Array<Var> new_params;
+  for (const auto& param : function->params) {
+    std::vector<TensorType> tensor_types = FlattenTupleType(param->type_annotation);
+    if (tensor_types.size() == 1) {
+      new_params.push_back(param);
+    } else {
+      Array<Expr> fields;
+      for (size_t i = 0; i < tensor_types.size(); ++i) {
+        Var new_param(param->name_hint() + "_" + std::to_string(i), tensor_types[i], param->span);
+        new_param->checked_type_ = tensor_types[i];
+        new_params.push_back(new_param);
+        fields.push_back(new_param);
+      }
+      Tuple new_tuple(fields);
+      subst.Set(param, new_tuple);
+    }
+  }
+  if (subst.empty()) {
+    return function;
+  }
+  return WithFields(function, new_params, Bind(function->body, subst));
+}
+
+}  // namespace
+
+Cost CandidatePartitionNode::EstimatedCost(
+    const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator,
+    const std::shared_ptr<CandidateFunctionCache>& cache) const {
+  if (cost_.is_unknown()) {
+    VLOG_CONTEXT << "spec " << partition_spec_name();
+    Function extracted_function = sub_graph_->ExtractAsFunction(dataflow_graph);
+    VLOG(2) << "Extracted function:" << std::endl << PrettyPrint(extracted_function);
+    extracted_function = EtaExpandTuples(extracted_function);
+    VLOG(2) << "Validating function:" << std::endl << PrettyPrint(extracted_function);
+    String error = partition_spec()->validate_sub_graph_func_(extracted_function);
+    if (!error.empty()) {
+      cost_ = Cost::Invalid();
+      VLOG(1) << "Unable to rewrite function: " << error;
+    } else {
+      // The extracted function may be the eta-expansion of a "Primitive" function.
+      // If so we want the cached external name and cost to be w.r.t. that function
+      // rather than the outer so that we'll get a cache hit when we outline functions
+      // in the final program.
+      Function primitive_function = GetPrimitiveFunction(extracted_function);
+      CandidateFunctionCache::Entry& entry =
+          cache->GetEntry(sub_graph_->label_, primitive_function);
+      if (entry.cost.is_unknown()) {
+        IRModule mod = IRModule::FromExpr(extracted_function);
+        VLOG(1) << "Outlining:" << std::endl << PrettyPrint(mod);
+        mod = OutlineCompilerFunctions(cache)(mod);
+        VLOG(1) << "Estimating cost of:" << std::endl
+                << PrettyPrint(mod) << std::endl
+                << "using target " << target()->ToDebugString();
+        entry.cost = cost_estimator->Estimate(mod, target(),
+                                              /*needs_tvm_tuning=*/!target().IsExternalCodegen());
+        VLOG(1) << "Measured cost as " << entry.cost.ToString();
+      } else {
+        VLOG(1) << "Reusing cost " << entry.cost.ToString()
+                << " cached in candidate function cache";
+      }
+      cost_ = entry.cost;
+    }
+  } else {
+    VLOG(1) << "Reusing cost " << cost_.ToString() << " cached in candidate";
+  }
+  return cost_;
+}
+
 CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph,
                                        ObjectRef /* actually PartitionSpec */ spec, Cost cost) {
   auto node = runtime::make_object<CandidatePartitionNode>();
diff --git a/src/relay/collage/candidate_partition.h b/src/relay/collage/candidate_partition.h
index 1265087f47..36a23f14bc 100644
--- a/src/relay/collage/candidate_partition.h
+++ b/src/relay/collage/candidate_partition.h
@@ -32,7 +32,10 @@
 #include <string>
 #include <vector>
 
+#include "./candidate_function_cache.h"
 #include "./cost.h"
+#include "./cost_estimator.h"
+#include "./name_supply.h"
 #include "./sub_graph.h"
 
 namespace tvm {
@@ -93,6 +96,13 @@ class CandidatePartitionNode : public Object {
    */
   Target target() const;
 
+  /*!
+   * \brief Return the estimated cost of the candidate partition, using \p cost_estimator and
+   * \p cache.
+   */
+  Cost EstimatedCost(const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator,
+                     const std::shared_ptr<CandidateFunctionCache>& cache) const;
+
   /*!
    * \brief Returns a brief description of candidate suitable for debugging output.
    */
diff --git a/src/relay/collage/combiner_rule.cc b/src/relay/collage/combiner_rule.cc
new file mode 100644
index 0000000000..bcfef04772
--- /dev/null
+++ b/src/relay/collage/combiner_rule.cc
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/combiner_rule.cc
+ * \brief Helpers for the \p CombinePartitionRule
+ */
+
+#include "./combiner_rule.h"
+
+#include "./partition_spec.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+TVM_REGISTER_NODE_TYPE(SimpleCombinerRuleNode);
+
+void SimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+bool SimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph,
+                                   const CandidatePartition& upstream,
+                                   const CandidatePartition& downstream) const {
+  return false;
+}
+
+std::string SimpleCombinerRuleNode::ToString() const {
+  return "SimpleCombinerRule(" + rule_name_ + ")";
+}
+
+SimpleCombinerRule::SimpleCombinerRule(String rule_name) {
+  auto node = runtime::make_object<SimpleCombinerRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(ByKindSimpleCombinerRuleNode);
+
+void ByKindSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+bool ByKindSimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph,
+                                         const CandidatePartition& upstream,
+                                         const CandidatePartition& downstream) const {
+  return upstream->sub_graph_->kind_ <= upstream_kind_ &&
+         downstream->sub_graph_->kind_ <= downstream_kind_;
+}
+
+std::string ByKindSimpleCombinerRuleNode::ToString() const {
+  std::ostringstream os;
+  os << "ByKindSimpleCombinerRule(" << rule_name_ << ")";
+  return os.str();
+}
+
+ByKindSimpleCombinerRule::ByKindSimpleCombinerRule(OpPatternKind upstream_kind,
+                                                   OpPatternKind downstream_kind) {
+  auto node = runtime::make_object<ByKindSimpleCombinerRuleNode>();
+  String rule_name = KindToString(upstream_kind) + "->" + KindToString(downstream_kind);
+  node->rule_name_ = std::move(rule_name);
+  node->upstream_kind_ = upstream_kind;
+  node->downstream_kind_ = downstream_kind;
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(CombinerRuleNode);
+
+void CombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+void CombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {}
+
+std::string CombinerRuleNode::ToString() const { return "CombinerRuleNode(" + rule_name_ + ")"; }
+
+CombinerRule::CombinerRule(String rule_name) {
+  auto node = runtime::make_object<CombinerRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(AllSimpleCombinerRuleNode);
+
+void AllSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+void AllSimpleCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
+  VLOG(1) << "running AllSimpleCombinerRule(" << rule_name_ << ")";
+  // Build map from post-dfs indices to the indices of candidates with corresponding entry node.
+  // NOTE: the index set is over candidate indices not post-dfs indices!
+  std::vector<IndexSet> entry_map(ctxt->dataflow_graph->size(),
+                                  IndexSet(ctxt->candidate_set->size()));
+  for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
+    CandidatePartition candidate = ctxt->candidate_set->at(i);
+    for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) {
+      entry_map[entry_index].Add(i);
+    }
+  }
+
+  for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
+    CandidatePartition upstream = ctxt->candidate_set->at(i);
+    // Narrow our search to just those candidates which could touch.
+    IndexSet possible_downstream(ctxt->candidate_set->size());
+    for (PostDfsIndex output_index : upstream->sub_graph_->output_) {
+      possible_downstream = possible_downstream | entry_map[output_index];
+    }
+    size_t start_j =
+        i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0;
+    for (size_t j : possible_downstream) {
+      if (i == j) {
+        continue;
+      }
+      if (i < start_j) {
+        // We already explored the cross-product of candidates [0, first_new_index), so don't
+        // do it again.
+        continue;
+      }
+      // Note that the rules are not commutative so we can't just ignore if j < i.
+      CandidatePartition downstream = ctxt->candidate_set->at(j);
+      if (ctxt->max_depth > 0 &&
+          upstream->sub_graph_->depth_ + downstream->sub_graph_->depth_ > ctxt->max_depth) {
+        continue;
+      }
+      if (!upstream.AreTouching(*ctxt->dataflow_graph, downstream)) {
+        continue;
+      }
+      for (const auto& simple_rule : simple_rules_) {
+        if (simple_rule->Fires(*ctxt->dataflow_graph, upstream, downstream)) {
+          CandidatePartition new_candidate =
+              upstream.DisjointUnion(*ctxt->dataflow_graph, downstream);
+          VLOG(2) << "Fired " << simple_rule->rule_name_ << " on upstream candidate "
+                  << upstream->ToString() << " and downstream candidate " << downstream->ToString()
+                  << " to yield " << new_candidate->ToString();
+          ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
+        }
+      }
+    }
+  }
+}
+
+std::string AllSimpleCombinerRuleNode::ToString() const {
+  std::ostringstream os;
+  os << "AllSimpleCombinerRule(" << rule_name_;
+  for (const auto& simple : simple_rules_) {
+    os << ", " << simple->ToString();
+  }
+  os << ")";
+  return os.str();
+}
+
+AllSimpleCombinerRule::AllSimpleCombinerRule(String rule_name,
+                                             Array<SimpleCombinerRule> simple_rules) {
+  auto node = runtime::make_object<AllSimpleCombinerRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  node->simple_rules_ = std::move(simple_rules);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(TupleArgCombinerRuleNode);
+
+void TupleArgCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+void TupleArgCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
+  VLOG(1) << "running TupleArgCombinerRule(" << rule_name_ << ")";
+  // Build map from post-dfs index to the indices of injective candidates with corresponding entry
+  // node. NOTE: the index set is over candidate indices not post-dfs indices!
+  std::vector<IndexSet> exit_map(ctxt->dataflow_graph->size(),
+                                 IndexSet(ctxt->candidate_set->size()));
+  for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
+    CandidatePartition candidate = ctxt->candidate_set->at(i);
+    if (candidate->sub_graph_->kind_ > kInjective) {
+      continue;
+    }
+    for (PostDfsIndex exit_index : candidate->sub_graph_->exit_) {
+      exit_map[exit_index].Add(i);
+    }
+  }
+
+  // The two-step I -> tuple -> I rule.
+  // Look all possible tuple consumers...
+  for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
+    CandidatePartition tuple_consumer_candidate = ctxt->candidate_set->at(i);
+    if (tuple_consumer_candidate->sub_graph_->kind_ > kInjective) {
+      continue;
+    }
+    // For all possible tuples feeding into candidate...
+    for (PostDfsIndex input_index : tuple_consumer_candidate->sub_graph_->input_) {
+      auto node = ctxt->dataflow_graph->index_to_node(input_index);
+      Expr sub_expr = node->ref();
+      const auto* tuple_node = sub_expr.as<TupleNode>();
+      if (tuple_node == nullptr) {
+        continue;
+      }
+      // The tuple_consumer_candidate candidate consumes (at least one) tuple, eg as an argument
+      // to an operator.
+      // eg: concatenate((field1, ..., fieldn))
+      auto tuple_dataflow_node = ctxt->dataflow_graph->item_to_node(tuple_node);
+
+      // Collect all the possible unions. There may be more than one if different candidates
+      // could supply the same tuple field.
+      std::vector<std::vector<CandidatePartition>> all_possible_unions;
+
+      // Obviously we must include the consumer.
+      all_possible_unions.emplace_back();
+      all_possible_unions.back().emplace_back(tuple_consumer_candidate);
+
+      // We must include the tuple itself.
+      SubGraph tuple_sub_graph(*ctxt->dataflow_graph,
+                               IndexSet(ctxt->dataflow_graph->size(), {node->index_}), kInjective,
+                               "tuple");
+      CandidatePartition tuple_candidate("", std::move(tuple_sub_graph),
+                                         tuple_consumer_candidate->partition_spec());
+      all_possible_unions.back().emplace_back(std::move(tuple_candidate));
+
+      // For all tuple fields...
+      bool all_tuple_fields_have_producer = true;
+      for (auto* tuple_field_dataflow_node : tuple_dataflow_node->inputs_) {
+        // Collect all the candidates which could produce this tuple field.
+        std::vector<CandidatePartition> to_appends;
+        size_t start_j =
+            i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0;
+        for (size_t j : exit_map[tuple_field_dataflow_node->index_]) {
+          if (i == j) {
+            continue;
+          }
+          if (i < start_j) {
+            // We already explored the cross-product of candidates [0, first_new_index), so don't
+            // do it again.
+            continue;
+          }
+          CandidatePartition tuple_field_producer = ctxt->candidate_set->at(j);
+          // The tuple_field_producer candidate can provide this tuple field.
+          // eg concatenate((..., producer, ...))
+          to_appends.emplace_back(tuple_field_producer);
+        }
+        if (to_appends.empty()) {
+          // At least one of the tuple's fields does not have a producer candidate we can
+          // union in, so we need to give up.
+          all_tuple_fields_have_producer = false;
+          break;
+        } else {
+          // If to_appends = [A, B] and we already have possible unions [C, D] and [E, F] then
+          // the new possible unions are [C, D, A], [C, D, B], [E, F, A] and [E, F, B].
+          std::vector<std::vector<CandidatePartition>> new_all_possible_unions;
+          for (const auto& to_append : to_appends) {
+            for (const auto& possible_union : all_possible_unions) {
+              new_all_possible_unions.emplace_back(possible_union);
+              new_all_possible_unions.back().emplace_back(to_append);
+            }
+          }
+          all_possible_unions = std::move(new_all_possible_unions);
+        }
+      }
+
+      if (!all_tuple_fields_have_producer) {
+        continue;
+      }
+
+      // Actually build the candidates which union according to all_possible_unions.
+      for (const auto& possible_union : all_possible_unions) {
+        if (possible_union.size() > 2) {
+          CandidatePartition new_candidate =
+              CandidatePartition::DisjointUnion(*ctxt->dataflow_graph, possible_union);
+#if TVM_LOG_DEBUG
+          std::ostringstream os;
+          bool first = true;
+          for (const auto& candidate : possible_union) {
+            if (first) {
+              first = false;
+            } else {
+              os << ", ";
+            }
+            os << candidate->ToString();
+          }
+          VLOG(2) << "Fired rule " << rule_name_ << " on {" << os.str() << "} to yield "
+                  << new_candidate->ToString();
+#endif
+          ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
+        }
+      }
+    }
+  }
+}
+
+std::string TupleArgCombinerRuleNode::ToString() const {
+  return "TupleArgCombinerRule(" + rule_name_ + ")";
+}
+
+TupleArgCombinerRule::TupleArgCombinerRule(String rule_name) {
+  auto node = runtime::make_object<TupleArgCombinerRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(TupleProjCombinerRuleNode);
+
+void TupleProjCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+void TupleProjCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
+  VLOG(1) << "running TupleProjCombinerRule(" << rule_name_ << ")";
+  // We already explored [0, first_new_index), so don't do it again.
+  for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) {
+    CandidatePartition base = ctxt->candidate_set->at(i);
+    for (PostDfsIndex index : base->sub_graph_->output_) {
+      auto node = ctxt->dataflow_graph->index_to_node(index);
+      if (node->ref().as<TupleGetItemNode>()) {
+        IndexSet index_set(ctxt->dataflow_graph->size(), {node->index_});
+        SubGraph sub_graph(*ctxt->dataflow_graph, std::move(index_set), kInjective, "proj");
+        CandidatePartition proj_candidate("", std::move(sub_graph), base->spec_);
+        CandidatePartition new_candidate =
+            base.DisjointUnion(*ctxt->dataflow_graph, proj_candidate);
+        VLOG(2) << "Fired rule " << rule_name_ << " on " << proj_candidate->ToString() << " and "
+                << base->ToString() << " to yield " << new_candidate->ToString();
+        ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
+      }
+    }
+  }
+}
+
+std::string TupleProjCombinerRuleNode::ToString() const {
+  return "TupleProjCombinerRule(" + rule_name_ + ")";
+}
+
+TupleProjCombinerRule::TupleProjCombinerRule(String rule_name) {
+  auto node = runtime::make_object<TupleProjCombinerRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(ConstantCombinerRuleNode);
+
+void ConstantCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+void ConstantCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
+  VLOG(1) << "running ConstantCombinerRule(" << rule_name_ << ")";
+  // We already explored [0, first_new_index), so don't do it again.
+  for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) {
+    CandidatePartition base = ctxt->candidate_set->at(i);
+    IndexSet new_constants(ctxt->dataflow_graph->size());
+    for (PostDfsIndex index : base->sub_graph_->input_) {
+      auto node = ctxt->dataflow_graph->index_to_node(index);
+      if (node->ref().as<ConstantNode>()) {
+        new_constants.Add(index);
+      }
+    }
+    if (!new_constants.IsZero()) {
+      SubGraph sub_graph(*ctxt->dataflow_graph, new_constants, kElemWise, "const");
+      CandidatePartition new_const_candidate("", std::move(sub_graph), base->spec_);
+      CandidatePartition new_candidate =
+          base.DisjointUnion(*ctxt->dataflow_graph, new_const_candidate);
+      VLOG(2) << "Fired rule " << rule_name_ << " on " << new_const_candidate->ToString() << " and "
+              << base->ToString() << " to yield " << new_candidate->ToString();
+      ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
+    }
+  }
+}
+
+std::string ConstantCombinerRuleNode::ToString() const {
+  return "ConstantCombinerRule(" + rule_name_ + ")";
+}
+
+ConstantCombinerRule::ConstantCombinerRule(String rule_name) {
+  auto node = runtime::make_object<ConstantCombinerRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/combiner_rule.h b/src/relay/collage/combiner_rule.h
new file mode 100644
index 0000000000..04ea2a9cc1
--- /dev/null
+++ b/src/relay/collage/combiner_rule.h
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/combiner_rule.h
+ * \brief Helpers for the \p CombinePartitionRule
+ */
+
+#ifndef TVM_RELAY_COLLAGE_COMBINER_RULE_H_
+#define TVM_RELAY_COLLAGE_COMBINER_RULE_H_
+
+#include <tvm/relay/dataflow_pattern.h>
+#include <tvm/relay/expr.h>
+
+#include <string>
+
+#include "./candidate_partition.h"
+#include "./candidate_set.h"
+#include "./sub_graph.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief Base class for all 'simple' combiner rules.
+ *
+ * Given \p upstream and \p downstream candidates which touch, a simple combiner rule returns
+ * true if their union should also be considered a candidate.
+ */
+class SimpleCombinerRuleNode : public Object {
+ public:
+  String rule_name_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  virtual bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream,
+                     const CandidatePartition& downstream) const;
+
+  virtual std::string ToString() const;
+
+  static constexpr const char* _type_key = "relay.collage.SimpleCombinerRule";
+  static constexpr const uint32_t _type_child_slots = 1;
+  TVM_DECLARE_BASE_OBJECT_INFO(SimpleCombinerRuleNode, Object);
+};
+
+class SimpleCombinerRule : public ObjectRef {
+ public:
+  explicit SimpleCombinerRule(String rule_name);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(SimpleCombinerRule, ObjectRef, SimpleCombinerRuleNode);
+};
+
+/*!
+ * \brief A simple combiner rule which fires if the \p upstream and \p downstream candidates have
+ * the given \p upstream_kind and \p downstream_kind (or less) respectively.
+ */
+class ByKindSimpleCombinerRuleNode : public SimpleCombinerRuleNode {
+ public:
+  OpPatternKind upstream_kind_;
+  OpPatternKind downstream_kind_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream,
+             const CandidatePartition& downstream) const override;
+  std::string ToString() const override;
+
+  static constexpr const char* _type_key = "relay.collage.ByKindSimpleCombinerRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ByKindSimpleCombinerRuleNode, SimpleCombinerRuleNode);
+};
+
+class ByKindSimpleCombinerRule : public SimpleCombinerRule {
+ public:
+  ByKindSimpleCombinerRule(OpPatternKind upstream_kind, OpPatternKind downstream_kind);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ByKindSimpleCombinerRule, SimpleCombinerRule,
+                                ByKindSimpleCombinerRuleNode);
+};
+
+/*! \brief Context required by CombineRuleNode::AppendAllResultsContext. */
+struct AppendAllResultsContext {
+  AppendAllResultsContext(const DataflowGraph* dataflow_graph, size_t max_depth,
+                          CandidateSet* candidate_set)
+      : dataflow_graph(dataflow_graph), max_depth(max_depth), candidate_set(candidate_set) {}
+
+  const DataflowGraph* dataflow_graph;
+  size_t max_depth;
+  CandidateSet* candidate_set;
+};
+
+/*!
+ * \brief Base class for all 'combiner' rules.
+ *
+ * Given the current candidate set, a combiner rule looks for opportunities to form larger
+ * candidates, optionally removing existing candidates in the process.
+ */
+class CombinerRuleNode : public Object {
+ public:
+  String rule_name_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  virtual void AppendAllResults(AppendAllResultsContext* ctxt) const;
+  virtual std::string ToString() const;
+
+  static constexpr const char* _type_key = "relay.collage.CombinerRule";
+  static constexpr const uint32_t _type_child_slots = 4;
+  TVM_DECLARE_BASE_OBJECT_INFO(CombinerRuleNode, Object);
+};
+
+class CombinerRule : public ObjectRef {
+ public:
+  explicit CombinerRule(String rule_name);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(CombinerRule, ObjectRef, CombinerRuleNode);
+};
+
+/*!
+ * \brief A combiner rule which runs one or more simple combiner rules over the current
+ * touching candidates.
+ */
+class AllSimpleCombinerRuleNode : public CombinerRuleNode {
+ public:
+  Array<SimpleCombinerRule> simple_rules_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  void AppendAllResults(AppendAllResultsContext* ctxt) const override;
+  std::string ToString() const override;
+
+  static constexpr const char* _type_key = "relay.collage.AllSimpleCombinerRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllSimpleCombinerRuleNode, CombinerRuleNode);
+};
+
+class AllSimpleCombinerRule : public CombinerRule {
+ public:
+  AllSimpleCombinerRule(String rule_name, Array<SimpleCombinerRule> simple_rules);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(AllSimpleCombinerRule, CombinerRule, AllSimpleCombinerRuleNode);
+};
+
+/*!
+ * \brief A combiner rule which combines injective sub-groups which appear inside tuples which are
+ * themselves inputs to injective sub-groups.
+ */
+class TupleArgCombinerRuleNode : public CombinerRuleNode {
+ public:
+  void VisitAttrs(AttrVisitor* v);
+
+  void AppendAllResults(AppendAllResultsContext* ctxt) const override;
+  std::string ToString() const override;
+
+  static constexpr const char* _type_key = "relay.collage.TupleArgCombinerRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleArgCombinerRuleNode, CombinerRuleNode);
+};
+
+class TupleArgCombinerRule : public CombinerRule {
+ public:
+  explicit TupleArgCombinerRule(String rule_name);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TupleArgCombinerRule, CombinerRule, TupleArgCombinerRuleNode);
+};
+
+/*!
+ * \brief A combiner rule which combines tuple projection if it's an output of an injective
+ * group.
+ */
+class TupleProjCombinerRuleNode : public CombinerRuleNode {
+ public:
+  void VisitAttrs(AttrVisitor* v);
+
+  void AppendAllResults(AppendAllResultsContext* ctxt) const override;
+  std::string ToString() const override;
+
+  static constexpr const char* _type_key = "relay.collage.TupleProjCombinerRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TupleProjCombinerRuleNode, CombinerRuleNode);
+};
+
+class TupleProjCombinerRule : public CombinerRule {
+ public:
+  explicit TupleProjCombinerRule(String rule_name);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TupleProjCombinerRule, CombinerRule, TupleProjCombinerRuleNode);
+};
+
+/*!
+ * \brief A combiner rule which combines constants in argument positions to existing candidates.
+ * Note that scalars are always inlined, so this rule only combines tensor constant arguments.
+ */
+class ConstantCombinerRuleNode : public CombinerRuleNode {
+ public:
+  void VisitAttrs(AttrVisitor* v);
+
+  void AppendAllResults(AppendAllResultsContext* ctxt) const override;
+  std::string ToString() const override;
+
+  static constexpr const char* _type_key = "relay.collage.ConstantCombinerRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ConstantCombinerRuleNode, CombinerRuleNode);
+};
+
+class ConstantCombinerRule : public CombinerRule {
+ public:
+  explicit ConstantCombinerRule(String rule_name);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ConstantCombinerRule, CombinerRule, ConstantCombinerRuleNode);
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_COMBINER_RULE_H_
diff --git a/src/relay/collage/cost.h b/src/relay/collage/cost.h
index 8ae276d220..723c5b58ac 100644
--- a/src/relay/collage/cost.h
+++ b/src/relay/collage/cost.h
@@ -71,6 +71,11 @@ class Cost {
 
   bool is_value() const { return !std::isnan(value_) && !std::isinf(value_); }
 
+  double value() const {
+    ICHECK(is_value());
+    return value_;
+  }
+
   /*! \brief Return true if the less-than relation is defined for this and that. */
   bool are_comparable(Cost that) const { return !std::isnan(value_) && !std::isnan(that.value_); }
 
diff --git a/src/relay/collage/cost_estimator.cc b/src/relay/collage/cost_estimator.cc
new file mode 100644
index 0000000000..e2ea99ce9b
--- /dev/null
+++ b/src/relay/collage/cost_estimator.cc
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/cost_estimator.cc
+ * \brief Interface for measuring candidate partition cost.
+ */
+
+#include "./cost_estimator.h"
+
+#include <math.h>
+#include <tvm/relay/expr_functor.h>
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+TVM_REGISTER_OBJECT_TYPE(CostEstimatorNode);
+TVM_REGISTER_OBJECT_TYPE(MockEstimatorNode);
+
+CostEstimator::CostEstimator() {
+  auto node = make_object<CostEstimatorNode>();
+  data_ = std::move(node);
+}
+
+Cost CostEstimatorNode::Estimate(const IRModule& mod, const Target& target,
+                                 bool needs_tvm_turning) const {
+  static const runtime::PackedFunc* estimate_seconds =
+      runtime::Registry::Get("tvm.relay.collage.estimate_seconds");
+  ICHECK(estimate_seconds);
+  const double value = (*estimate_seconds)(mod, target, needs_tvm_turning);
+  if (std::isinf(value)) {
+    return Cost::Invalid();
+  } else if (std::isnan(value)) {
+    return Cost::Unknown();
+  } else {
+    return Cost::Value(value);
+  }
+}
+
+/*!
+ * \brief Visitor to accumulate the costs of all calls to operators in an expression.
+ */
+class MockEstimationVisitor : private ExprVisitor {
+ public:
+  MockEstimationVisitor(double op_cost, double fusion_benefit)
+      : op_cost_(op_cost), fusion_benefit_(fusion_benefit) {}
+
+  double EstimateCost(const Expr& body) {
+    this->VisitExpr(body);
+    return cost_;
+  }
+
+ private:
+  /*! \brief The assumed baseline cost of each operator call. */
+  double op_cost_;
+  /*!
+   * \brief The factor by which each operator call cost is to be changed for every other
+   * operator call in the same group.
+   */
+  double fusion_benefit_;
+  /*! \brief The number of operator calls seen so far. */
+  size_t num_ops_ = 0;
+  /*! \brief Accumulate overall cost. */
+  double cost_ = 0.0;
+
+  void VisitExpr_(const CallNode* call_node) final {
+    if (call_node->op->IsInstance<OpNode>()) {
+      cost_ += op_cost_ * pow(fusion_benefit_, num_ops_);
+      num_ops_++;
+    }
+    ExprVisitor::VisitExpr_(call_node);
+  }
+
+  void VisitExpr_(const FunctionNode* function_node) final {
+    // No "Compiler" functions can be inlined.
+    ICHECK(!function_node->GetAttr<String>(attr::kCompiler).defined());
+    ExprVisitor::VisitExpr_(function_node);
+  }
+};
+
+Cost MockEstimatorNode::Estimate(const IRModule& mod, const Target& target,
+                                 bool needs_tvm_tuning) const {
+  double op_cost = static_cast<double>(target_costs_.at(target->kind->name)->value);
+  double cost = 0.0;
+  for (const auto& kv : mod->functions) {
+    if (const auto* function_node = kv.second.as<FunctionNode>()) {
+      auto function = GetRef<Function>(function_node);
+      if (kv.first->name_hint == "main") {
+        // Only tensor args are allowed to main.
+        for (const auto& param : function->params) {
+          ICHECK(param->type_annotation->IsInstance<TensorTypeNode>());
+        }
+      }
+      cost += MockEstimationVisitor(op_cost, /*fusion_benefit=*/0.9).EstimateCost(function->body);
+    }
+  }
+  return Cost::Value(cost);
+}
+
+MockEstimator::MockEstimator(Map<String, Integer> target_costs) {
+  auto node = make_object<MockEstimatorNode>();
+  node->target_costs_ = std::move(target_costs);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("relay.collage.CostEstimator").set_body_typed([]() { return CostEstimator(); });
+
+TVM_REGISTER_GLOBAL("relay.collage.MockEstimator")
+    .set_body_typed([](Map<String, Integer> target_costs) {
+      return MockEstimator(std::move(target_costs));
+    });
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/cost_estimator.h b/src/relay/collage/cost_estimator.h
new file mode 100644
index 0000000000..f433fd5840
--- /dev/null
+++ b/src/relay/collage/cost_estimator.h
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/cost_estimator.cc
+ * \brief Interface for measuring candidate partition cost.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_
+#define TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_
+
+#include <tvm/relay/function.h>
+
+#include "./cost.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief An (abstract) estimator for the cost of executing "main" in an \p IRModule representing
+ * a candidate partition, using the given target for lowering and codegen.
+ *
+ * Generally the implementation will compile to a \p runtime::Module (possibly on a target-specific
+ * worker if cross-compilation is not available), repeatedly invoke "main" with random data until
+ * measure variance is acceptable (on a target-specific worker), and return the summarized costs.
+ *
+ * If using a TVM native \p Target, it is possible compilation will itself invoke TVM tuning.
+ *
+ * TODO(mbs): Actually, currently not abstract so can get some local measurements.
+ */
+class CostEstimatorNode : public Object {
+ public:
+  /*!
+   * \brief Returns the estimated cost (possibly after many many minutes of training time) of
+   * running "main" in \p mod using \p target, which represents a possible partitioning of
+   * some overall Relay expression.
+   */
+  virtual Cost Estimate(const IRModule& mod, const Target& target, bool needs_tvm_tuning) const;
+
+  static constexpr const char* _type_key = "relay.collage.CostEstimator";
+  TVM_DECLARE_BASE_OBJECT_INFO(CostEstimatorNode, Object);
+};
+
+class CostEstimator : public ObjectRef {
+ public:
+  CostEstimator();
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CostEstimator, ObjectRef, CostEstimatorNode);
+};
+
+/*!
+ * \brief A mock cost estimator which can determine the cost of a candidate based on both
+ * the candidate's target and the number of operator calls inside it.
+ *
+ * The estimator also ICHECKs the given module has all "Compiler" functions outlined and @main
+ * takes only tensor arguments (ie no tuple types).
+ *
+ * To support testing only.
+ */
+class MockEstimatorNode : public CostEstimatorNode {
+ public:
+  Cost Estimate(const IRModule& mod, const Target& target, bool needs_tvm_tuning) const override;
+
+  static constexpr const char* _type_key = "relay.collage.MockEstimator";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MockEstimatorNode, CostEstimatorNode);
+
+ protected:
+  friend class MockEstimator;
+
+  /*!
+   * \brief Map from target kind name to assumed baseline cost (in integer seconds) for all
+   * operator calls.
+   */
+  Map<String, Integer> target_costs_;
+};
+
+class MockEstimator : public CostEstimator {
+ public:
+  explicit MockEstimator(Map<String, Integer> target_costs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(MockEstimator, CostEstimator, MockEstimatorNode);
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_COST_ESTIMATOR_H_
diff --git a/src/relay/collage/name_supply.cc b/src/relay/collage/name_supply.cc
new file mode 100644
index 0000000000..4b7d497b0d
--- /dev/null
+++ b/src/relay/collage/name_supply.cc
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/name_supply.cc
+ * \brief A source of fresh variable names.
+ */
+
+#include "./name_supply.h"
+
+#include <algorithm>
+#include <sstream>
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+void AppendCSafe(bool* first, std::ostringstream& os, const std::string& str) {
+  for (size_t i = 0; i < str.size(); ++i) {
+    const char c = str[i];
+    if (i == 0 && first && (!std::isalpha(c) && c != '_')) {
+      os << "_";
+    }
+    if (c == '_' || std::isalnum(c)) {
+      os << c;
+    } else {
+      os << "_";
+    }
+    *first = false;
+  }
+}
+}  // namespace
+
+NameSupply NameSupply::MakeSubNameSupply() {
+  NameSupply result(prefix_);
+  for (const auto& kv : next_free_index_) {
+    result.next_free_index_.emplace(kv.first, kv.second);
+  }
+  return result;
+}
+
+std::string NameSupply::Fresh(const std::initializer_list<std::string>& hints) {
+  std::ostringstream os;
+  bool first = true;
+  bool need_sep = false;
+  if (!prefix_.empty()) {
+    AppendCSafe(&first, os, prefix_);
+    need_sep = true;
+  }
+  for (const auto& hint : hints) {
+    if (hint.empty()) {
+      continue;
+    }
+    if (need_sep) {
+      os << "_";
+    }
+    AppendCSafe(&first, os, hint);
+    need_sep = true;
+  }
+  std::string name = os.str();
+  auto itr = next_free_index_.find(name);
+  if (itr == next_free_index_.end()) {
+    next_free_index_.emplace(name, 1);
+  } else {
+    os << "_" << itr->second++;
+    name = os.str();
+  }
+  return name;
+}
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/name_supply.h b/src/relay/collage/name_supply.h
new file mode 100644
index 0000000000..d37023ab6f
--- /dev/null
+++ b/src/relay/collage/name_supply.h
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/name_supply.h
+ * \brief A source of fresh variable names.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_NAME_SUPPLY_H_
+#define TVM_RELAY_COLLAGE_NAME_SUPPLY_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief A supply of fresh names. */
+class NameSupply {
+ public:
+  explicit NameSupply(std::string prefix) : prefix_(std::move(prefix)) {}
+
+  NameSupply MakeSubNameSupply();
+
+  void Reserve(const std::string& existing) { next_free_index_.emplace(existing, 1); }
+
+  std::string Fresh(const std::initializer_list<std::string>& hints);
+
+ private:
+  /*! \brief Prefix for all names. May be empty. */
+  std::string prefix_;
+  /*! \brief Next unused index for variables with given basename. */
+  std::unordered_map<std::string, int> next_free_index_;
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_NAME_SUPPLY_H_
diff --git a/src/relay/collage/partition_rule.cc b/src/relay/collage/partition_rule.cc
index 1cedbfc9d7..e11f740acf 100644
--- a/src/relay/collage/partition_rule.cc
+++ b/src/relay/collage/partition_rule.cc
@@ -285,6 +285,66 @@ OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) {
   data_ = std::move(node);
 }
 
+TVM_REGISTER_NODE_TYPE(CombinePartitionRuleNode);
+
+void CombinePartitionRuleNode::VisitAttrs(AttrVisitor* v) {
+  // TODO(mbs)
+}
+
+std::vector<CandidatePartition> CombinePartitionRuleNode::AllCandidates(
+    const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
+  // We'll accumulate all the candidates here, starting with those from the sub-rule.
+  // Once a candidate is added to this vector it is immutable.
+  std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec);
+  VLOG(1) << "running CombinePartitionRule(" << rule_name_ << ") over " << candidates.size()
+          << " sub-candidates";
+  CandidateSet result_set(std::move(candidates));
+
+  size_t num_rounds = 0;
+  AppendAllResultsContext ctxt(&dataflow_graph, max_depth_, &result_set);
+  while (result_set.PrepareForNextRound()) {
+    VLOG_CONTEXT << "round " << ++num_rounds;
+    VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index()
+            << " existing)";
+    for (const auto& combiner_rule : combiner_rules_) {
+      combiner_rule->AppendAllResults(&ctxt);
+    }
+  }
+
+  std::vector<CandidatePartition> result;
+  for (auto& candidate : result_set.MovedCurrentCandidates()) {
+    String rule_name = NestLabels(rule_name_, candidate->rule_name_);
+    CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name));
+    VLOG(2) << "CombinePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString();
+    result.emplace_back(std::move(new_candidate));
+  }
+  VLOG(1) << "CombinePartitionRule(" << rule_name_ << ") produced " << result.size()
+          << " candidates";
+  return result;
+}
+
+void CombinePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
+  PartitionRuleNode::AppendBodyItems(body_items);
+  body_items->emplace_back();
+  body_items->back() << "sub_rule=" << sub_rule_->ToDoc();
+  for (const auto& combiner_rule : combiner_rules_) {
+    body_items->emplace_back();
+    body_items->back() << "combiner_rule=" << combiner_rule->ToString();
+  }
+  body_items->emplace_back();
+  body_items->back() << "max_depth=" << max_depth_;
+}
+
+CombinePartitionRule::CombinePartitionRule(String rule_name, PartitionRule sub_rule,
+                                           Array<CombinerRule> combiner_rules, size_t max_depth_) {
+  auto node = runtime::make_object<CombinePartitionRuleNode>();
+  node->rule_name_ = std::move(rule_name);
+  node->sub_rule_ = std::move(sub_rule);
+  node->combiner_rules_ = std::move(combiner_rules);
+  node->max_depth_ = max_depth_;
+  data_ = std::move(node);
+}
+
 TVM_REGISTER_NODE_TYPE(OnlyValidPartitionRuleNode);
 
 void OnlyValidPartitionRuleNode::VisitAttrs(AttrVisitor* v) {
diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h
index 13f5c0b01d..19e7f3cceb 100644
--- a/src/relay/collage/partition_rule.h
+++ b/src/relay/collage/partition_rule.h
@@ -33,6 +33,7 @@
 
 #include "../../printer/doc.h"
 #include "./candidate_partition.h"
+#include "./combiner_rule.h"
 #include "./sub_graph.h"
 
 namespace tvm {
@@ -88,6 +89,15 @@ bool DefaultPatternPredicate(const Expr& matched_sub_expr);
  *    delineate a partition (or kernel).
  *  - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to
  *    combine individual \p DFPatternPartitionRules.
+ *  - \p CombinePartitionRule: Given a sub-rule and a list of 'combiner' rules, finds
+ *    all possible ways of combining the sub-rule's candidates to yield even larger candidates.
+ *    Note that the sub-rule's candidates may also be directly included in the results. The
+ *    'combiner' rules allow combining by \p OpPatternKinds, combining the arguments to tuples
+ *    which themselves are arguments to Relay operator calls, and so on. This rule is intended to
+ *    mimic the existing TVM \p FuseOps pass, though:
+ *    i) all candidates are found rather than just the largest, ii) the starting set of candidates
+ *    can be provided by any other rule, and iii) we rely on \p SubGraph validity checking to weed
+ *    out infeasible candidates.
  *  - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid'
  *    sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs,
  *    and whether intermediate 'taps' are allowed.
@@ -100,6 +110,54 @@ bool DefaultPatternPredicate(const Expr& matched_sub_expr);
  * partition on more primitive candidates. Note that the \p SubGraph machinery supports
  * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy
  * implement.)
+ *
+ * Here are some typical ways to combine \p PartitionRules for different partition/fusion
+ * strategies:
+ *
+ *  - Classic pattern-based BYOC with \p MergeComposite/AnnotateTarget/PartitionGraph passes:
+ *    \code
+ *    PrimitivePartitionRule
+ *      OnlyValidPartitionRule
+ *        CombinePartitionRule (with join-anything combiner rule)
+ *          UnionPartitionRule
+ *            CompositePartitionRule(label1)
+ *              DFPatternPartitionRule(pattern1)
+ *                        :
+ *            CompositePartitionRule(labeln)
+ *              DFPatternPartitionRule(patternn)
+ *    \endcode
+ *
+ *  - "Consider this library implementation for these sub-expressions", using \p DFPatterns to
+ *    pick out which Relay operators are supported:
+ *    \code
+ *    OnlyValidPartitionRule
+ *      CombinePartitionRule (with default TVM combiner rules)
+ *        UnionPartitionRule
+ *          OpCallByKindPartitionRule
+ *          CompositePartitionRule(lable1)
+ *            DFPatternPartitionRule(pattern1)
+ *                       :
+ *          CompositePartitionRule(lablen)
+ *            DFPatternPartitionRule(patternn)
+ *    \endcode
+ *
+ *  - Classic TVM \p FuseOps
+ *    \code
+ *    PrimitivePartitionRule
+ *      OnlyValidPartitionRule
+ *        CombinePartitionRule (with default TVM combiner rules)
+ *          OpCallByKindPartitionRule
+ *    \endcode
+ *
+ *  - "Just fuse what I tell you to fuse", using \p DFPatterns to directly select candidates:
+ *    \code
+ *    PrimitivePartitionRule
+ *      OnlyValidPartitionRule
+ *        UnionPartitionRule
+ *          DFPatternPartitionRule(pattern1)
+ *                       :
+ *          DFPatternPartitionRule(patternn)
+ *    \endcode
  */
 class PartitionRuleNode : public Object {
  public:
@@ -293,6 +351,80 @@ class OpCallByKindPartitionRule : public PartitionRule {
                                 OpCallByKindPartitionRuleNode);
 };
 
+/*!
+ * \brief Partition rule which combines sub-graphs to exploit optimizations commonly available in
+ * backends (including the TVM lowering backend). Those optimization rules are in turn described by
+ * one or more primitive \p CombinerRules.
+ *
+ * For TVM these primitive combiner rules are guided by the \p OpPatternKind associated with every
+ * sub-graph. That in turn is the maximum of the kind of each expression node in the sub-graph,
+ * using the rules:
+ *  - Constants are \p kElemwise.
+ *  - A call to a Relay operator has the kind of its callee.
+ *  - Tuple construction and projection are injective provided all tuple fields are of tensor type.
+ *  - All other sub-expressions are opaque.
+ *
+ * The available \p OpPatternKinds (and our abbreviations for them) are:
+ *  - E: kElemWise, eg nn.relu
+ *  - B: kBroadcast, eg add
+ *  - I: kInjective, eg concatenate
+ *  - R: kCommReduce, eg sum
+ *  - A: kOutEWiseFusable, eg nn.conv2d (often called 'anchor nodes', hence the A abbreviation)
+ *  - O: kOpaque, everything else
+ * (The kTuple kind is not used by this machinery.)
+ *
+ * Kinds are ordered as above from least- to most-constraining w.r.t. possible partition
+ * opportunities. When we write a kind abbreviation below we intend it to mean that kind *or less*.
+ * And when when write 'kl -> kr' we mean it to match a sub-expression of kind kr or less who's
+ * dataflow inputs are all of kind kl or less.
+ *
+ * We can then mimic the classic \p FuseOps TVM Pass with the following more primitive combiner
+ * rules:
+ *  - Sub-groups cannot have taps. In the classic \p FuseOps pass taps are avoided by construction
+ *    by always considering all node->dominator paths. Here we naively allow taps on all candidates,
+ *    but reject them using SubGraph::IsValid with a SubGraphConfig with allow_taps = false.
+ *  - Combine A -> B
+ *  - Combine B -> R
+ *  - Combine I -> I
+ *  - Combine I -> tuple -> I. That is, if an I sub-graph has a tuple as input, and at least one
+ *    tuple field can be provided by an I sub-graph exit, then both the tuple and all such fields
+ *    may be joined.
+ gt*
+ * Note that \p FuseOps only considers the largest possible sub-graphs. However this partition rule
+ * considers all possibilities so as to 'make room' for other targets supplying other
+ * overlapping candidates.
+ *
+ * See combiner_rule.h for the more primitive combiner rules which implement the above.
+ */
+class CombinePartitionRuleNode : public PartitionRuleNode {
+ public:
+  /*! \brief The sub-rule supplying the initial set of candidates. */
+  PartitionRule sub_rule_;
+  /*! \brief The more primitive rules to use to combine the candidates found by the above rule. */
+  Array<CombinerRule> combiner_rules_;
+  /*! \brief Maximum max_depth for candidates. */
+  size_t max_depth_;
+
+  void VisitAttrs(AttrVisitor* v);
+
+  std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
+                                                const PartitionSpec& spec) const override;
+
+  void AppendBodyItems(std::vector<Doc>* body_items) const override;
+
+ public:
+  static constexpr const char* _type_key = "relay.collage.CombinePartitionRule";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CombinePartitionRuleNode, PartitionRuleNode);
+};
+
+class CombinePartitionRule : public PartitionRule {
+ public:
+  CombinePartitionRule(String rule_name, PartitionRule sub_rule, Array<CombinerRule> combiner_rules,
+                       size_t max_depth_);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(CombinePartitionRule, PartitionRule, CombinePartitionRuleNode);
+};
+
 /*!
  * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid
  * w.r.t. the given \p SubGraphConfig.
diff --git a/tests/cpp/relay/collage/candidate_partition_test.cc b/tests/cpp/relay/collage/candidate_partition_test.cc
new file mode 100644
index 0000000000..c4f81e18ec
--- /dev/null
+++ b/tests/cpp/relay/collage/candidate_partition_test.cc
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../../../src/relay/collage/candidate_partition.h"
+
+#include <gtest/gtest.h>
+#include <tvm/parser/parser.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+#include <tvm/relay/transform.h>
+
+#include "../../../src/relay/collage/partition_spec.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+namespace {
+
+// NOTE: CandidatePartition::ParallelRewrite is effectively tested in partition_rule_test.cc
+// so not re-tested here. The only other non-trivial code is CandidatePartition::EstimateCost
+
+Function MakeTestFunction(const std::string& mod_text) {
+  IRModule mod = parser::ParseModule("string", mod_text, {}, {});
+  mod = transform::CapturePostDfsIndexInSpans()(mod);
+  auto func = Downcast<Function>(mod->Lookup("main"));
+  LOG(INFO) << "------- input function -------";
+  LOG(INFO) << PrettyPrint(func);
+  LOG(INFO) << "------------------------------";
+  return func;
+}
+
+PartitionSpec StandardSpec() { return PartitionSpec("test_spec", Target("llvm"), {}); }
+
+String AlwaysInvalid(const Function& function) { return "invalid"; }
+
+PartitionSpec AlwaysInvalidSpec() {
+  return PartitionSpec("test_spec", Target("llvm"), {}, AlwaysInvalid);
+}
+
+/*!
+ * \brief Returns candidate containing nodes with given \p indexes wrapped within a
+ * "Primitive" and "Compiler" function.
+ */
+CandidatePartition MakeCandidate(const DataflowGraph& graph, const PartitionSpec& spec,
+                                 const std::vector<PostDfsIndex>& indexes) {
+  IndexSet inside(graph.size(), indexes);
+  SubGraph inner_sub_graph(graph, inside);
+  FunctionAttrsMap attrs_map;
+  attrs_map.Set(attr::kPrimitive, Integer(1));
+  attrs_map.Set(attr::kCompiler, String("llvm"));
+  NestedSubGraph nested_sub_graph(inner_sub_graph, attrs_map);
+  SubGraph outer_sub_graph(graph, inside, inner_sub_graph->kind_, inner_sub_graph->label_,
+                           {nested_sub_graph});
+  return CandidatePartition(/*rule_name=*/"", outer_sub_graph, spec);
+}
+
+CostEstimator StandardEstimator() {
+  Map<String, Integer> target_costs;
+  target_costs.Set("llvm", 3);
+  return MockEstimator(std::move(target_costs));
+}
+
+CostEstimator AlternateEstimator() {
+  Map<String, Integer> target_costs;
+  target_costs.Set("llvm", 7);
+  return MockEstimator(std::move(target_costs));
+}
+
+std::shared_ptr<CandidateFunctionCache> Cache() {
+  return std::make_shared<CandidateFunctionCache>(std::make_shared<NameSupply>("test"));
+}
+
+TEST(CandidatePartition, EstimateCost_Simple) {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(10, 10), float32]) {
+      %0 = abs(%x);                      //  3
+      %1 = nn.relu(%0);                  //  4
+      nn.relu(%1)                        //  5
+    }
+  )";
+  auto func = MakeTestFunction(kMod);
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+  auto candidate = MakeCandidate(graph, spec, {3, 4});
+  auto estimator = StandardEstimator();
+  auto cache = Cache();
+
+  {
+    auto cost = candidate->EstimatedCost(graph, estimator, cache);
+    ASSERT_TRUE(cost.is_value());
+    // cost is 3 for nn.rulu plus 3 * 0.9 for the nested abs
+    ASSERT_EQ(cost.value(), 5.7);
+  }
+}
+
+TEST(CandidatePartition, EstimateCost_AlreadyCached) {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(10, 10), float32]) {
+      %0 = abs(%x);                      //  3
+      %1 = nn.relu(%0);                  //  4
+      nn.relu(%1)                        //  5
+    }
+  )";
+  auto func = MakeTestFunction(kMod);
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+  auto candidate = MakeCandidate(graph, spec, {3, 4});
+  candidate->cost_ = Cost::Value(42.0);
+  auto estimator = StandardEstimator();
+  auto cache = Cache();
+
+  {
+    auto cost = candidate->EstimatedCost(graph, estimator, cache);
+    ASSERT_TRUE(cost.is_value());
+    ASSERT_EQ(cost.value(), 42.0);
+  }
+}
+
+TEST(CandidatePartition, EstimateCost_Invalid) {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(10, 10), float32]) {
+      %0 = abs(%x);                      //  3
+      %1 = nn.relu(%0);                  //  4
+      nn.relu(%1)                        //  5
+    }
+  )";
+  auto func = MakeTestFunction(kMod);
+  auto graph = DataflowGraph(func);
+  auto spec = AlwaysInvalidSpec();
+  auto candidate = MakeCandidate(graph, spec, {3, 4});
+  auto estimator = StandardEstimator();
+  auto cache = Cache();
+
+  {
+    auto cost = candidate->EstimatedCost(graph, estimator, cache);
+    ASSERT_TRUE(cost.is_invalid());
+  }
+}
+
+TEST(CandidatePartition, EstimateCost_Cached) {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(10, 10), float32]) {
+      %0 = abs(%x);                      //  4
+      %1 = nn.relu(%0);                  //  5
+      %2 = abs(%1);                      //  6
+      %3 = nn.relu(%2);                  //  7
+      add(%1, %3)                        //  8
+    }
+  )";
+  auto func = MakeTestFunction(kMod);
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+  auto candidateA = MakeCandidate(graph, spec, {4, 5});
+  auto candidateB = MakeCandidate(graph, spec, {6, 7});
+  auto standard_estimator = StandardEstimator();
+  auto alternate_estimator = AlternateEstimator();
+  auto cache = Cache();
+
+  {
+    // First candidate estimated as per usual.
+    auto costA = candidateA->EstimatedCost(graph, standard_estimator, cache);
+    ASSERT_TRUE(costA.is_value());
+    ASSERT_EQ(costA.value(), 5.7);
+
+    // Second candidate is structurally equal to first, so reuse first's cost even though
+    // estimator has different weights.
+    auto costB = candidateB->EstimatedCost(graph, alternate_estimator, cache);
+    ASSERT_TRUE(costB.is_value());
+    ASSERT_EQ(costB.value(), costA.value());
+  }
+}
+
+TEST(CandidatePartition, EstimateCost_EtaExpandTuples) {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(10, 10), float32]) {
+      %0 = abs(%x);                      //  3
+      %1 = nn.relu(%0);                  //  5
+      %2 = (%0, %1);                     //  6
+      concatenate(%2)                    //  7
+    }
+  )";
+  auto func = MakeTestFunction(kMod);
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+  auto candidate = MakeCandidate(graph, spec, {7});
+  auto estimator = StandardEstimator();
+  auto cache = Cache();
+
+  {
+    auto cost = candidate->EstimatedCost(graph, estimator, cache);
+    ASSERT_TRUE(cost.is_value());
+    ASSERT_EQ(cost.value(), 3);
+  }
+}
+
+}  // namespace
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc
index fab34cd3d3..51a4970c7e 100644
--- a/tests/cpp/relay/collage/partition_rule_test.cc
+++ b/tests/cpp/relay/collage/partition_rule_test.cc
@@ -38,7 +38,8 @@ Constant MakeConstant(std::initializer_list<ShapeTuple::index_type> shape) {
 
 Function MakeTestFunction(
     const std::string& mod_text,
-    std::initializer_list<std::initializer_list<ShapeTuple::index_type>> constant_shapes) {
+    const std::initializer_list<std::initializer_list<ShapeTuple::index_type>>& constant_shapes =
+        {}) {
   Array<ObjectRef> constants;
   for (const auto& shape : constant_shapes) {
     constants.push_back(MakeConstant(shape));
@@ -58,12 +59,73 @@ Function StandardTestFunction() {
   constexpr const char* kMod = R"(
     #[version = "0.0.5"]
     def @main(%x: Tensor[(10, 10), float32]) {
-      %0 = abs(%x);                      //  3
-      %1 = nn.relu(%0);                  //  4
-      nn.relu(%1)                        //  5
+                                         //  index, kind
+      %0 = abs(%x);                      //  3, E
+      %1 = nn.relu(%0);                  //  4, E
+      nn.relu(%1)                        //  5, E
     }
   )";
-  return MakeTestFunction(kMod, /*constant_shapes=*/{});
+  return MakeTestFunction(kMod);
+}
+
+Function VariantTestFunction() {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(10, 10), float32]) {
+                                         // index, kind
+      %0 = abs(%x);                      // 4, E
+      %1 = add(%0, %x);                  // 5, E
+      shape_of(%1)                       // 6, O
+    }
+  )";
+  return MakeTestFunction(kMod);
+}
+
+Function GPT2ExtractOps() {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(1600, 768), float32]) {
+                                                                               // index, kind
+      %60 = nn.dense(%x, meta[relay.Constant][0] /*(3072, 768)*/, units=3072); // 6,  A
+      %61 = add(%60, meta[relay.Constant][1] /*(3072)*/);                      // 8,  B
+      %62 = reshape(%61, newshape=[50, 32, 3072]);                             // 9,  I
+      %63 = power(%62, 3f);                                                    // 15, B
+      %64 = multiply(%63, 0.044715f);                                          // 17, B
+      %65 = add(%62, %64);                                                     // 18, B
+      %66 = multiply(%65, 0.797885f);                                          // 20, B
+      %67 = tanh(%66);                                                         // 21, E
+      %68 = multiply(%62, 0.5f);                                               // 11, B
+      %69 = add(%67, 1f);                                                      // 23, B
+      multiply(%68, %69)                                                       // 24, B
+    }
+  )";
+  return MakeTestFunction(kMod, {{3072, 768}, {3072}});
+}
+
+Function GPT2ExtractTuples() {
+  constexpr const char* kMod = R"(
+    #[version = "0.0.5"]
+    def @main(%x: Tensor[(50, 32, 2304), float32]) {
+                                                                           // index, kind
+      %19 = split(%x, indices_or_sections=[768, 1536], axis=2);            // 6,  I
+      %23 = %19.1;                                                         // 7
+      %24 = reshape(%23, newshape=[50, 32, 12, 64]);                       // 8,  I
+      %35 = %19.2;                                                         // 11
+      %36 = reshape(%35, newshape=[50, 32, 12, 64]);                       // 12, I
+      %37 = transpose(%36, axes=[0, 2, 1, 3]);                             // 13, I
+      %855 = transpose(%24, axes=[0, 2, 1, 3]);                            // 9,  I
+      %856 = expand_dims(%855, axis=0);                                    // 10, B
+      %857 = expand_dims(%37, axis=0);                                     // 14, B
+      %858 = (%856, %857);                                                 // 15, B
+      concatenate(%858)                                                    // 16, I
+    }
+  )";
+  return MakeTestFunction(kMod);
+}
+
+PartitionSpec StandardSpec(const std::string& spec_name = "test_spec",
+                           const std::string& target = "llvm") {
+  return PartitionSpec(spec_name, Target(target), {});
 }
 
 std::vector<CandidatePartition> ActualCandidates(const DataflowGraph& graph, const Function& func,
@@ -79,12 +141,12 @@ std::vector<CandidatePartition> ActualCandidates(const DataflowGraph& graph, con
 }
 
 std::vector<CandidatePartition> ExpectedCandidates(
-    const DataflowGraph& graph, const runtime::String rule_name, const PartitionSpec& spec,
-    const std::vector<std::vector<PostDfsIndex>> index_sets) {
+    const DataflowGraph& graph, const PartitionSpec& spec,
+    const std::vector<std::vector<PostDfsIndex>>& index_sets) {
   std::vector<CandidatePartition> candidate_partitions;
   for (const auto& indexes : index_sets) {
     auto subgraph = SubGraph(graph, IndexSet(graph.size(), indexes));
-    auto candidate = CandidatePartition(rule_name, subgraph, spec);
+    auto candidate = CandidatePartition(/*rule_name=*/"", subgraph, spec);
     candidate_partitions.emplace_back(std::move(candidate));
   }
   return candidate_partitions;
@@ -98,66 +160,53 @@ void AssertEqual(const std::vector<CandidatePartition>& actual,
                                                                        expected.end());
   ASSERT_EQ(actual_set.size(), expected_set.size());
   for (const auto& actual_candidate : actual_set) {
-    ASSERT_EQ(expected_set.count(actual_candidate), 1);
+    ASSERT_EQ(expected_set.count(actual_candidate), 1) << actual_candidate->ToString();
   }
 }
 
+void AssertEqual(const Expr& actual, const Expr& expected) {
+  ASSERT_TRUE(StructuralEqual()(actual, expected)) << PrettyPrint(actual);
+}
+
 TEST(PartitionRule, DFPatternSingleOp) {
   auto func = StandardTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("llvm");
-  auto spec = PartitionSpec("test_spec", target, {});
+  auto spec = StandardSpec();
 
   {
     auto pattern = IsOp("nn.relu")({IsWildcard()});
     auto rule = DFPatternPartitionRule("relu_pattern", pattern);
-    auto expected_candidates = ExpectedCandidates(graph, "relu_pattern", spec, {{4}, {5}});
 
-    auto candidates = ActualCandidates(graph, func, spec, rule);
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
 
-    ICHECK_EQ(candidates.size(), 2);
-    for (size_t i = 0; i < candidates.size(); i++) {
-      ICHECK(CandidatePartitionEquals()(candidates[i], expected_candidates[i]));
-    }
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}});
+    AssertEqual(actual_candidates, expected_candidates);
   }
 }
 
 TEST(PartitionRule, DFPatternOverlap) {
   auto func = StandardTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("llvm");
-  auto spec = PartitionSpec("test_spec", target, {});
+  auto spec = StandardSpec();
 
   {
     auto pattern =
         IsOp("nn.relu")({IsOp("nn.relu")({IsWildcard()}) || IsOp("abs")({IsWildcard()})});
     auto rule = DFPatternPartitionRule("relu+abs_pattern", pattern);
 
-    auto candidates = ActualCandidates(graph, func, spec, rule);
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
 
-    auto expected_candidates =
-        ExpectedCandidates(graph, "relu+abs_pattern", spec, {{3, 4}, {4, 5}});
-    AssertEqual(candidates, expected_candidates);
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{3, 4}, {4, 5}});
+    AssertEqual(actual_candidates, expected_candidates);
   }
 }
 
 TEST(PartitionRule, Composite) {
   auto func = StandardTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("llvm");
-  auto spec = PartitionSpec("test_spec", target, {});
-
-  {
-    auto pattern = IsOp("nn.relu")({IsWildcard()});
-    auto df_rule = DFPatternPartitionRule("relu_pattern", pattern);
-    auto composite_rule = CompositePartitionRule("composite", df_rule);
-
-    auto candidates = ActualCandidates(graph, func, spec, composite_rule);
-    auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates);
+  auto spec = StandardSpec();
 
-    ICHECK_EQ(candidates.size(), 2);
-
-    constexpr const char* kExpectedMod = R"(
+  constexpr const char* kExpectedMod = R"(
       #[version = "0.0.5"]
       def @main(%x: Tensor[(10, 10), float32]) {
         %0 = abs(%x);
@@ -171,27 +220,28 @@ TEST(PartitionRule, Composite) {
         %3(%2)
       }
     )";
-    Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{});
-    ICHECK(StructuralEqual()(rewrite_expr, expected_expr));
+  Expr expected_expr = MakeTestFunction(kExpectedMod);
+
+  {
+    auto pattern = IsOp("nn.relu")({IsWildcard()});
+    auto df_rule = DFPatternPartitionRule("relu_pattern", pattern);
+    auto composite_rule = CompositePartitionRule("composite", df_rule);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, composite_rule);
+    auto actual_expr = CandidatePartition::ParallelRewrite(graph, actual_candidates);
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}});
+    AssertEqual(actual_candidates, expected_candidates);
+    AssertEqual(actual_expr, expected_expr);
   }
 }
 
 TEST(PartitionRule, PrimitiveTVM) {
   auto func = StandardTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("llvm");
-  auto spec = PartitionSpec("test_spec", target, {});
-
-  {
-    auto pattern = IsOp("nn.relu")({IsWildcard()});
-    auto df_rule = DFPatternPartitionRule("relu_pattern", pattern);
-    auto primitive_rule = PrimitivePartitionRule("primitive", df_rule);
-
-    auto candidates = ActualCandidates(graph, func, spec, primitive_rule);
-    auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates);
+  auto spec = StandardSpec();
 
-    ICHECK_EQ(candidates.size(), 2);
-    constexpr const char* kExpectedMod = R"(
+  constexpr const char* kExpectedMod = R"(
       #[version = "0.0.5"]
       def @main(%x: Tensor[(10, 10), float32]) {
         %0 = abs(%x);
@@ -205,8 +255,19 @@ TEST(PartitionRule, PrimitiveTVM) {
         %3(%2)
       }
     )";
-    Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{});
-    ICHECK(StructuralEqual()(rewrite_expr, expected_expr));
+  Expr expected_expr = MakeTestFunction(kExpectedMod);
+
+  {
+    auto pattern = IsOp("nn.relu")({IsWildcard()});
+    auto df_rule = DFPatternPartitionRule("relu_pattern", pattern);
+    auto primitive_rule = PrimitivePartitionRule("primitive", df_rule);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, primitive_rule);
+    auto actual_expr = CandidatePartition::ParallelRewrite(graph, actual_candidates);
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}});
+    AssertEqual(actual_candidates, expected_candidates);
+    AssertEqual(actual_expr, expected_expr);
   }
 }
 
@@ -216,19 +277,9 @@ TVM_REGISTER_TARGET_KIND("test_ext_codegen", kDLCUDA)
 TEST(PartitionRule, PrimitiveExternal) {
   auto func = StandardTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("test_ext_codegen");
-  auto spec = PartitionSpec("test_ext_codegen", target, {});
-
-  {
-    auto pattern = IsOp("nn.relu")({IsWildcard()});
-    auto df_rule = DFPatternPartitionRule("relu_pattern", pattern);
-    auto primitive_rule = PrimitivePartitionRule("primitive", df_rule);
+  auto spec = StandardSpec("test_ext_codegen", "test_ext_codegen");
 
-    auto candidates = ActualCandidates(graph, func, spec, primitive_rule);
-    auto rewrite_expr = CandidatePartition::ParallelRewrite(graph, candidates);
-
-    ICHECK_EQ(candidates.size(), 2);
-    constexpr const char* kExpectedMod = R"(
+  constexpr const char* kExpectedMod = R"(
       #[version = "0.0.5"]
       def @main(%x: Tensor[(10, 10), float32]) {
         %0 = abs(%x);
@@ -242,16 +293,26 @@ TEST(PartitionRule, PrimitiveExternal) {
         %3(%2)
       }
     )";
-    Expr expected_expr = MakeTestFunction(kExpectedMod, /*constant_shapes=*/{});
-    ICHECK(StructuralEqual()(rewrite_expr, expected_expr));
+  Expr expected_expr = MakeTestFunction(kExpectedMod);
+
+  {
+    auto pattern = IsOp("nn.relu")({IsWildcard()});
+    auto df_rule = DFPatternPartitionRule("relu_pattern", pattern);
+    auto primitive_rule = PrimitivePartitionRule("primitive", df_rule);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, primitive_rule);
+    auto actual_expr = CandidatePartition::ParallelRewrite(graph, actual_candidates);
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}});
+    AssertEqual(actual_candidates, expected_candidates);
+    AssertEqual(actual_expr, expected_expr);
   }
 }
 
 TEST(PartitionRule, Union) {
   auto func = StandardTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("llvm");
-  auto spec = PartitionSpec("test_spec", target, {});
+  auto spec = StandardSpec();
 
   {
     auto abs_pattern = IsOp("abs")({IsWildcard()});
@@ -260,40 +321,391 @@ TEST(PartitionRule, Union) {
     auto relu_rule = DFPatternPartitionRule("relu_pattern", relu_pattern);
     auto union_rule = UnionPartitionRule("union", {abs_rule, relu_rule});
 
-    auto abs_candidates = ExpectedCandidates(graph, "abs_pattern", spec, {{3}});
-    auto relu_candidates = ExpectedCandidates(graph, "relu_pattern", spec, {{4}, {5}});
-
-    auto candidates = ActualCandidates(graph, func, spec, union_rule);
+    auto actual_candidates = ActualCandidates(graph, func, spec, union_rule);
 
-    std::vector<CandidatePartition> expected_candidates;
-    expected_candidates.insert(expected_candidates.end(), abs_candidates.begin(),
-                               abs_candidates.end());
-    expected_candidates.insert(expected_candidates.end(), relu_candidates.begin(),
-                               relu_candidates.end());
-    AssertEqual(candidates, expected_candidates);
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{3}, {4}, {5}});
+    AssertEqual(actual_candidates, expected_candidates);
   }
 }
 
 TEST(PartitionRule, OpCallByKind) {
-  constexpr const char* kMod = R"(
-    #[version = "0.0.5"]
-    def @main(%x: Tensor[(10, 10), float32]) {
-      %0 = abs(%x);                      //  4
-      %1 = add(%0, %x);                  //  5
-      shape_of(%1)                       //  6
-    }
-  )";
-  auto func = MakeTestFunction(kMod, {});
+  auto func = VariantTestFunction();
   auto graph = DataflowGraph(func);
-  Target target("llvm");
-  auto spec = PartitionSpec("test_spec", target, {});
+  auto spec = StandardSpec();
 
   {
     auto rule = OpCallByKindPartitionRule("op_call_by_kind");
-    auto candidates = ActualCandidates(graph, func, spec, rule);
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, {{4}, {5}});
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, Combine_ByKind) {
+  auto func = GPT2ExtractOps();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+
+  {
+    // Prime the system by picking out all 11 calls to non-opaque ops.
+    auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind");
+    // Combine all <= kOutEWiseFusable (A) actual_candidates (ie anything) with downstream
+    // <= kBroadcast (B) actual_candidates (ie B or E).
+    Array<SimpleCombinerRule> simple_rules;
+    simple_rules.push_back(ByKindSimpleCombinerRule(/*upstream_kind=*/kOutEWiseFusable,
+                                                    /*downstream_kind=*/kBroadcast));
+    Array<CombinerRule> combiner_rules;
+    combiner_rules.push_back(AllSimpleCombinerRule("all_simple", std::move(simple_rules)));
+    // Build the overall partition rule.
+    auto rule = CombinePartitionRule("combine_by_kind_A_B", std::move(sub_rule),
+                                     std::move(combiner_rules), /*max_depth=*/3);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    // The original calls.
+    std::vector<std::vector<PostDfsIndex>> expected;
+    expected.push_back({6});
+    expected.push_back({8});
+    expected.push_back({9});
+    expected.push_back({11});
+    expected.push_back({15});
+    expected.push_back({17});
+    expected.push_back({18});
+    expected.push_back({20});
+    expected.push_back({21});
+    expected.push_back({23});
+    expected.push_back({24});
+
+    // nn.dense (A) and the following add (B)
+    expected.push_back({6, 8});
+
+    // reshape (I) and the following power or multiply or both
+    expected.push_back({9, 11});
+    expected.push_back({9, 15});
+    expected.push_back({9, 11, 15});
+
+    // reshape (I) and the following power and multiply
+    expected.push_back({9, 15, 17});
+
+    // reshape (I) and everything after it to the max depth of 3
+    expected.push_back({9, 11, 15, 17});
+
+    // pairs of broadcasts
+    expected.push_back({11, 24});  // multiply / multiply
+    expected.push_back({15, 17});  // power / multiply
+    expected.push_back({17, 18});  // multiply / add
+    expected.push_back({18, 20});  // add / multiply
+    expected.push_back({20, 21});  // multiply / tanh
+    expected.push_back({21, 23});  // tanh / add
+    expected.push_back({23, 24});  // add / multiply
+
+    // triples of broadcasts
+    expected.push_back({15, 17, 18});  // power / multiply / add
+    expected.push_back({17, 18, 20});  // multiply / add / multiply
+    expected.push_back({18, 20, 21});  // add / multiply / tanh
+    expected.push_back({20, 21, 23});  // multiply / tanh / add
+    expected.push_back({21, 23, 24});  // tanh / add / multiply
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, Combine_TupleArg) {
+  auto func = GPT2ExtractTuples();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+
+  {
+    // Prime the system by picking out all 8 calls to non-opaque ops.
+    auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind");
+    // Merge args of tuples of <= injective (I) fields into the call's group.
+    Array<CombinerRule> combiner_rules;
+    combiner_rules.push_back(TupleArgCombinerRule("tuple_arg"));
+    // Build the overall partition rule.
+    auto rule = CombinePartitionRule("combine_tuple_arg", std::move(sub_rule),
+                                     std::move(combiner_rules), /*max_depth=*/3);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    // The original calls
+    std::vector<std::vector<PostDfsIndex>> expected;
+    expected.push_back({6});
+    expected.push_back({8});
+    expected.push_back({9});
+    expected.push_back({10});
+    expected.push_back({12});
+    expected.push_back({13});
+    expected.push_back({14});
+    expected.push_back({16});
+
+    // The concatenate((expand_dims(...), expand_dims(...)) is grouped.
+    expected.push_back({10, 14, 15, 16});
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, Combine_TupleProj) {
+  auto func = GPT2ExtractTuples();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+
+  {
+    // Prime the system by picking out all 8 calls to non-opaque ops.
+    auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind");
+    // Merge projections from injective groups.
+    Array<CombinerRule> combiner_rules;
+    combiner_rules.push_back(TupleProjCombinerRule("tuple_proj"));
+    // Build the overall partition rule.
+    auto rule = CombinePartitionRule("combine_tuple_proj", std::move(sub_rule),
+                                     std::move(combiner_rules), /*max_depth=*/3);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    // The original calls
+    std::vector<std::vector<PostDfsIndex>> expected;
+    expected.push_back({6});
+    expected.push_back({8});
+    expected.push_back({9});
+    expected.push_back({10});
+    expected.push_back({12});
+    expected.push_back({13});
+    expected.push_back({14});
+    expected.push_back({16});
+
+    // split / proj 1
+    expected.push_back({6, 7});
+    // split / proj 2
+    expected.push_back({6, 11});
+    // split and both projections
+    expected.push_back({6, 7, 11});
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, Combine_Constant) {
+  auto func = GPT2ExtractOps();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+
+  {
+    // Prime the system by picking out all 11 calls to non-opaque ops.
+    auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind");
+    // Merge constant args into injective groups
+    Array<CombinerRule> combiner_rules;
+    combiner_rules.push_back(ConstantCombinerRule("constant"));
+    // Build the overall partition rule.
+    auto rule = CombinePartitionRule("combine_constant", std::move(sub_rule),
+                                     std::move(combiner_rules), /*max_depth=*/3);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    // The original calls
+    std::vector<std::vector<PostDfsIndex>> expected;
+    expected.push_back({6});
+    expected.push_back({8});
+    expected.push_back({9});
+    expected.push_back({11});
+    expected.push_back({15});
+    expected.push_back({17});
+    expected.push_back({18});
+    expected.push_back({20});
+    expected.push_back({21});
+    expected.push_back({23});
+    expected.push_back({24});
+
+    // Constant arg to nn.dense
+    expected.push_back({5, 6});
+
+    // Constant arg to add
+    expected.push_back({7, 8});
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, Combine_Mixed) {
+  auto func = GPT2ExtractOps();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+
+  {
+    // Prime the system by picking out all 11 calls to non-opaque ops.
+    auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind");
+
+    // Mimic the FuseOps rules.
+    Array<SimpleCombinerRule> simple_rules;
+    simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast));
+    simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce));
+    simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective));
+    Array<CombinerRule> combiner_rules;
+    combiner_rules.push_back(AllSimpleCombinerRule("all_simple", std::move(simple_rules)));
+
+    // Merge constant args into injective groups
+    combiner_rules.push_back(ConstantCombinerRule("constant"));
+
+    // Build the overall partition rule.
+    auto rule = CombinePartitionRule("combine_mixed", std::move(sub_rule),
+                                     std::move(combiner_rules), /*max_depth=*/3);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    // The original calls
+    std::vector<std::vector<PostDfsIndex>> expected;
+    expected.push_back({6});
+    expected.push_back({8});
+    expected.push_back({9});
+    expected.push_back({11});
+    expected.push_back({15});
+    expected.push_back({17});
+    expected.push_back({18});
+    expected.push_back({20});
+    expected.push_back({21});
+    expected.push_back({23});
+    expected.push_back({24});
+
+    // A -> B merging
+    expected.push_back({6, 8});
+    expected.push_back({9, 11});
+    expected.push_back({9, 15});
+    expected.push_back({9, 11, 15});
+    expected.push_back({9, 15, 17});
+    expected.push_back({9, 11, 15, 17});
+    expected.push_back({11, 24});
+    expected.push_back({15, 17});
+    expected.push_back({17, 18});
+    expected.push_back({18, 20});
+    expected.push_back({20, 21});
+    expected.push_back({21, 23});
+    expected.push_back({23, 24});
+    expected.push_back({15, 17, 18});
+    expected.push_back({17, 18, 20});
+    expected.push_back({18, 20, 21});
+    expected.push_back({20, 21, 23});
+    expected.push_back({21, 23, 24});
+
+    // Constant args
+    expected.push_back({5, 6});
+    expected.push_back({7, 8});
+
+    // B -> R
+    expected.push_back({8, 9});
+    expected.push_back({8, 9, 11});
+    expected.push_back({8, 9, 15});
+
+    // Constant's and A -> B
+    expected.push_back({5, 6, 8});
+    expected.push_back({5, 6, 7, 8});
+
+    // Constants and B -> R
+    expected.push_back({7, 8, 9});
+    expected.push_back({7, 8, 9, 11});
+    expected.push_back({7, 8, 9, 15});
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, OnlyValid) {
+  auto func = GPT2ExtractOps();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
 
-    auto expected_candidates = ExpectedCandidates(graph, "op_call_by_kind", spec, {{4}, {5}});
-    AssertEqual(candidates, expected_candidates);
+  {
+    // Prime the system by picking out all 11 calls to non-opaque ops.
+    auto sub_rule = OpCallByKindPartitionRule("op_call_by_kind");
+    // Combine all <= kOutEWiseFusable (A) actual_candidates (ie anything) with downstream
+    // <= kBroadcast (B) actual_candidates (ie B or E).
+    Array<SimpleCombinerRule> simple_rules;
+    simple_rules.push_back(ByKindSimpleCombinerRule(/*upstream_kind=*/kOutEWiseFusable,
+                                                    /*downstream_kind=*/kBroadcast));
+    Array<CombinerRule> combiner_rules;
+    combiner_rules.push_back(AllSimpleCombinerRule("all_simple", std::move(simple_rules)));
+    auto combine_rule = CombinePartitionRule("combine_by_kind_A_B", std::move(sub_rule),
+                                             std::move(combiner_rules), /*max_depth=*/3);
+    // Only allow up to depth 2, no taps and 1 exit.
+    SubGraphConfig config;
+    config.allow_taps = false;
+    config.max_depth = 2;
+    config.max_exits = 1;
+
+    // Build the overall partition rule.
+    auto rule = OnlyValidPartitionRule("only_valid", std::move(combine_rule), config);
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    // The original calls.
+    std::vector<std::vector<PostDfsIndex>> expected;
+    expected.push_back({6});
+    expected.push_back({8});
+    expected.push_back({9});
+    expected.push_back({11});
+    expected.push_back({15});
+    expected.push_back({17});
+    expected.push_back({18});
+    expected.push_back({20});
+    expected.push_back({21});
+    expected.push_back({23});
+    expected.push_back({24});
+
+    // nn.dense (A) and the following add (B)
+    expected.push_back({6, 8});
+
+    // pairs of broadcasts
+    expected.push_back({11, 24});  // multiply / multiply
+    expected.push_back({15, 17});  // power / multiply
+    expected.push_back({17, 18});  // multiply / add
+    expected.push_back({18, 20});  // add / multiply
+    expected.push_back({20, 21});  // multiply / tanh
+    expected.push_back({21, 23});  // tanh / add
+    expected.push_back({23, 24});  // add / multiply
+
+    // The following candidates are filtered out because they have 2 or 3 exits:
+    // {9, 11}, {9, 15}, {9,11,15}, {9,15,17}, {15,17,18}, {17,18,20},
+    // {18,20,21}, {20,21,23}, {21,23,24}, {9,11,15,17}
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
+  }
+}
+
+TEST(PartitionRule, Host) {
+  auto func = GPT2ExtractTuples();
+  auto graph = DataflowGraph(func);
+  auto spec = StandardSpec();
+
+  {
+    auto rule = HostPartitionRule("host");
+
+    auto actual_candidates = ActualCandidates(graph, func, spec, rule);
+
+    std::vector<std::vector<PostDfsIndex>> expected;
+
+    // Function arg %x
+    expected.push_back({0});
+    // Operators
+    expected.push_back({1});  // concatenate
+    expected.push_back({2});  // expand_dims
+    expected.push_back({3});  // transpose
+    expected.push_back({4});  // reshape
+    expected.push_back({5});  // split
+    // Tuple projection
+    expected.push_back({7});
+    expected.push_back({11});
+    // Tuple construction
+    expected.push_back({15});
+    // The overall @main function
+    expected.push_back({17});
+
+    auto expected_candidates = ExpectedCandidates(graph, spec, expected);
+    AssertEqual(actual_candidates, expected_candidates);
   }
 }