You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2022/07/14 21:30:49 UTC

[tvm] branch main updated: [Collage] CollagePartition pass (#12086)

This is an automated email from the ASF dual-hosted git repository.

mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7661ba8910 [Collage] CollagePartition pass (#12086)
7661ba8910 is described below

commit 7661ba89106fbc67ea3d2e7efe3e846c6d149436
Author: Mark Shields <87...@users.noreply.github.com>
AuthorDate: Thu Jul 14 14:30:44 2022 -0700

    [Collage] CollagePartition pass (#12086)
    
    * [Collage] CollagePartition pass
    
    See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md.
    
    This adds the main CollagePartition pass, which:
     1. Inspects all the targets in the CompilationConfig and builds
        PartitionSpecs describing how to generate speculative CandidatePartitions
        for them.
     2. Runs the above rules on the model to collect all the candidates.
     3. Eliminates candidates whose target contradicts any constraints already
        imposed by, eg, device planning.
     4. Eagerly estimates the cost of each candidate.
     5. Performs a shortest path search to chose an 'optimal' set of candidate
        partitions so as to minimize estimated model latency, such that every sub-expression
        node is contained in exactly one candidate partition.
     6. Coalesces adjacent optimal candidates which ended up on the same target.
     7. Rewrites the model according to the chosen optimal partitioning.
    
    As for the existing partition_for_<external codegen name> methods, the result of
    CollagePartition can then be built using regular TVM.
    
    Very special thanks to @mbaret for authoring test_pass_collage_partition.py.
    
    Logic to prune the candidates after step 3 will be in a follow up PR since it
    deserves its own testing. A demonstration driver will also come as a follow up.
    
    * - lints
    
    * - more lints
    
    * - use the _ffi_api properly
---
 python/tvm/relay/__init__.py                      |   1 +
 python/tvm/relay/collage/__init__.py              |  24 +
 python/tvm/relay/collage/_ffi_api.py              |  21 +
 python/tvm/relay/collage/collage.py               | 146 +++++
 python/tvm/relay/transform/transform.py           |  23 +
 src/relay/collage/candidate_partition.cc          |   3 +-
 src/relay/collage/candidate_partition_index.cc    | 150 ++++++
 src/relay/collage/candidate_partition_index.h     | 102 ++++
 src/relay/collage/collage_partitioner.cc          | 352 ++++++++++++
 src/relay/collage/collage_partitioner.h           |  50 ++
 src/relay/collage/cost_estimator.cc               |   8 +-
 src/relay/collage/cost_estimator.h                |   4 +-
 src/relay/collage/gather_partition_specs.cc       | 214 ++++++++
 src/relay/collage/gather_partition_specs.h        |  71 +++
 src/relay/collage/priority_queue.h                |  72 +++
 src/relay/collage/utils.cc                        |   2 +-
 src/runtime/vm/vm.cc                              |   8 +-
 tests/python/relay/test_pass_collage_partition.py | 617 ++++++++++++++++++++++
 18 files changed, 1854 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 89c8fcb17d..97842738e5 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -32,6 +32,7 @@ from . import scope_builder
 
 from . import transform
 from . import analysis
+from . import collage
 from .build_module import build, create_executor, optimize
 from .transform import build_config
 from . import debug
diff --git a/python/tvm/relay/collage/__init__.py b/python/tvm/relay/collage/__init__.py
new file mode 100644
index 0000000000..18461f25df
--- /dev/null
+++ b/python/tvm/relay/collage/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""relay.collage exports"""
+from .collage import (
+    MEASURE_NUMBER,
+    MEASURE_REPEAT,
+    WARMUP_MIN_REPEAT_MS,
+    CostEstimator,
+    MockEstimator,
+)
diff --git a/python/tvm/relay/collage/_ffi_api.py b/python/tvm/relay/collage/_ffi_api.py
new file mode 100644
index 0000000000..bb5be46c7a
--- /dev/null
+++ b/python/tvm/relay/collage/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for the Collage partitioner."""
+import tvm._ffi
+
+
+tvm._ffi._init_api("relay.collage", __name__)
diff --git a/python/tvm/relay/collage/collage.py b/python/tvm/relay/collage/collage.py
new file mode 100644
index 0000000000..8d1caa9c85
--- /dev/null
+++ b/python/tvm/relay/collage/collage.py
@@ -0,0 +1,146 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Mostly helper methods which interface the main C++ Collage implementation with Python.
+   See relay.transform.CollagePartition for the main Collage entrypoint."""
+
+import logging
+import os
+import math
+import tempfile
+
+import numpy as np
+
+import tvm
+from tvm._ffi.registry import register_func, register_object
+from tvm.runtime import Object
+from . import _ffi_api
+
+# Parameters to use when estimating latency (of both partitions and overall models).
+MEASURE_NUMBER = 20
+MEASURE_REPEAT = 5
+WARMUP_MIN_REPEAT_MS = 250
+
+
+@register_object("relay.collage.CostEstimator")
+class CostEstimator(Object):
+    """CostEstimator class"""
+
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.CostEstimator)
+
+
+@register_object("relay.collage.MockEstimator")
+class MockEstimator(Object):
+    """MockEstimator class"""
+
+    def __init__(self, target_costs):
+        self.__init_handle_by_constructor__(_ffi_api.MockEstimator, target_costs)
+
+
+def arg_for(arg_type, device):
+    """Returns a test argument of Relay arg_type on device"""
+    assert isinstance(arg_type, tvm.ir.TensorType)
+    return tvm.nd.array(
+        np.random.uniform(-1.0, 1.0, size=arg_type.concrete_shape).astype(arg_type.dtype),
+        device=device,
+    )
+
+
+def vm_estimate_seconds(device, the_vm, func_name, args):
+    """Returns the estimated latency, in seconds, of running func_name with args on the_vm."""
+    # Warmup
+    the_vm.benchmark(
+        device, repeat=1, number=1, min_repeat_ms=WARMUP_MIN_REPEAT_MS, func_name=func_name, **args
+    )
+    # One more time, with feeling
+    return the_vm.benchmark(
+        device,
+        repeat=MEASURE_REPEAT,
+        number=MEASURE_NUMBER,
+        min_repeat_ms=0,
+        func_name=func_name,
+        **args,
+    )
+
+
+@register_func("tvm.relay.collage.estimate_seconds")
+def estimate_seconds(mod, target):
+    """Returns the mean execution time of "main" in mod on target with params. The module
+    may contain "Primitive" functions, possibly with "Compiler" attributes."""
+    device = tvm.device(target.kind.device_type)
+
+    try:
+        # Build the module.
+        logging.info("Compiling module to estimate")
+        exe = tvm.relay.vm.compile(mod, target)
+    except RuntimeError as err:
+        # A build failure indicates the partition is not supported.
+        # eg trying to build an nn.batch_norm on GPU, which has no schedule since we assume it
+        # is only ever used with a tuple projection which is rewritten away.
+        logging.info("Assigning module infinite cost since unable to build: %s", err)
+        return math.inf
+
+    # Finalize compilation
+    tmp_dir = tempfile.mkdtemp()
+    code, lib = exe.save()
+    lib_path = os.path.join(tmp_dir, "library.so")
+    # TODO(mbs): Avoid nvcc dependency?
+    lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
+    lib = tvm.runtime.load_module(lib_path)
+    exe = tvm.runtime.vm.Executable.load_exec(code, lib)
+
+    # Benchmark the module.
+    the_vm = tvm.runtime.vm.VirtualMachine(exe, device)
+    func_name = "main"
+    main_args = {v.name_hint: arg_for(v.checked_type, device) for v in mod[func_name].params}
+    logging.info("Benchmarking module to estimate")
+    profile = vm_estimate_seconds(device, the_vm, func_name, main_args)
+    logging.info("profile: %s", profile)
+    return profile.median  # seconds
+
+
+def make_labelled_dfpattern_partition_rule_wrapper(compiler, pattern_tuple):
+    """Returns a DFPatternPartitionRule representing one (label, pattern, predicate) entry from
+    the pattern table for external codegen compiler"""
+    if len(pattern_tuple) == 2:
+        rule_name, dataflow_pattern = pattern_tuple
+        return _ffi_api.MakeLabelledDFPatternPartitionRule(compiler, rule_name, dataflow_pattern)
+    else:
+        rule_name, dataflow_pattern, predicate = pattern_tuple
+        return _ffi_api.MakeLabelledDFPatternPartitionRuleWithPredicate(
+            compiler, rule_name, dataflow_pattern, predicate
+        )
+
+
+@register_func("tvm.relay.collage.make_byoc_partition_rule")
+def make_byoc_partition_rule(compiler):
+    """Returns the PartitionRule for external codegen compiler"""
+    pattern_table = tvm.relay.op.contrib.get_pattern_table(compiler)
+    assert (
+        pattern_table is not None
+    ), f"No pattern table entry was found for BYOC compiler {compiler}"
+    logging.info(
+        "Converting %s rules for %s for use in pattern style BYOC lowering/codegen",
+        len(pattern_table),
+        compiler,
+    )
+    sub_rules = [
+        make_labelled_dfpattern_partition_rule_wrapper(compiler, pattern_tuple)
+        for pattern_tuple in pattern_table
+    ]
+    return _ffi_api.MakePatternBYOCPartitionRule(compiler, sub_rules)
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index d7979a7571..196f7ef812 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -1461,3 +1461,26 @@ def InlineCompilerFunctionsBoundTo(global_vars):
         The pass.
     """
     return _ffi_api.InlineCompilerFunctionsBoundTo(global_vars)
+
+
+def CollagePartition(config, cost_estimator=None):
+    """Partition the bodies of all functions according to the available targets so as to
+    minimize model latency. See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md.
+
+    Parameters
+    ----------
+    config : CompilationConfig
+        The available targets.
+    cost_estimator : CostEstimator, optional
+        The custom cost estimator to use for costing each candidate partition.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The pass.
+
+    """
+    if cost_estimator is None:
+        cost_estimator = relay.collage.CostEstimator()
+
+    return _ffi_api.CollagePartition(config, cost_estimator)
diff --git a/src/relay/collage/candidate_partition.cc b/src/relay/collage/candidate_partition.cc
index 20e29a6d40..2050fbddb1 100644
--- a/src/relay/collage/candidate_partition.cc
+++ b/src/relay/collage/candidate_partition.cc
@@ -191,8 +191,7 @@ Cost CandidatePartitionNode::EstimatedCost(
         VLOG(1) << "Estimating cost of:" << std::endl
                 << PrettyPrint(mod) << std::endl
                 << "using target " << target()->ToDebugString();
-        entry.cost = cost_estimator->Estimate(mod, target(),
-                                              /*needs_tvm_tuning=*/!target().IsExternalCodegen());
+        entry.cost = cost_estimator->Estimate(mod, target());
         VLOG(1) << "Measured cost as " << entry.cost.ToString();
       } else {
         VLOG(1) << "Reusing cost " << entry.cost.ToString()
diff --git a/src/relay/collage/candidate_partition_index.cc b/src/relay/collage/candidate_partition_index.cc
new file mode 100644
index 0000000000..4e90e8829a
--- /dev/null
+++ b/src/relay/collage/candidate_partition_index.cc
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/collage/candidate_partition_index.h
+ * \brief Index for finding relevant candidate partitions for a particular search state.
+ */
+
+#include "./candidate_partition_index.h"
+
+#include "./gather_partition_specs.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+CandidatePartitionIndex::CandidatePartitionIndex(
+    const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices,
+    DataflowGraph* dataflow_graph)
+    : virtual_devices_(virtual_devices),
+      dataflow_graph_(dataflow_graph),
+      first_inside_index_to_candidates_(dataflow_graph->size()) {}
+
+void CandidatePartitionIndex::Index(const Array<PartitionSpec>& partition_specs) {
+  std::vector<CandidatePartition> candidates = Collect(partition_specs);
+
+  // (The candidates could be pruned at this point to elliminate those which are heuristically
+  //  unlikely to appear in the optimal partitioning.)
+
+  // Index the candidates by their first inside index.
+  for (auto& candidate : candidates) {
+    first_inside_index_to_candidates_[candidate->sub_graph_->first_inside_index_].emplace_back(
+        candidate);
+  }
+  size_ = candidates.size();
+}
+
+void CandidatePartitionIndex::EstimateAllCosts(
+    const CostEstimator cost_estimator, const std::shared_ptr<CandidateFunctionCache>& cache) {
+  size_t n = 0;
+  for (PostDfsIndex index = 0; index < dataflow_graph_->size(); ++index) {
+    for (const auto& candidate : first_inside_index_to_candidates_[index]) {
+      LOG(INFO) << "Estimating cost of candidate " << candidate->ToSummary(*dataflow_graph_) << " ["
+                << n++ << "/" << size_ << "]";
+      // Cost will be cached in candidate as a side effect.
+      Cost cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator, cache);
+      LOG(INFO) << "Candidate has cost " << cost.ToString();
+    }
+  }
+}
+
+std::string CandidatePartitionIndex::ToSummary() const {
+  std::vector<std::string> lines;
+  for (const auto& candidates : first_inside_index_to_candidates_) {
+    for (const auto& candidate : candidates) {
+      if (candidate->partition_spec_name() == kHostSpecName) {
+        continue;
+      }
+      lines.emplace_back(candidate->ToSummary(*dataflow_graph_));
+    }
+  }
+  std::sort(lines.begin(), lines.end());
+  std::ostringstream os;
+  bool first = true;
+  for (const auto& line : lines) {
+    if (first) {
+      first = false;
+    } else {
+      os << std::endl;
+    }
+    os << line;
+  }
+  return os.str();
+}
+
+bool CandidatePartitionIndex::IsCompatibleWithVirtualDevice(const CandidatePartition& candidate) {
+  for (PostDfsIndex index : candidate->sub_graph_->inside_) {
+    const ExprNode* sub_expr_node = dataflow_graph_->index_to_node(index)->node_ref_;
+    if (sub_expr_node->IsInstance<OpNode>() || sub_expr_node->IsInstance<ConstructorNode>()) {
+      // These nodes are target/device polymorphic.
+      continue;
+    }
+    auto itr = virtual_devices_->find(sub_expr_node);
+    ICHECK(itr != virtual_devices_->end()) << PrettyPrint(GetRef<Expr>(sub_expr_node));
+    const Target& existing_target = itr->second->target;
+    if (!existing_target.defined()) {
+      // No constraint.
+      continue;
+    }
+    if (StructuralEqual()(existing_target, candidate->target())) {
+      // No disagreement.
+      continue;
+    }
+    if (!candidate->target().IsExternalCodegenFor(itr->second->target)) {
+      // The candidate's target is not an external codegen target compatible with the existing
+      // target.
+      // TODO(mbs): There's a conflict here between Collage's desire to leave some expression nodes
+      // 'behind' on the VM and PlanDevice's desire to assign a primitive Target to every node.
+      // I think PlanDevices is the one that needs to give here by leaving such nodes
+      // unconstrained.
+      VLOG(1) << "Ignoring candidate " << candidate->ToString()
+              << " since incompatible with existing virtual device assignment of:" << std::endl
+              << itr->second << std::endl
+              << "to sub-graph:" << std::endl
+              << PrettyPrint(GetRef<Expr>(sub_expr_node));
+      return false;
+    }
+  }
+  return true;
+}
+
+std::vector<CandidatePartition> CandidatePartitionIndex::Collect(
+    const Array<PartitionSpec>& partition_specs) {
+  VLOG_CONTEXT << "collecting";
+  std::vector<CandidatePartition> result;
+  for (const auto& spec : partition_specs) {
+    VLOG_CONTEXT << "spec " << spec->spec_name_;
+    VLOG(1) << "collecting candidates";
+    std::vector<CandidatePartition> candidates = spec->AllCandidates(*dataflow_graph_);
+    for (auto& candidate : candidates) {
+      if (!IsCompatibleWithVirtualDevice(candidate)) {
+        continue;
+      }
+      result.push_back(candidate);
+    }
+  }
+  VLOG(1) << "Found " << result.size() << " candidates";
+  return result;
+}
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/candidate_partition_index.h b/src/relay/collage/candidate_partition_index.h
new file mode 100644
index 0000000000..aa3f7d4fcd
--- /dev/null
+++ b/src/relay/collage/candidate_partition_index.h
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/collage/candidate_partition_index.h
+ * \brief Index for finding relevant candidate partitions for a particular search state.
+ */
+#ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_
+#define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_
+
+#include <tvm/relay/expr.h>
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "./partition_spec.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief Collects and indexes all the candidate partitions for the overall expression. This index
+ * is used during partitioning search to find the next valid candidate partition to explore from the
+ * current search state. We do not yet attempt to estimate the cost of each candidate partition, and
+ * when we do so during the search we may discover it to be infeasible.
+ */
+class CandidatePartitionIndex {
+ public:
+  CandidatePartitionIndex(const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices,
+                          DataflowGraph* dataflow_graph);
+
+  /*! \brief Constructs the index. */
+  void Index(const Array<PartitionSpec>& partition_specs);
+
+  /*! \brief Returns all the candidates which may begin at \p index. */
+  const std::vector<CandidatePartition>& candidates_at(PostDfsIndex index) const {
+    ICHECK_LT(index, dataflow_graph_->size());
+    return first_inside_index_to_candidates_[index];
+  }
+
+  /*! \brief Estimates the casts of all candidates in the index. Each candidate caches its cost. */
+  void EstimateAllCosts(const CostEstimator cost_estimator,
+                        const std::shared_ptr<CandidateFunctionCache>& cache);
+
+  size_t size() const { return size_; }
+
+  std::string ToSummary() const;
+
+ private:
+  /*!
+   * \brief Returns true if \p candidate's desired target is compatible with any existing target
+   * constraints on the candidate's sub-expressions.
+   */
+  bool IsCompatibleWithVirtualDevice(const CandidatePartition& candidate);
+
+  /*! \brief Returns all valid candidates found from \p partition_specs. */
+  std::vector<CandidatePartition> Collect(const Array<PartitionSpec>& partition_specs);
+
+  /*!
+   * \brief The \p VirtualDevice for every sub-expression in the overall expression. Needed to
+   * ensure candidates do not contradict the target/device placement already determined by
+   * device planning.
+   */
+  const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices_;
+
+  /*! \brief Dataflow graph for overall expression. */
+  DataflowGraph* dataflow_graph_;
+
+  /*!
+   * \brief Maps post-dfs indexes to the all the candidates which have that as their first inside
+   * index, and which should be considered in the Collage search.
+   */
+  std::vector<std::vector<CandidatePartition>> first_inside_index_to_candidates_;
+
+  /*! \brief Number of entries in above. */
+  size_t size_ = 0;
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_
diff --git a/src/relay/collage/collage_partitioner.cc b/src/relay/collage/collage_partitioner.cc
new file mode 100644
index 0000000000..ac038fba2a
--- /dev/null
+++ b/src/relay/collage/collage_partitioner.cc
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/collage_partitioner.cc
+ * \brief Search for an optimal partitioning of a Relay model.
+ */
+
+#include "./collage_partitioner.h"
+
+#include <math.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/function.h>
+#include <tvm/ir/transform.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+#include <tvm/relay/transform.h>
+#include <tvm/target/target.h>
+
+#include "../ir/dataflow_matcher_impl.h"
+#include "../transforms/compiler_function_utils.h"
+#include "../transforms/device_aware_visitors.h"
+#include "./candidate_partition.h"
+#include "./candidate_partition_index.h"
+#include "./cost.h"
+#include "./cost_estimator.h"
+#include "./gather_partition_specs.h"
+#include "./name_supply.h"
+#include "./partition_rule.h"
+#include "./partition_spec.h"
+#include "./priority_queue.h"
+#include "./sub_graph.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+namespace {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.tvm_max_depth", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_max_depth", Integer);
+
+/*!
+ * \brief Represents the overall expression after some number of non-overlapping candidate
+ * partitions have been applied.
+ */
+class SearchState {
+ public:
+  explicit SearchState(IndexSet covered) : covered_(std::move(covered)) {}
+
+  /*!
+   * \brief Order states by increasing best cost, breaking ties by lexicographic order on
+   * the covering sub graph.
+   */
+  bool operator<(const SearchState& that) const {
+    return std::tie(best_cost_, covered_) < std::tie(that.best_cost_, that.covered_);
+  }
+
+  const IndexSet& covered() const { return covered_; }
+
+  std::string ToString() const {
+    std::ostringstream os;
+    os << "State(";
+    os << "covered=" << covered_.ToString();
+    os << ",best_cost=" << best_cost_.ToString();
+    if (best_candidate_.defined()) {
+      os << ",best_candidate=" << best_candidate_->ToString();
+    }
+    os << ")";
+    return os.str();
+  }
+
+ private:
+  /*! \brief Which nodes of overall expression have been placed on all paths to this state. */
+  IndexSet covered_;
+  /*! \brief Predecessor state for sequence of candidates reaching this state with least
+   * cost. Null if initial search state. */
+  SearchState* pred_state_ = nullptr;
+  /*!
+   * \brief Cost of reaching this state using placement implied by path given by pred_state fields.
+   * Includes estimated/measured cost of all candidates plus any candidate launch penalty.
+   * Initially invalid cost.
+   */
+  Cost best_cost_ = Cost::Invalid();
+  /*! \brief Candidate partition selected in transition from pred_state to this state. */
+  CandidatePartition best_candidate_;
+
+  friend class Partitioner;
+};
+
+struct CompareSearchStatePtrs {
+  bool operator()(const SearchState* left, const SearchState* right) const {
+    return *left < *right;
+  }
+};
+
+struct EqualSearchStatePtrs {
+  bool operator()(const SearchState* left, const SearchState* right) const {
+    return left->covered() == right->covered();
+  }
+};
+
+/*!
+ * \brief Finds the optimal partitioning of an expression to candidate partitions.
+ * Though no candidate partitions overlap, it is possible some sub-expressions end up in
+ * no candidate. Those sub-expressions must be evaluated by the host executor (eg VM).
+ */
+class Partitioner {
+ public:
+  explicit Partitioner(Array<PartitionSpec> partition_specs,
+                       const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices,
+                       CostEstimator cost_estimator, std::shared_ptr<CandidateFunctionCache> cache,
+                       Expr expr)
+      : partition_specs_(std::move(partition_specs)),
+        virtual_devices_(virtual_devices),
+        cost_estimator_(std::move(cost_estimator)),
+        cache_(std::move(cache)),
+        expr_(std::move(expr)) {}
+
+  Expr Partition() {
+    // Establish core data structures.
+    dataflow_graph_ = std::make_unique<DataflowGraph>(expr_);
+    VLOG(1) << "Created dataflow graph with " << dataflow_graph_->size() << " nodes";
+
+    // Build the candidate index. This is where all the partition rules are invoked .
+    index_ = std::make_unique<CandidatePartitionIndex>(virtual_devices_, dataflow_graph_.get());
+    index_->Index(partition_specs_);
+    VLOG(1) << "All candidates before search:" << std::endl << index_->ToSummary();
+
+    // 'Eagerly' estimate the cost of all candidates.
+    //
+    // Note if this is not done costs will simply be estimated 'lazily' as the search proceeds.
+    // Typically, some candidates are never explored during the search because:
+    //  - There are no paths in which the candidate does not intersect candidates already
+    //    applied on the path.
+    //  - The Dijkstra search terminates early with a least cost path.
+    // So eager may result in more estimation overhead. However, eager could be made
+    // embarrassingly parallel.
+    VLOG(1) << "Beginning eager cost estimation";
+    index_->EstimateAllCosts(cost_estimator_, cache_);
+    VLOG(1) << "Finished eager cost estimation";
+
+    // Setup initial state.
+    SearchState* init_state = GetState(IndexSet(dataflow_graph_->size()));
+    init_state->best_cost_ = Cost::Zero();
+    pq_.Push(init_state);
+
+    size_t num_transitions = 0;
+
+    VLOG(1) << "#### Commencing Collage search over " << index_->size() << " candidates ####";
+    while (!pq_.empty()) {
+      SearchState* curr_state = pq_.Pop();
+      VLOG(1) << "Looking at state " << curr_state->covered_.ToString();
+      PostDfsIndex next_index = curr_state->covered_.FirstOutsideIndex();
+
+      if (next_index >= dataflow_graph_->size()) {
+        // The entire expression has been explored. Collect the candidates on the optimal path.
+        VLOG(1) << "#### Finished Collage search after exploring " << num_transitions
+                << " transitions ####";
+        std::vector<CandidatePartition> best_candidates;
+        while (curr_state != init_state) {
+          ICHECK(curr_state->best_candidate_.defined());
+          best_candidates.emplace_back(curr_state->best_candidate_);
+          curr_state = curr_state->pred_state_;
+          ICHECK(curr_state != nullptr);
+        }
+        return Finalize(best_candidates);
+      }
+
+      size_t num_fires = 0;
+      Expr sub_expr = dataflow_graph_->index_to_node(next_index)->ref();
+      VLOG(1) << "Looking at index " << next_index << " for sub-expression "
+              << SubExprKindAndLabel(sub_expr).second << " out of " << dataflow_graph_->size()
+              << " total dataflow nodes";
+
+      // Explore all the outgoing candidates from the current state.
+      for (const auto& candidate : index_->candidates_at(next_index)) {
+        VLOG(1) << "Considering candidate " << candidate->ToSummary(*dataflow_graph_)
+                << " for transition " << ++num_transitions << " over " << index_->size()
+                << " total candidates";
+        if (!candidate->sub_graph_->inside_.AreDisjoint(curr_state->covered_)) {
+          LOG(INFO) << "Candidate overlaps with already partitioned nodes";
+          continue;
+        }
+        IndexSet next_covered = curr_state->covered_ | candidate->sub_graph_->inside_;
+        SearchState* next_state = GetState(next_covered);
+        Relax(curr_state, next_state, candidate);
+        ++num_fires;
+      }
+      ICHECK_GT(num_fires, 0)
+          << "No candidate was found covering sub-expression at index " << next_index
+          << ", suggesting the partition rules are incomplete for the given targets.";
+    }
+
+    ICHECK(false) << "should have reached end state in which all sub-expressions are covered";
+    return {};
+  }
+
+  /*! \brief Returns the unique state corresponding to the \p covered sub-graph. */
+  SearchState* GetState(const IndexSet& covered) {
+    auto itr = covered_to_state_.find(covered);
+    if (itr != covered_to_state_.end()) {
+      return itr->second.get();
+    }
+    auto state = std::make_unique<SearchState>(covered);
+    SearchState* raw_ptr = state.get();
+    covered_to_state_.emplace(covered, std::move(state));
+    return raw_ptr;
+  }
+
+  /*!
+   * \brief Record that it is possible to reach \p next_state by choosing \p candidate
+   * in \p curr_state. If the resulting cost is better than the best known so far, update
+   * \p next_state's best cost, predecessor and candidate to match.
+   */
+  void Relax(SearchState* curr_state, SearchState* next_state,
+             const CandidatePartition& candidate) {
+    // Note this may already be cached if the candidate partition costs were 'eagerly' estimated.
+    Cost candidate_cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator_, cache_);
+    VLOG(1) << "Candidate has cost " << candidate_cost.ToString();
+    Cost new_state_cost = candidate_cost + curr_state->best_cost_;
+    const bool is_new = next_state->best_cost_.is_invalid();
+    CandidatePartition previously_best_candidate = next_state->best_candidate_;
+    if (is_new || new_state_cost < next_state->best_cost_) {
+      next_state->pred_state_ = curr_state;
+      Cost previously_best_cost = next_state->best_cost_;
+      next_state->best_cost_ = new_state_cost;
+      next_state->best_candidate_ = candidate;
+      if (is_new) {
+        VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString()
+                << " (New state for spec " << candidate->partition_spec_name() << ")";
+        pq_.Push(next_state);
+      } else {
+        VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString()
+                << " (Spec " << candidate->partition_spec_name() << " beats previous spec "
+                << previously_best_candidate->partition_spec_name() << " by "
+                << (previously_best_cost - curr_state->best_cost_).ToString() << ")";
+        pq_.Update(next_state);
+      }
+    } else {
+      VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString()
+              << " (Spec " << candidate->partition_spec_name() << " does not beat existing spec "
+              << previously_best_candidate->partition_spec_name() << ")";
+    }
+  }
+
+  /*!
+   * \brief Returns the result of partitioning \p expr according to 'optimal' candidates found
+   * by the search.
+   */
+  Expr Finalize(std::vector<CandidatePartition> best_candidates) {
+    best_candidates = CandidatePartition::MaxCoalesce(*dataflow_graph_, best_candidates);
+
+    Cost total_cost = Cost::Zero();
+    std::ostringstream os;
+    os << "Optimal partitioning:" << std::endl;
+    for (const auto& best_candidate : best_candidates) {
+      if (best_candidate->partition_spec_name() == kHostSpecName) {
+        continue;
+      }
+      os << best_candidate->ToSummary(*dataflow_graph_);
+      os << std::endl;
+      total_cost = total_cost + best_candidate->cost_;
+    }
+    os << "Estimated overall cost is " << total_cost.ToString();
+    LOG(INFO) << os.str();
+
+    LOG(INFO) << "All candidates after search:" << std::endl << index_->ToSummary();
+
+    return CandidatePartition::ParallelRewrite(*dataflow_graph_, best_candidates);
+  }
+
+ private:
+  /*! \brief Available partition specs to use during search. */
+  Array<PartitionSpec> partition_specs_;
+  /*!
+   * \brief The virtual devices for every sub-expression so we can respect any existing target
+   * constraints.
+   */
+  const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices_;
+  /*! \brief Cost estimator to use for candidates. */
+  CostEstimator cost_estimator_;
+  /*! \brief Cached names and costs for all partition functions. */
+  std::shared_ptr<CandidateFunctionCache> cache_;
+  /*! \brief The expression we will be partitioning. */
+  Expr expr_;
+  /*! \brief Dataflow graph for overall expression. */
+  std::unique_ptr<DataflowGraph> dataflow_graph_;
+  /*! \brief Index of all avoilable candidates we are searching over. */
+  std::unique_ptr<CandidatePartitionIndex> index_;
+  /*! \brief Map from covered sub-graphs to the corresponding state. */
+  std::unordered_map<IndexSet, std::unique_ptr<SearchState>, IndexSetHash, IndexSetEqual>
+      covered_to_state_;
+  /*! \brief Priority queue of states, ordered by increasing cost. */
+  PriorityQueue<SearchState, CompareSearchStatePtrs, EqualSearchStatePtrs> pq_;
+};
+
+}  // namespace
+
+transform::Pass CollagePartition(CompilationConfig config, CostEstimator cost_estimator) {
+  runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
+      [config = std::move(config), cost_estimator = std::move(cost_estimator)](
+          IRModule mod, transform::PassContext ctxt) {
+        VLOG(1) << "CollagePartition input:" << std::endl << PrettyPrint(mod);
+
+        Array<PartitionSpec> partition_specs = GatherPartitionSpecs(config);
+        VLOG(1) << "Gathered " << partition_specs.size() << " partition specs";
+
+        auto cache =
+            std::make_shared<CandidateFunctionCache>(std::make_shared<NameSupply>("collage"));
+
+        IRModule out_mod = mod->ShallowCopy();
+        for (const auto& kv : mod->functions) {
+          if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
+            auto function = GetRef<Function>(function_node);
+            std::unordered_map<const ExprNode*, VirtualDevice> virtual_devices =
+                transform::RecoverVirtualDeviceMap(mod, function);
+            Partitioner partitioner(partition_specs, &virtual_devices, cost_estimator, cache,
+                                    function);
+            Function result = Downcast<Function>(partitioner.Partition());
+            out_mod->Add(kv.first, result);
+          }
+        }
+
+        out_mod = OutlineCompilerFunctions(cache)(std::move(out_mod));
+        VLOG(1) << "CollagePartition result:" << std::endl << PrettyPrint(out_mod);
+        return out_mod;
+      };
+  return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/0, "CollagePartition", {});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.CollagePartition").set_body_typed(CollagePartition);
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/collage_partitioner.h b/src/relay/collage/collage_partitioner.h
new file mode 100644
index 0000000000..7c8de87ffe
--- /dev/null
+++ b/src/relay/collage/collage_partitioner.h
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/collage/collage_partitioner.h
+ * \brief Search for an optimal partitioning of a Relay model.
+ *
+ * See:
+ *   Collage: Automated Integration of Deep Learning Backends
+ *   Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia
+ *   https://arxiv.org/pdf/2111.00655.pdf
+ */
+#ifndef TVM_RELAY_COLLAGE_COLLAGE_PARTITIONER_H_
+#define TVM_RELAY_COLLAGE_COLLAGE_PARTITIONER_H_
+
+#include <tvm/relay/transform.h>
+
+#include "./cost_estimator.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief Explores the space of all possible (sub-graph, target) pairs which cover the
+ * model, and applies the globally optimal choice (assuming partition costs are additive).
+ */
+transform::Pass CollagePartition(CompilationConfig config, CostEstimator cost_estimator);
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_COLLAGE_PARTITIONER_H_
diff --git a/src/relay/collage/cost_estimator.cc b/src/relay/collage/cost_estimator.cc
index e2ea99ce9b..f8bd0867a3 100644
--- a/src/relay/collage/cost_estimator.cc
+++ b/src/relay/collage/cost_estimator.cc
@@ -39,12 +39,11 @@ CostEstimator::CostEstimator() {
   data_ = std::move(node);
 }
 
