You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/10 11:39:36 UTC

[tvm] branch main updated: [COLLAGE] Add more customization to support more targets (#13450)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new db920ddcde [COLLAGE] Add more customization to support more targets (#13450)
db920ddcde is described below

commit db920ddcde0172c2abb5e2b6ec80cfd8583046f7
Author: krishnaraj36 <45...@users.noreply.github.com>
AuthorDate: Tue Jan 10 17:09:30 2023 +0530

    [COLLAGE] Add more customization to support more targets (#13450)
    
    * [COLLAGE] Add more customization to support more targets
    
    1. Added custom cost module to provide a provision to incorporate
    custom cost estimator python function instead using default cost function.
       eg: cost_estimator = CustomCostEstimator(py_fn_estimator="tvm.relay.collage.opencl_cost_estimator")
           mod = CollagePartition(config, cost_estimator=cost_estimator)(mod)
    2. Added provision to select BYOC fusion style for all compiler target.
       eg : config = { "relay.collage.byoc_fusion_style": ["compiler.NoFusion", "compiler.TVMFusion"]}
            ctxt = tvm.transform.PassContext(config=config)
    
    * Fix the lint errors
    
    * Fix the lint error whitespace
    
    * Fix the lint error tabs
    
    * Fix the lint error tabs
    
    * Fix the lint error tabs
    
    * move the clml collage test case to test_clml
    
    * Fix lint error whitespace
    
    * Fix the import error
    
    * Fix the envirnoment var and import
    
    * Add comments
    
    * Add clml preprocess module in cost estimator
    
    * Fix whitespace lint error
    
    * Fix whitespace lint error
    
    * Fix whitespace lint error
    
    * Fix the comments and removed unwanted code
    
    * Fix whitespace error
    
    * Removed Todo comments
    
    * Removed TODO comments
    
    * Updated naming convension
    
    * Fix typo error
    
    * Fixe the typo error
    
    * Corrected typo error
    
    * Corrected typo error
    
    * Removed unused and fix typo error
    
    * Removed redundent code and optimize the code
    
    * Fix the lint error
    
    * Fix whitespace lint error
    
    * Removed Prints in file
    
    * Fix lint error
    
    * Fix lint error
    
    * Removed runner template in test script
    
    * Fix the lint error
    
    * Fix lint error
    
    * Fix lint error
    
    * Fix the lint error
    
    * Fix the lint error
    
    Co-authored-by: kvegiraj <kv...@qti.qualcomm.com>
---
 python/tvm/relay/collage/__init__.py               |   1 +
 python/tvm/relay/collage/collage.py                |   8 +
 python/tvm/relay/op/contrib/clml.py                |  20 ++
 src/relay/collage/collage_partitioner.cc           |   2 +-
 src/relay/collage/custom_cost_estimator.cc         |  60 ++++
 src/relay/collage/custom_cost_estimator.h          |  67 ++++
 src/relay/collage/gather_partition_specs.cc        |  35 +-
 src/relay/collage/utils.cc                         |  13 +
 src/relay/collage/utils.h                          |   6 +
 .../test_clml/test_adreno_collage_targets.py       | 354 +++++++++++++++++++++
 .../relay/collage/demo_collage_partitioner.py      |   6 +
 11 files changed, 567 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/collage/__init__.py b/python/tvm/relay/collage/__init__.py
index ff0d486606..b3b485ead4 100644
--- a/python/tvm/relay/collage/__init__.py
+++ b/python/tvm/relay/collage/__init__.py
@@ -21,4 +21,5 @@ from .collage import (
     WARMUP_MIN_REPEAT_MS,
     CostEstimator,
     MockCostEstimator,
+    CustomCostEstimator,
 )
diff --git a/python/tvm/relay/collage/collage.py b/python/tvm/relay/collage/collage.py
index 632ab1746f..cfc527c2b9 100644
--- a/python/tvm/relay/collage/collage.py
+++ b/python/tvm/relay/collage/collage.py
@@ -52,6 +52,14 @@ class MockCostEstimator(Object):
         self.__init_handle_by_constructor__(_ffi_api.MockCostEstimator, target_costs, max_estimates)
 
 
+@register_object("relay.collage.CustomCostEstimator")
+class CustomCostEstimator(Object):
+    """CustomEstimator class"""
+
+    def __init__(self, py_fn_estimator="tvm.relay.collage.estimate_seconds_custom"):
+        self.__init_handle_by_constructor__(_ffi_api.CustomCostEstimator, py_fn_estimator)
+
+
 def arg_for(arg_type, device):
     """Returns a test argument of Relay arg_type on device"""
     assert isinstance(arg_type, tvm.ir.TensorType)
diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py
index 77882917b1..e6e535edc0 100644
--- a/python/tvm/relay/op/contrib/clml.py
+++ b/python/tvm/relay/op/contrib/clml.py
@@ -23,6 +23,7 @@ from tvm.ir import Op
 from tvm._ffi import register_func
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
+from tvm.relay import function as _function
 from tvm.relay.expr_functor import ExprMutator
 from tvm.relay.expr import Call, TupleGetItem
 
@@ -161,6 +162,25 @@ def preprocess_module(mod):
     return preprocessed_mod
 
 
+def preprocess_for_clml(mod):
+    """Preprocessing pass to alter the layouts for CLML compiler target"""
+
+    for _var in mod.get_global_vars():
+        if _var.name_hint == "main":
+            continue
+        fn = mod[_var.name_hint]
+        if "Compiler" in fn.attrs.keys() and fn.attrs["Compiler"] == "clml":
+            new_fn = fn.body
+            clml_mod = tvm.IRModule.from_expr(new_fn)
+            with tvm.transform.PassContext(opt_level=3):
+                clml_mod = preprocess_module(clml_mod)
+            new_body = clml_mod["main"].body
+            mod[_var.name_hint] = _function.Function(
+                fn.params, new_body, fn.ret_type, fn.type_params, fn.attrs
+            )
+    return mod
+
+
 @register_pattern_table("clml")
 def clml_pattern_table():
     """Get the CLML pattern table."""
diff --git a/src/relay/collage/collage_partitioner.cc b/src/relay/collage/collage_partitioner.cc
index ac038fba2a..54fc6c45ca 100644
--- a/src/relay/collage/collage_partitioner.cc
+++ b/src/relay/collage/collage_partitioner.cc
@@ -55,7 +55,7 @@ namespace {
 
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.tvm_max_depth", Integer);
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_max_depth", Integer);
-
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_fusion_style", Array<String>);
 /*!
  * \brief Represents the overall expression after some number of non-overlapping candidate
  * partitions have been applied.
diff --git a/src/relay/collage/custom_cost_estimator.cc b/src/relay/collage/custom_cost_estimator.cc
new file mode 100644
index 0000000000..dea4df072c
--- /dev/null
+++ b/src/relay/collage/custom_cost_estimator.cc
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/custom_cost_estimator.cc
+ * \brief A custom CostEstimator to support alternative cost functions.
+ */
+
+#include "./custom_cost_estimator.h"
+
+#include <tvm/relay/expr_functor.h>
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+TVM_REGISTER_OBJECT_TYPE(CustomCostEstimatorNode);
+
+Cost CustomCostEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
+  static const runtime::PackedFunc* estimate_seconds = runtime::Registry::Get(py_fn_estimator_);
+  ICHECK(estimate_seconds);
+  const double value = (*estimate_seconds)(mod, target);
+  if (std::isinf(value)) {
+    return Cost::Invalid();
+  } else if (std::isnan(value)) {
+    return Cost::Unknown();
+  } else {
+    return Cost::Value(value);
+  }
+}
+
+CustomCostEstimator::CustomCostEstimator(String py_fn_estimator) {
+  auto node = make_object<CustomCostEstimatorNode>();
+  node->py_fn_estimator_ = std::move(py_fn_estimator);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("relay.collage.CustomCostEstimator").set_body_typed([](String py_fn_estimator) {
+  return CustomCostEstimator(std::move(py_fn_estimator));
+});
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/collage/custom_cost_estimator.h b/src/relay/collage/custom_cost_estimator.h
new file mode 100644
index 0000000000..4e6b45832e
--- /dev/null
+++ b/src/relay/collage/custom_cost_estimator.h
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/collage/custom_cost_estimator.cc
+ * \brief A custom CostEstimator to support target-specific cost functions.
+ */
+
+#ifndef TVM_RELAY_COLLAGE_CUSTOM_COST_ESTIMATOR_H_
+#define TVM_RELAY_COLLAGE_CUSTOM_COST_ESTIMATOR_H_
+
+#include <tvm/relay/function.h>
+
+#include "./cost.h"
+#include "./cost_estimator.h"
+
+namespace tvm {
+namespace relay {
+namespace collage {
+
+/*!
+ * \brief A cost estimator that uses a target-specific cost function.
+ */
+class CustomCostEstimatorNode : public CostEstimatorNode {
+ public:
+  Cost Estimate(const IRModule& mod, const Target& target) const override;
+
+  static constexpr const char* _type_key = "relay.collage.CustomCostEstimator";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CustomCostEstimatorNode, CostEstimatorNode);
+
+ protected:
+  /*!
+   * \brief Python implemented cost function name.
+   */
+  String py_fn_estimator_;
+
+  friend class CustomCostEstimator;
+};
+
+class CustomCostEstimator : public CostEstimator {
+ public:
+  explicit CustomCostEstimator(String py_fn_estimator);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(CustomCostEstimator, CostEstimator, CustomCostEstimatorNode);
+};
+
+}  // namespace collage
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_COLLAGE_CUSTOM_COST_ESTIMATOR_H_
diff --git a/src/relay/collage/gather_partition_specs.cc b/src/relay/collage/gather_partition_specs.cc
index 7e28367908..ad45167334 100644
--- a/src/relay/collage/gather_partition_specs.cc
+++ b/src/relay/collage/gather_partition_specs.cc
@@ -89,11 +89,9 @@ PartitionRule MakeTVMPartitionRule() {
 }
 
 /*!
- * \brief Returns the fusion style for \p compiler.
- *
- * TODO(mbs): Defer to per-BYOC integration definition.
+ * \brief Returns the fusion style for default compiler.
  */
-BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
+BYOCStyle DefaultBYOCFusionStyleForCompiler(const String& compiler) {
   if (compiler == "cutlass" || compiler == "cublas" || compiler == "cudnn") {
     return kNoFusionBYOCStyle;
   } else if (compiler == "tensorrt") {
@@ -103,6 +101,35 @@ BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
   }
 }
 
+/*!
+ * \brief Returns the fusion style for given compiler.
+ */
+BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
+  tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
+  std::string config_key = "relay.collage.byoc_fusion_style";
+  Optional<Array<String>> byoc_configs = ctxt->GetConfig(config_key, Optional<Array<String>>());
+  BYOCStyle byoc_fusion_style = DefaultBYOCFusionStyleForCompiler(compiler);
+  if (!byoc_configs) {
+    return byoc_fusion_style;
+  }
+  for (auto config_ : byoc_configs.value()) {
+    std::vector<std::string> byoc_cfg = SplitString(config_, ".");
+    if (byoc_cfg[0] == compiler) {
+      if (byoc_cfg[1] == "NoFusion") {
+        byoc_fusion_style = kNoFusionBYOCStyle;
+      } else if (byoc_cfg[1] == "TVMFusion") {
+        byoc_fusion_style = kTVMFusionBYOCStyle;
+      } else if (byoc_cfg[1] == "ArbitraryFusion") {
+        byoc_fusion_style = kArbitraryFusionBYOCStyle;
+      } else {
+        ICHECK(false) << "Invalid fusion name for compiler " << byoc_cfg[0] << " in pass context";
+      }
+      break;
+    }
+  }
+  return byoc_fusion_style;
+}
+
 /*!
  * \brief Returns the primitive combiner rules which allow for any touching candidates
  * to be fused provided they don't have kind \p kOpaque.
diff --git a/src/relay/collage/utils.cc b/src/relay/collage/utils.cc
index cad29c4f6e..451e18c219 100644
--- a/src/relay/collage/utils.cc
+++ b/src/relay/collage/utils.cc
@@ -134,6 +134,19 @@ bool MustBeLowered(const Expr& expr) {
   return false;
 }
 
+std::vector<std::string> SplitString(std::string stmt, const char* del) {
+  std::vector<std::string> str_tokens;
+  int start = 0;
+  int end = stmt.find(del, 0);
+  str_tokens.emplace_back(stmt.substr(start, end));
+  while (end != -1) {
+    stmt = stmt.substr(end + 1, stmt.size());
+    end = stmt.find(del, 0);
+    str_tokens.emplace_back(stmt.substr(start, end));
+  }
+  return str_tokens;
+}
+
 }  // namespace collage
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/collage/utils.h b/src/relay/collage/utils.h
index 4c0493cdd6..630b3b22f1 100644
--- a/src/relay/collage/utils.h
+++ b/src/relay/collage/utils.h
@@ -31,6 +31,7 @@
 #include <tvm/runtime/container/string.h>
 
 #include <string>
+#include <vector>
 
 namespace tvm {
 namespace relay {
@@ -79,6 +80,11 @@ bool IsSpecialOp(const OpNode* op_node);
  */
 bool MustBeLowered(const Expr& expr);
 
+/*!
+ * \brief Returns the list of split strings of given statement with delimiter.
+ */
+std::vector<std::string> SplitString(std::string stmt, const char* del);
+
 }  // namespace collage
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/contrib/test_clml/test_adreno_collage_targets.py b/tests/python/contrib/test_clml/test_adreno_collage_targets.py
new file mode 100644
index 0000000000..d08b76c3b5
--- /dev/null
+++ b/tests/python/contrib/test_clml/test_adreno_collage_targets.py
@@ -0,0 +1,354 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Compares Collage with various other baselines."""
+
+import tvm
+import logging
+import tempfile
+import os
+import shutil
+import numpy as np
+from tvm.relay import testing
+from tvm import rpc
+from tvm.contrib import utils, ndk
+from tvm.relay.build_module import bind_params_by_name
+
+# The following are necessary to force global functions or pattern tables to be registered
+from tvm.relay.collage.collage import *
+from tvm.relay.op.contrib import clml
+import pytest
+
+logging.basicConfig(level=logging.INFO)
+
+
+########### Configuration ###########
+
+###
+### TVM Opencl AutoTvm log file name
+###
+TUNING_LOG = ""
+
+###
+### If true, run all models
+###
+ALL_MODELS = False
+
+###
+### If true, run all configurations
+###
+ALL_CONFIGS = False
+
+###
+### How aggressively to look for candidates?
+###
+TVM_MAX_DEPTH = 8
+BYOC_MAX_DEPTH = 8
+
+###
+### AutoTVM tuning parameters.
+###
+AUTOTVM_NUM_TRIALS = 1024
+AUTOTVM_EARLY_STOPPING = 600
+TIMEOUT = 10
+MEASURE_NUMBER = tvm.relay.collage.MEASURE_NUMBER
+MEASURE_REPEAT = tvm.relay.collage.MEASURE_REPEAT
+WARMUP_MIN_REPEAT_MS = tvm.relay.collage.WARMUP_MIN_REPEAT_MS
+
+##
+## RPC Build configuration
+##
+HOST = tvm.target.Target("llvm -mtriple=arm64-linux-android")
+OPENCL = tvm.target.Target("opencl", HOST)
+RPC_TRACKER_HOST = os.getenv("TVM_TRACKER_HOST", "localhost")
+RPC_TRACKER_PORT = int(os.getenv("TVM_TRACKER_PORT", 9090))
+RPC_KEY = os.getenv("RPC_DEVICE_KEY", "android")
+NDK_CROSS_COMPILER = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++")
+
+
+########### AutoTVM tuning helpers ###########
+
+
+def extract_autotvm_tasks(mod, target):
+    """Returns TVM kernels to tune for mod and target."""
+    return tvm.autotvm.task.extract_from_program(mod, target=target, params=None)
+
+
+def optional_tuning_records(log_filename):
+    """Returns existing tuning records, if any."""
+    if log_filename == "" or not os.path.exists(log_filename):
+        return tvm.autotvm.task.FallbackContext()
+    else:
+        return tvm.autotvm.task.ApplyHistoryBest(log_filename)
+
+
+def is_already_tuned(task, log_filename):
+    """Returns True if we already have a tuning record for task in turning logs in log_filename"""
+    if not os.path.exists(log_filename):
+        return False
+
+    dispatch_context = tvm.autotvm.task.ApplyHistoryBest(log_filename)
+    return dispatch_context._query_inside(task.target, task.workload)
+
+
+def tune_autotvm_tasks(tasks, log_filename):
+    """Appends to log filename the best strategies for tasks"""
+    if len(tasks) == 0:
+        return
+
+    measure_option = tvm.autotvm.measure_option(
+        builder=tvm.autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15),
+        runner=tvm.autotvm.RPCRunner(
+            RPC_KEY, host=RPC_TRACKER_HOST, port=RPC_TRACKER_PORT, number=100, timeout=15
+        ),
+    )
+
+    logging.info(
+        f"Using autotvm tuning for {len(tasks)} tasks with {AUTOTVM_NUM_TRIALS} trials, logging to {log_filename}"
+    )
+
+    # create tmp log file, starting with contents from existing log file
+    tmp_log_filename = log_filename + ".tmp"
+    if os.path.exists(tmp_log_filename):
+        os.remove(tmp_log_filename)
+    if os.path.exists(log_filename):
+        logging.info(f"Copying existing log {log_filename} to {tmp_log_filename}")
+        shutil.copy(log_filename, tmp_log_filename)
+
+    for i, task in enumerate(reversed(tasks)):
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
+        logging.info(f"Considering task {task.name} {prefix}")
+        if is_already_tuned(task, tmp_log_filename):
+            logging.info(f"Re-using existing record for {task.name}")
+            continue
+
+        logging.info(f"Using autotvm to tune {task.name}")
+        tuner_obj = tvm.autotvm.tuner.XGBTuner(task, loss_type="rank")
+        if os.path.exists(tmp_log_filename):
+            tuner_obj.load_history(tvm.autotvm.record.load_from_file(tmp_log_filename))
+
+        # do tuning
+        n_trial = min(AUTOTVM_NUM_TRIALS, len(task.config_space))
+        tuner_obj.tune(
+            n_trial=n_trial,
+            early_stopping=AUTOTVM_EARLY_STOPPING,
+            measure_option=measure_option,
+            callbacks=[
+                tvm.autotvm.callback.progress_bar(n_trial, prefix=prefix),
+                tvm.autotvm.callback.log_to_file(tmp_log_filename),
+            ],
+        )
+
+    # Pick best records and copy back to main log file
+    tvm.autotvm.record.pick_best(tmp_log_filename, log_filename)
+    os.remove(tmp_log_filename)
+
+    logging.info("Done with autotvm tuning")
+
+
+def autotvm_tune_module(mod, target, log_filename):
+    if log_filename == "":
+        logging.info("Not tuning with autotvm since disabled")
+        return
+    # Extract and tune any TVM kernels. BYOC partitions will have no tasks extracted.
+    logging.info("Extracting tasks from overall module")
+    tasks = extract_autotvm_tasks(mod, target)
+    logging.info(f"Auto-tuning {len(tasks)} tasks from overall module")
+    tune_autotvm_tasks(tasks, log_filename)
+
+
+########### Drivers ###########
+
+
+def compile_and_benchmark(label, model, targets, tmp_dir):
+    """Compile model for target and run it with profiling."""
+    logging.info(f"Compiling {model['name']} using {label} with {targets}...")
+    mod = model["mod"]
+    mod = clml.preprocess_for_clml(mod)
+    exe = tvm.relay.vm.compile(mod, target=targets, params=model["params"])
+    lib = exe.mod
+    lib_path = os.path.join(tmp_dir, "lib.so")
+    logging.info(f"Exporting library to {lib_path}...")
+    lib.export_library(lib_path, cc=NDK_CROSS_COMPILER)
+    tracker = rpc.connect_tracker(RPC_TRACKER_HOST, RPC_TRACKER_PORT)
+    remote = tracker.request(RPC_KEY, priority=0, session_timeout=600)
+    ctx = remote.cl(0)
+    remote_path = "lib.so"
+    remote.upload(lib_path, target=remote_path)
+    lib = remote.load_module(remote_path)
+    vm_factory = tvm.runtime.vm.VirtualMachine(lib, ctx)
+    args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod["main"].params}
+    logging.info(f"Benchmarking for {model['name']} generated by {label}...")
+    profile = vm_factory.benchmark(
+        ctx, repeat=MEASURE_REPEAT, number=MEASURE_NUMBER, min_repeat_ms=0, **args
+    )
+    logging.info(f"Benchmarked for {model['name']} generated by {label}: {profile}")
+    logging.info(f"RESULT: {label} | {model['name']} | {profile.median * 1e3}ms")
+
+
+# Custom cost function for Opencl RPC targets.
+@register_func("tvm.relay.collage.opencl_cost_estimator")
+def opencl_cost_estimator(mod, target):
+    mod = clml.preprocess_for_clml(mod) if "clml" == target.kind.name else mod
+    try:
+        # Build the module.
+        logging.info("Compiling module to estimate")
+        exe = tvm.relay.vm.compile(mod, target)
+    except RuntimeError as err:
+        # A build failure indicates the partition is not supported.
+        # eg trying to build an nn.batch_norm on GPU, which has no schedule since we assume it
+        # is only ever used with a tuple projection which is rewritten away.
+        logging.info("Assigning module infinite cost since unable to build: %s", err)
+        return math.inf
+
+    lib = exe.mod
+    tracker = rpc.connect_tracker(RPC_TRACKER_HOST, RPC_TRACKER_PORT)
+    remote = tracker.request(RPC_KEY, priority=0, session_timeout=600)
+    temp = utils.tempdir()
+    dso_binary = "dev_lib_cl.so"
+    dso_binary_path = temp.relpath(dso_binary)
+    ctx = remote.cl(0)
+    lib.export_library(dso_binary_path, cc=NDK_CROSS_COMPILER)
+    remote_path = dso_binary
+    remote.upload(dso_binary_path, target=remote_path)
+    lib = remote.load_module(remote_path)
+
+    vm_factory = tvm.runtime.vm.VirtualMachine(lib, ctx)
+    func_name = "main"
+    main_args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod[func_name].params}
+    cost = vm_factory.benchmark(
+        ctx, repeat=5, number=20, min_repeat_ms=0, func_name=func_name, **main_args
+    )
+    return cost.mean
+
+
+def collage(model):
+    """Run the Collage partitioner for a set of Opencl Adreno related targets and profile the result"""
+    logging.info(f"collage | {model['name']}")
+    logging.info("-------------- BEGIN ORIGINAL --------------")
+    logging.info(model["mod"])
+    logging.info("-------------- END ORIGINAL ----------------")
+    autotvm_tune_module(model["mod"], OPENCL, TUNING_LOG)
+    with optional_tuning_records(TUNING_LOG):
+        targets = []
+        targets.append(OPENCL)
+        use_fp16 = model["main_dtype"] == "float16"
+        tmp_dir = tempfile.mkdtemp()
+        targets.append(tvm.target.Target("clml", HOST))
+
+        # Register byoc fusion style for compiler with available
+        # options [compiler.NoFusion | compiler.TVMFusion | compiler.MaxDepthFusion]
+        config = {
+            "relay.collage.tvm_max_depth": TVM_MAX_DEPTH,
+            "relay.collage.byoc_max_depth": BYOC_MAX_DEPTH,
+            "relay.collage.byoc_fusion_style": ["clml.NoFusion"],
+        }
+        logging.info(f"Using PassContext(config={config}")
+        ctxt = tvm.transform.PassContext(config=config)
+        config = tvm.target.make_compilation_config(ctxt, targets)
+        with ctxt:
+            mod = model["mod"]
+            mod = tvm.relay.transform.CapturePostDfsIndexInSpans()(mod)
+            logging.info("-------------- BEGIN INDEXED --------------")
+            logging.info(mod)
+            logging.info("-------------- END INDEXED ----------------")
+            # Register python custom cost function for targets in
+            # custom cost estimator module.
+            cost_estimator = CustomCostEstimator(
+                py_fn_estimator="tvm.relay.collage.opencl_cost_estimator"
+            )
+            mod = tvm.relay.transform.CollagePartition(config, cost_estimator=cost_estimator)(mod)
+            partitioned_model = model.copy()
+            partitioned_model["mod"] = mod
+            logging.info("-------------- BEGIN PARTITIONED --------------")
+            logging.info(partitioned_model["mod"])
+            logging.info("-------------- END PARTITIONED ----------------")
+            compile_and_benchmark("collage", partitioned_model, targets, tmp_dir)
+
+
+def just_clml(model):
+    """Run partition_for_clml, complete the compilation with TVM, and profile the result."""
+    logging.info(f"just_clml | {model['name']}")
+    logging.info("-------------- BEGIN ORIGINAL --------------")
+    logging.info(model["mod"])
+    logging.info("-------------- END ORIGINAL ----------------")
+    tmp_dir = tempfile.mkdtemp()
+    with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+        logging.info("Partitioning for CLML...")
+        mod = tvm.relay.op.contrib.clml.partition_for_clml(model["mod"], model["params"])
+        partitioned_model = model.copy()
+        partitioned_model["mod"] = mod
+        logging.info("-------------- BEGIN PARTITIONED --------------")
+        logging.info(partitioned_model["mod"])
+        logging.info("-------------- END PARTITIONED ----------------")
+        targets = []
+        targets.append(OPENCL)
+        targets.append(tvm.target.Target("clml", HOST))
+        compile_and_benchmark("just_clml", partitioned_model, targets, tmp_dir)
+
+
+def just_tvm(model):
+    """Compile and profile using vanilla TVM."""
+    logging.info(f"just_tvm | {model['name']}")
+    logging.info("-------------- BEGIN ORIGINAL --------------")
+    logging.info(model["mod"])
+    logging.info("-------------- END ORIGINAL ----------------")
+    tmp_dir = tempfile.mkdtemp()
+    autotvm_tune_module(model["mod"], OPENCL, TUNING_LOG)
+    with optional_tuning_records(TUNING_LOG):
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            compile_and_benchmark("just_tvm", model, OPENCL, tmp_dir)
+
+
+def get_model(model_name, dtype):
+
+    if "mobilenet" in model_name:
+        mod, params = testing.mobilenet.get_workload(batch_size=1, dtype=dtype)
+    elif "resnet" in model_name:
+        mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype=dtype)
+    if params:
+        mod["main"] = bind_params_by_name(mod["main"], params)
+        mod = tvm.relay.transform.FoldConstant()(mod)
+    return {
+        "name": model_name,
+        "input_shapes": {"data": [1, 3, 224, 224]},
+        "input_dtypes": {"data": dtype},
+        "mod": mod,
+        "params": params,
+        "main_dtype": dtype,
+    }
+
+
+########### Runners ###########
+@pytest.mark.parametrize("dtype", ["float32"])
+@tvm.testing.requires_openclml
+def run_resnet50(dtype):
+
+    just_clml(get_model("resnet-50", dtype))
+    just_tvm(get_model("resnet-50", dtype))
+    """Run Collage for tvm and clml compiler target."""
+    collage(get_model("resnet-50", dtype))
+
+
+@pytest.mark.parametrize("dtype", ["float32"])
+@tvm.testing.requires_openclml
+def run_mobilenetv1(dtype):
+
+    just_clml(get_model("mobilenet", dtype))
+    just_tvm(get_model("mobilenet", dtype))
+    """Run Collage for tvm and clml compiler target."""
+    collage(get_model("mobilenet", dtype))
diff --git a/tests/python/relay/collage/demo_collage_partitioner.py b/tests/python/relay/collage/demo_collage_partitioner.py
index 47f2612d7f..2c93145167 100644
--- a/tests/python/relay/collage/demo_collage_partitioner.py
+++ b/tests/python/relay/collage/demo_collage_partitioner.py
@@ -264,6 +264,12 @@ def collage(model):
         config = {
             "relay.collage.tvm_max_depth": TVM_MAX_DEPTH,
             "relay.collage.byoc_max_depth": BYOC_MAX_DEPTH,
+            "relay.collage.byoc_fusion_style": [
+                "cutlass.NoFusion",
+                "cublas.NoFusion",
+                "cudnn.NoFusion",
+                "tensorrt.TVMFusion",
+            ],
         }
         logging.info(f"Using PassContext(config={config}")
         ctxt = tvm.transform.PassContext(config=config)