-Cost CostEstimatorNode::Estimate(const IRModule& mod, const Target& target,
-                                 bool needs_tvm_turning) const {
+Cost CostEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
   static const runtime::PackedFunc* estimate_seconds =
       runtime::Registry::Get("tvm.relay.collage.estimate_seconds");
   ICHECK(estimate_seconds);
-  const double value = (*estimate_seconds)(mod, target, needs_tvm_turning);
+  const double value = (*estimate_seconds)(mod, target);
   if (std::isinf(value)) {
     return Cost::Invalid();
   } else if (std::isnan(value)) {
@@ -95,8 +94,7 @@ class MockEstimationVisitor : private ExprVisitor {
   }
 };
 
-Cost MockEstimatorNode::Estimate(const IRModule& mod, const Target& target,
-                                 bool needs_tvm_tuning) const {
+Cost MockEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
   double op_cost = static_cast<double>(target_costs_.at(target->kind->name)->value);
   double cost = 0.0;
   for (const auto& kv : mod->functions) {
diff --git a/src/relay/collage/cost_estimator.h b/src/relay/collage/cost_estimator.h
index f433fd5840..15f383a4cd 100644
--- a/src/relay/collage/cost_estimator.h
+++ b/src/relay/collage/cost_estimator.h
@@ -52,7 +52,7 @@ class CostEstimatorNode : public Object {
    * running "main" in \p mod using \p target, which represents a possible partitioning of
    * some overall Relay expression.
    */
-  virtual Cost Estimate(const IRModule& mod, const Target& target, bool needs_tvm_tuning) const;
+  virtual Cost Estimate(const IRModule& mod, const Target& target) const;
 
   static constexpr const char* _type_key = "relay.collage.CostEstimator";
   TVM_DECLARE_BASE_OBJECT_INFO(CostEstimatorNode, Object);
@@ -75,7 +75,7 @@ class CostEstimator : public ObjectRef {
  */
 class MockEstimatorNode : public CostEstimatorNode {
  public:
-  Cost Estimate(const IRModule& mod, const Target& target, bool needs_tvm_tuning) const override;
+  Cost Estimate(const IRModule& mod, const Target& target) const override;
 
   static constexpr const char* _type_key = "relay.collage.MockEstimator";
   TVM_DECLARE_FINAL_OBJECT_INFO(MockEstimatorNode, CostEstimatorNode);
diff --git a/src/relay/collage/gather_partition_specs.cc b/src/relay/collage/gather_partition_specs.cc
new file mode 100644
index 0000000000..7e28367908
--- /dev/null
+++ b/src/relay/collage/gather_partition_specs.cc
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/gather_partition_specs.cc
+ * \brief Gather the relevant \p PartitionSpecs from the available \p Targets.
+ */
+
+#include "./gather_partition_specs.h"
+
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+namespace {
+
+PartitionRule MakeCombinePartitionRule(PartitionRule sub_rule, Array<CombinerRule> combiner_rules,
+                                       size_t max_depth) {
+  if (combiner_rules.empty()) {
+    return sub_rule;
+  } else {
+    return CombinePartitionRule("", std::move(sub_rule), std::move(combiner_rules), max_depth);
+  }
+}
+
+/*! \brief Returns the primitive combiner rules which mimic TVM's \p FuseOps. */
+Array<CombinerRule> TVMCombinerRules() {
+  Array<SimpleCombinerRule> simple_rules;
+  // Mimic the FuseOps rules.
+  simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast));
+  simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce));
+  simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective));
+
+  Array<CombinerRule> combiner_rules;
+  // Fire the simple fusion rules
+  combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules)));
+  // Fuse tuple arguments
+  combiner_rules.push_back(TupleArgCombinerRule("tuple"));
+  // Fuse tuple projection
+  combiner_rules.push_back(TupleProjCombinerRule("proj"));
+
+  return combiner_rules;
+}
+
+size_t GetMaxDepth(std::string key) {
+  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
+  std::string config_key = "relay.collage." + key;
+  Optional<Integer> opt_max_depth = ctxt->GetConfig(config_key, Optional<Integer>());
+  ICHECK(opt_max_depth.defined()) << "missing binding for '" << config_key << " in pass context";
+  ICHECK(opt_max_depth.value()->value > 0)
+      << "invalid value for '" << config_key << " in pass context";
+  return static_cast<size_t>(opt_max_depth.value()->value);
+}
+
+/*! \brief Returns partition rule mimicking TVM FuseOps. */
+PartitionRule MakeTVMPartitionRule() {
+  size_t max_depth = GetMaxDepth("tvm_max_depth");
+  // Build singleton candidates for all calls to ops <= kOutEWiseFusable.
+  OpCallByKindPartitionRule op_call_by_kind("");
+  // Combine candidates according to the TVM fusion rules.
+  PartitionRule combine =
+      MakeCombinePartitionRule(std::move(op_call_by_kind), TVMCombinerRules(), max_depth);
+  // Discard invalid candidates.
+  SubGraphConfig sub_graph_config;
+  sub_graph_config.allow_taps = false;
+  sub_graph_config.max_depth = max_depth;
+  sub_graph_config.max_exits = 1;
+  return OnlyValidPartitionRule("", std::move(combine), sub_graph_config);
+  // NOTE: We don't wrap by a "Primitive" since we want to defer making TVM fusion decisions until
+  // after running more Relay passes.
+}
+
+/*!
+ * \brief Returns the fusion style for \p compiler.
+ *
+ * TODO(mbs): Defer to per-BYOC integration definition.
+ */
+BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
+  if (compiler == "cutlass" || compiler == "cublas" || compiler == "cudnn") {
+    return kNoFusionBYOCStyle;
+  } else if (compiler == "tensorrt") {
+    return kTVMFusionBYOCStyle;
+  } else {
+    return kArbitraryFusionBYOCStyle;
+  }
+}
+
+/*!
+ * \brief Returns the primitive combiner rules which allow for any touching candidates
+ * to be fused provided they don't have kind \p kOpaque.
+ */
+Array<CombinerRule> BYOCCombinerRules(const String& compiler) {
+  Array<SimpleCombinerRule> simple_rules;
+  Array<CombinerRule> combiner_rules;
+  switch (BYOCFusionStyleForCompiler(compiler)) {
+    case kNoFusionBYOCStyle:
+      break;
+    case kTVMFusionBYOCStyle:
+      // Conservatively assume the BYOC toolchain follows the same rules as for TVM's FuseOps.
+      simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast));
+      simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce));
+      simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective));
+      combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules)));
+      break;
+    case kArbitraryFusionBYOCStyle:
+      // Just try all combinations up to the max_depth limit.
+      simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kOutEWiseFusable));
+      combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules)));
+      break;
+  }
+  return combiner_rules;
+}
+
+/*!
+ * \brief Returns partition rule mimicking one entry in the patterns list passed to the
+ * MergeComposite pass.
+ */
+PartitionRule MakeLabelledDFPatternPartitionRule(
+    const std::string& compiler, String rule_name, DFPattern dataflow_pattern,
+    TPatternPredicate predicate = DefaultPatternPredicate) {
+  DFPatternPartitionRule patterns("", std::move(dataflow_pattern), std::move(predicate));
+  return CompositePartitionRule(std::move(rule_name), std::move(patterns));
+}
+
+/*!
+ * \brief Returns partition rule mimicking
+ * MergeComposite/AnnotateTarget/MergeCompilerRegions/PartitionGraph passes for "compiler"
+ * attribute of \p target.
+ */
+PartitionRule MakePatternBYOCPartitionRule(const std::string& compiler,
+                                           Array<PartitionRule> sub_rules) {
+  size_t max_depth = GetMaxDepth("byoc_max_depth");
+  // Union all the individual pattern rules.
+  UnionPartitionRule unioned("", std::move(sub_rules));
+  PartitionRule combine =
+      MakeCombinePartitionRule(std::move(unioned), BYOCCombinerRules(compiler), max_depth);
+  // Ignore invalid candidates.
+  SubGraphConfig sub_graph_config;
+  sub_graph_config.allow_taps = false;
+  sub_graph_config.max_depth = max_depth;
+  sub_graph_config.max_exits = 1;
+  OnlyValidPartitionRule valid("", std::move(combine), sub_graph_config);
+  // Wrap the candidates in a "Primitive" function with a "Compiler" attribute.
+  return PrimitivePartitionRule("", std::move(valid));
+}
+
+TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRule")
+    .set_body_typed(MakeLabelledDFPatternPartitionRule);
+
+TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRuleWithPredicate")
+    .set_body_typed(MakeLabelledDFPatternPartitionRule);
+
+TVM_REGISTER_GLOBAL("relay.collage.MakePatternBYOCPartitionRule")
+    .set_body_typed(MakePatternBYOCPartitionRule);
+
+/*!
+ * \brief Returns the rule to pick out expression nodes which can be 'left behind' for execution
+ * on the host.
+ */
+PartitionRule MakeHostPartitionRule() { return HostPartitionRule(""); }
+
+}  // namespace
+
+Array<PartitionSpec> GatherPartitionSpecs(const CompilationConfig& config) {
+  Array<PartitionSpec> result;
+  for (const auto& primitive_target : config->primitive_targets) {
+    String spec_name = GetSpecName(primitive_target);
+    PartitionRule rule;
+    if (primitive_target.IsExternalCodegen()) {
+      // Transition to the Python side so we can get access to the BYOC pattern registry.
+      // That will bounce right back into the above construction helpers.
+      static const runtime::PackedFunc* make_byoc_partition_rule =
+          runtime::Registry::Get("tvm.relay.collage.make_byoc_partition_rule");
+      ICHECK(make_byoc_partition_rule);
+      rule = (*make_byoc_partition_rule)(spec_name);  // spec_name == primitive_target->kind->name
+      VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for BYOC spec_name "
+              << spec_name << " and has default partition rule:\n"
+              << rule->ToString();
+    } else {
+      rule = MakeTVMPartitionRule();
+      VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for TVM spec_name "
+              << spec_name << " and has default partition rule:\n"
+              << rule->ToString();
+    }
+    result.push_back(PartitionSpec(spec_name, primitive_target, rule));
+  }
+
+  // Add one more spec to cover the host target.
+  result.push_back(PartitionSpec(kHostSpecName, config->host_target, MakeHostPartitionRule()));
+
+  return result;
+}
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/gather_partition_specs.h b/src/relay/collage/gather_partition_specs.h
new file mode 100644
index 0000000000..62ffca27d6
--- /dev/null
+++ b/src/relay/collage/gather_partition_specs.h
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/gather_partition_specs.h
+ * \brief Gather the relevant \p PartitionSpecs from the available \p Targets.
+ */
+#ifndef TVM_RELAY_COLLAGE_GATHER_PARTITION_SPECS_H_
+#define TVM_RELAY_COLLAGE_GATHER_PARTITION_SPECS_H_
+
+#include <tvm/target/compilation_config.h>
+
+#include "./partition_spec.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief The 'styles' of BYOC integrations. Used to influence how their corresponding
+ * partition rule is constructed.
+ */
+enum BYOCStyle {
+  /*!
+   * \brief The BYOC patterns pick out 'ideal' candidates directly, either because:
+   *  - the BYOC toolchain does not perform any fusion so each matched sub-expression maps 1:1 to a
+   *    BYOC-provided operator, or
+   *  - the BYOC toolchain does perform fusion, however the patterns have been written to pick out
+   *    fusable sub-graphs.
+   */
+  kNoFusionBYOCStyle,
+
+  /*!
+   * \brief The BYOC patterns pick out supported operators, but the BYOC backend may perform
+   * fusion over those operators in much the same way TVM does.
+   */
+  kTVMFusionBYOCStyle,
+
+  /*!
+   * \brief The BYOC patterns pick out supported operators, but the BYOC backend may perform
+   * arbitrary fusion over those operators.
+   */
+  kArbitraryFusionBYOCStyle,
+};
+
+/*!
+ * \brief Returns all the partition specifications gathered from the \p Targets in \p config.
+ */
+Array<PartitionSpec> GatherPartitionSpecs(const CompilationConfig& config);
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_GATHER_PARTITION_SPECS_H_
diff --git a/src/relay/collage/priority_queue.h b/src/relay/collage/priority_queue.h
new file mode 100644
index 0000000000..1d30fe5d96
--- /dev/null
+++ b/src/relay/collage/priority_queue.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/priority_queue.h
+ * \brief An updatable priority queue.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_PRIORITY_QUEUE_H_
+#define TVM_RELAY_COLLAGE_PRIORITY_QUEUE_H_
+
+#include <set>
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*! \brief Priority queue of search states, ordered by increasing cost. */
+template <typename T, typename CmpTPtr, typename EqTPtr>
+class PriorityQueue {
+ public:
+  PriorityQueue() = default;
+
+  /*! \brief Pushes \p item onto the queue. */
+  void Push(T* item) { set_.emplace(item); }
+
+  /*! \brief Pops the item with the least cost off the queue. */
+  T* Pop() {
+    ICHECK(!set_.empty());
+    T* item = *set_.begin();
+    set_.erase(set_.begin());
+    return item;
+  }
+
+  /*! \brief Updates the queue to account for \p item's best cost being lowered. */
+  void Update(T* item) {
+    auto itr = std::find_if(set_.begin(), set_.end(),
+                            [item](const T* that) { return EqTPtr()(that, item); });
+    ICHECK(itr != set_.end());
+    set_.erase(itr);
+    set_.emplace(item);
+  }
+
+  bool empty() const { return set_.empty(); }
+  size_t size() const { return set_.size(); }
+
+ private:
+  // TODO(mbs): Actually use a pri-queue datastructure!
+  std::set<T*, CmpTPtr> set_;
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_PRIORITY_QUEUE_H_
diff --git a/src/relay/collage/utils.cc b/src/relay/collage/utils.cc
index 03af980e8c..cad29c4f6e 100644
--- a/src/relay/collage/utils.cc
+++ b/src/relay/collage/utils.cc
@@ -32,7 +32,7 @@ namespace relay {
 namespace collage {
 
 String GetSpecName(const Target& target) {
-  if (TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen).get(target->kind, Bool(false))) {
+  if (target.IsExternalCodegen()) {
     return target->kind->name;
   } else {
     return std::string(kTVMSpecNamePrefix) + target->kind->name;
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 8d03dbf210..6f52f4b83c 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -359,11 +359,11 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Obje
 }
 
 ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
-  DLOG(INFO) << "Executing Function: " << std::endl << func;
+  VLOG(2) << "Executing Function: " << std::endl << func;
   for (int i = 0; i < static_cast<int>(devices_.size()); ++i) {
-    DLOG(INFO) << "Device " << i << " has device type " << devices_[i].device_type
-               << " and device id " << devices_[i].device_id
-               << (i == exec_->host_device_index ? " (using as host device)" : "");
+    VLOG(2) << "Device " << i << " has device type " << devices_[i].device_type << " and device id "
+            << devices_[i].device_id
+            << (i == exec_->host_device_index ? " (using as host device)" : "");
   }
 
   InvokeGlobal(func, args);
diff --git a/tests/python/relay/test_pass_collage_partition.py b/tests/python/relay/test_pass_collage_partition.py
new file mode 100644
index 0000000000..3a8f249af2
--- /dev/null
+++ b/tests/python/relay/test_pass_collage_partition.py
@@ -0,0 +1,617 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+import pytest
+from tvm.relay.transform import CollagePartition, InferType, CapturePostDfsIndexInSpans
+from tvm.target import make_compilation_config
+from tvm.relay.collage import MockEstimator
+from unittest.mock import patch
+from tvm.relay.dataflow_pattern import is_op, wildcard
+
+
+# We'll reuse the target kind "example_target_hook" (registered in
+# src/relay/backend/contrib/example_target_hooks/target.cc) as our
+# example external codegen target.
+
+
+def test_pattern_table():
+    def relu_pattern():
+        return is_op("nn.relu")(wildcard())
+
+    def add_pattern():
+        return is_op("add")(wildcard(), wildcard())
+
+    def concatenate_pattern():
+        return is_op("concatenate")(wildcard())
+
+    def predicate(expr):
+        return True
+
+    return [
+        ("relu", relu_pattern(), predicate),
+        ("add", add_pattern(), predicate),
+        ("concatenate", concatenate_pattern(), predicate),
+    ]
+
+
+def _mock_get_pattern_table(target):
+    if target == "example_target_hook":
+        return test_pattern_table()
+
+
+def run_collage(
+    input_mod, targets, cost_estimator, expected_mod, tvm_max_depth=8, byoc_max_depth=8
+):
+    ctxt = {
+        "relay.collage.tvm_max_depth": tvm_max_depth,
+        "relay.collage.byoc_max_depth": byoc_max_depth,
+    }
+    expected_mod = InferType()(expected_mod)
+    pass_ctxt = tvm.transform.PassContext(config=ctxt)
+    with pass_ctxt:
+        config = make_compilation_config(pass_ctxt, targets)
+        actual_mod = InferType()(input_mod)
+        # Capture indexes only to help debug failing tests
+        actual_mod = CapturePostDfsIndexInSpans()(actual_mod)
+        actual_mod = CollagePartition(config, cost_estimator)(actual_mod)
+
+        if not tvm.ir.structural_equal(actual_mod, expected_mod, map_free_vars=True):
+            # Print everything in full so we can see what's going on when things fail.
+            print("Input module:")
+            print(input_mod)
+            print("Actual module:")
+            print(actual_mod)
+            print("Expected module:")
+            print(expected_mod)
+            # Assert again so as to see the actual disagreeing sub-expressions.
+            tvm.ir.assert_structural_equal(actual_mod, expected_mod, map_free_vars=True)
+
+
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_partition_single_op_llvm(mock_get_pattern_table):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        nn.relu(%x)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+        nn.relu(%x)
+      }
+    """
+    expected_mod = tvm.parser.fromtext(expected_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 1,
+            "example_target_hook": 2,
+        }
+    )
+    run_collage(mod, targets, cost_estimator, expected_mod)
+
+
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_partition_single_op_byoc(mock_get_pattern_table):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        nn.relu(%x)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook_nn_relu(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu") -> Tensor[(10, 10), float32] {
+        %0 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_01)
+        };
+        %0(%FunctionVar_0)
+      }
+
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+        @collage_example_target_hook_nn_relu(%x)
+      }
+    """
+    expected_mod = tvm.parser.fromtext(expected_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 2,
+            "example_target_hook": 1,
+        }
+    )
+    run_collage(mod, targets, cost_estimator, expected_mod)
+
+
+@pytest.mark.parametrize("byoc_max_depth", [1, 3])
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_partition_diamond_valid_topology(mock_get_pattern_table, byoc_max_depth):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = abs(%0);
+        %2 = nn.relu(%1);
+        add(%1, %2)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_3_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook_nn_relu(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu") -> Tensor[(10, 10), float32] {
+        %0 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_01)
+        };
+        %0(%FunctionVar_0)
+      }
+
+      def @collage_example_target_hook_nn_relu_add(%FunctionVar_02: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu_add") -> Tensor[(10, 10), float32] {
+        %1 = fn (%FunctionVar_04: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_04)
+        };
+        %2 = %1(%FunctionVar_02);
+        %3 = fn (%FunctionVar_03: Tensor[(10, 10), float32], %FunctionVar_1: Tensor[(10, 10), float32], Composite="add") -> Tensor[(10, 10), float32] {
+          add(%FunctionVar_03, %FunctionVar_1)
+        };
+        %3(%FunctionVar_02, %2)
+      }
+
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+        %4 = @collage_example_target_hook_nn_relu(%x);
+        %5 = abs(%4);
+        @collage_example_target_hook_nn_relu_add(%5)
+      }
+    """
+    expected_1_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook") -> Tensor[(10, 10), float32] {
+        %0 = fn (%FunctionVar_02: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_02)
+        };
+        %1 = %0(%FunctionVar_0);
+        %2 = fn (%FunctionVar_01: Tensor[(10, 10), float32], %FunctionVar_1: Tensor[(10, 10), float32], Composite="add") -> Tensor[(10, 10), float32] {
+          add(%FunctionVar_01, %FunctionVar_1)
+        };
+        %2(%FunctionVar_0, %1)
+      }
+
+      def @collage_example_target_hook_nn_relu(%FunctionVar_03: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu") -> Tensor[(10, 10), float32] {
+        %3 = fn (%FunctionVar_04: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_04)
+        };
+        %3(%FunctionVar_03)
+      }
+
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+        %4 = @collage_example_target_hook_nn_relu(%x);
+        %5 = abs(%4);
+        @collage_example_target_hook(%5)
+      }
+    """
+    expected_mod = tvm.parser.fromtext(expected_1_txt if byoc_max_depth == 1 else expected_3_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 2,
+            "example_target_hook": 1,
+        }
+    )
+    run_collage(
+        mod, targets, cost_estimator, expected_mod, tvm_max_depth=1, byoc_max_depth=byoc_max_depth
+    )
+
+
+@pytest.mark.parametrize("tvm_max_depth", [1, 2, 3])
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_tvm_max_depth(mock_get_pattern_table, tvm_max_depth):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = nn.relu(%0);
+        nn.relu(%1)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txts = {
+        1: """
+          #[version = "0.0.5"]
+          def @collage_example_target_hook(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook") -> Tensor[(10, 10), float32] {
+            %0 = fn (%FunctionVar_03: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_03)
+            };
+            %1 = %0(%FunctionVar_0);
+            %2 = fn (%FunctionVar_02: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_02)
+            };
+            %3 = %2(%1);
+            %4 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_01)
+            };
+            %4(%3)
+          }
+
+          def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+            @collage_example_target_hook(%x)
+          }
+        """,
+        2: """
+          #[version = "0.0.5"]
+          def @collage_example_target_hook_nn_relu(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu") -> Tensor[(10, 10), float32] {
+            %0 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_01)
+            };
+            %0(%FunctionVar_0)
+          }
+
+          def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+            %1 = @collage_example_target_hook_nn_relu(%x);
+            %2 = nn.relu(%1);
+            nn.relu(%2)
+          }
+        """,
+        3: """
+          #[version = "0.0.5"]
+          def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+            %0 = nn.relu(%x);
+            %1 = nn.relu(%0);
+            nn.relu(%1)
+          }
+        """,
+    }
+    expected_mod = tvm.parser.fromtext(expected_txts[tvm_max_depth])
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 100,
+            "example_target_hook": 99,
+        }
+    )
+    run_collage(
+        mod, targets, cost_estimator, expected_mod, tvm_max_depth=tvm_max_depth, byoc_max_depth=1
+    )
+
+
+@pytest.mark.parametrize("byoc_max_depth", [1, 2, 3])
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_byoc_max_depth(mock_get_pattern_table, byoc_max_depth):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = nn.relu(%0);
+        nn.relu(%1)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txts = {
+        1: """
+          #[version = "0.0.5"]
+          def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+            %0 = nn.relu(%x);
+            %1 = nn.relu(%0);
+            nn.relu(%1)
+          }
+        """,
+        2: """
+          #[version = "0.0.5"]
+          def @collage_example_target_hook_nn_relu_nn_relu(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu_nn_relu") -> Tensor[(10, 10), float32] {
+            %0 = fn (%FunctionVar_02: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_02)
+            };
+            %1 = %0(%FunctionVar_0);
+            %2 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_01)
+            };
+            %2(%1)
+          }
+
+          def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+            %3 = nn.relu(%x);
+            @collage_example_target_hook_nn_relu_nn_relu(%3)
+          }
+        """,
+        3: """
+          #[version = "0.0.5"]
+          def @collage_example_target_hook_nn_relu_nn_relu_nn_relu(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu_nn_relu_nn_relu") -> Tensor[(10, 10), float32] {
+            %0 = fn (%FunctionVar_03: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_03)
+            };
+            %1 = %0(%FunctionVar_0);
+            %2 = fn (%FunctionVar_02: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_02)
+            };
+            %3 = %2(%1);
+            %4 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+              nn.relu(%FunctionVar_01)
+            };
+            %4(%3)
+          }
+
+          def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+            @collage_example_target_hook_nn_relu_nn_relu_nn_relu(%x)
+          }
+        """,
+    }
+    expected_mod = tvm.parser.fromtext(expected_txts[byoc_max_depth])
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 99,
+            "example_target_hook": 100,
+        }
+    )
+    run_collage(
+        mod, targets, cost_estimator, expected_mod, tvm_max_depth=1, byoc_max_depth=byoc_max_depth
+    )
+
+
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_partition_output_tuple(mock_get_pattern_table):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = nn.relu(%0);
+        %2 = abs(%1);
+        (%0, %1, %2)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook") -> (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) {
+        %0 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_01)
+        };
+        %1 = %0(%FunctionVar_0);
+        %2 = fn (%FunctionVar_02: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_02)
+        };
+        %3 = %2(%1);
+        (%1, %3)
+      }
+
+      def @main(%x: Tensor[(10, 10), float32]) -> (Tensor[(10, 10), float32], Tensor[(10, 10), float32], Tensor[(10, 10), float32]) {
+        %4 = @collage_example_target_hook(%x);
+        %5 = %4.1;
+        %6 = %4.0;
+        %7 = abs(%5);
+        (%6, %5, %7)
+      }
+    """
+    expected_mod = tvm.parser.fromtext(expected_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 2,
+            "example_target_hook": 1,
+        }
+    )
+    run_collage(mod, targets, cost_estimator, expected_mod, tvm_max_depth=2, byoc_max_depth=2)
+
+
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_partition_intermediate_tuple(mock_get_pattern_table):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = nn.relu(%0);
+        %2 = (%0, %1);
+        concatenate(%2)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook(%FunctionVar_0: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook") -> (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) {
+        %0 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_01)
+        };
+        %1 = %0(%FunctionVar_0);
+        %2 = fn (%FunctionVar_02: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_02)
+        };
+        %3 = %2(%1);
+        (%1, %3)
+      }
+
+      def @collage_example_target_hook_concatenate(%FunctionVar_03: (Tensor[(10, 10), float32], Tensor[(10, 10), float32]), Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_concatenate") -> Tensor[(20, 10), float32] {
+        %4 = fn (%FunctionVar_04: (Tensor[(10, 10), float32], Tensor[(10, 10), float32]), Composite="concatenate") -> Tensor[(20, 10), float32] {
+          concatenate(%FunctionVar_04)
+        };
+        %4(%FunctionVar_03)
+      }
+        
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(20, 10), float32] {
+        %5 = @collage_example_target_hook(%x);
+        %6 = %5.0;
+        %7 = %5.1;
+        %8 = (%6, %7);
+        @collage_example_target_hook_concatenate(%8)
+      }
+    """
+    expected_mod = tvm.parser.fromtext(expected_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 2,
+            "example_target_hook": 1,
+        }
+    )
+    run_collage(mod, targets, cost_estimator, expected_mod, tvm_max_depth=3, byoc_max_depth=5)
+
+
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_fusion_benefit(mock_get_pattern_table):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = nn.relu(%0);
+        %2 = abs(%x);
+        %3 = nn.relu(%2);
+        %4 = add(%1, %3);
+        %5 = nn.relu(%4);
+        abs(%5)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook_nn_relu_nn_relu_nn_relu_add_nn_relu(%FunctionVar_0: Tensor[(10, 10), float32], %FunctionVar_1: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu_nn_relu_nn_relu_add_nn_relu") -> Tensor[(10, 10), float32] {
+        %0 = fn (%FunctionVar_04: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_04)
+        };
+        %1 = %0(%FunctionVar_0);
+        %2 = fn (%FunctionVar_03: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_03)
+        };
+        %3 = fn (%FunctionVar_05: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_05)
+        };
+        %4 = %2(%1);
+        %5 = %3(%FunctionVar_1);
+        %6 = fn (%FunctionVar_02: Tensor[(10, 10), float32], %FunctionVar_11: Tensor[(10, 10), float32], Composite="add") -> Tensor[(10, 10), float32] {
+          add(%FunctionVar_02, %FunctionVar_11)
+        };
+        %7 = %6(%4, %5);
+        %8 = fn (%FunctionVar_01: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_01)
+        };
+        %8(%7)
+      }
+        
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+        %9 = abs(%x);
+        %10 = @collage_example_target_hook_nn_relu_nn_relu_nn_relu_add_nn_relu(%x, %9);
+        abs(%10)
+      }
+    """
+    expected_mod = tvm.parser.fromtext(expected_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 5,
+            "example_target_hook": 6,
+        }
+    )
+    run_collage(mod, targets, cost_estimator, expected_mod, tvm_max_depth=1, byoc_max_depth=5)
+
+
+@patch("tvm.relay.op.contrib.get_pattern_table", wraps=_mock_get_pattern_table)
+def test_double_residual(mock_get_pattern_table):
+    mod_txt = """
+      #[version = "0.0.5"]
+      def @main(%x: Tensor[(10, 10), float32]) {
+        %0 = nn.relu(%x);
+        %1 = abs(%0);
+        %2 = add(%0, %1);
+        add(%1, %2)
+      }
+    """
+    mod = tvm.parser.fromtext(mod_txt)
+
+    expected_txt = """
+      #[version = "0.0.5"]
+      def @collage_example_target_hook_add_add(%FunctionVar_0: Tensor[(10, 10), float32], %FunctionVar_1: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_add_add") -> Tensor[(10, 10), float32] {
+        %0 = fn (%FunctionVar_02: Tensor[(10, 10), float32], %FunctionVar_12: Tensor[(10, 10), float32], Composite="add") -> Tensor[(10, 10), float32] {
+          add(%FunctionVar_02, %FunctionVar_12)
+        };
+        %1 = %0(%FunctionVar_1, %FunctionVar_0);
+        %2 = fn (%FunctionVar_01: Tensor[(10, 10), float32], %FunctionVar_11: Tensor[(10, 10), float32], Composite="add") -> Tensor[(10, 10), float32] {
+          add(%FunctionVar_01, %FunctionVar_11)
+        };
+        %2(%FunctionVar_0, %1)
+      }
+        
+      def @collage_example_target_hook_nn_relu(%FunctionVar_03: Tensor[(10, 10), float32], Primitive=1, Compiler="example_target_hook", global_symbol="collage_example_target_hook_nn_relu") -> Tensor[(10, 10), float32] {
+        %3 = fn (%FunctionVar_04: Tensor[(10, 10), float32], Composite="relu") -> Tensor[(10, 10), float32] {
+          nn.relu(%FunctionVar_04)
+        };
+        %3(%FunctionVar_03)
+      }
+        
+      def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
+        %4 = @collage_example_target_hook_nn_relu(%x);
+        %5 = abs(%4);
+        @collage_example_target_hook_add_add(%5, %4)
+      } 
+    """
+    expected_mod = tvm.parser.fromtext(expected_txt)
+
+    targets = [
+        tvm.target.Target("llvm"),
+        tvm.target.Target("example_target_hook"),
+    ]
+    cost_estimator = MockEstimator(
+        {
+            "llvm": 2,
+            "example_target_hook": 1,
+        }
+    )
+    run_collage(mod, targets, cost_estimator, expected_mod, tvm_max_depth=4, byoc_max_depth=4)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()