You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/11/10 13:49:27 UTC

[tvm] branch main updated: [TVMC] Global pass context for compile and tune (#13309)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 23ade0c14b [TVMC] Global pass context for compile and tune (#13309)
23ade0c14b is described below

commit 23ade0c14b29c2c2710f6580035878f130eea52b
Author: Luke Hutton <lu...@arm.com>
AuthorDate: Thu Nov 10 13:49:20 2022 +0000

    [TVMC] Global pass context for compile and tune (#13309)
    
    * [TVMC] Global pass context for compile and tune
    
    Comes as a followup from conversations in #13216. By making the pass
    context a global value for both `compile` and `tune` commands, we can
    ensure the pass context is exactly as the user expected and also
    test components such as `convert_graph_layout` under a pass context
    suitable for testing (e.g. add instruments). With this change, it
    becomes the users responsibility to ensure the PassContext they
    select is suitable for the passes that will be run. By default,
    `opt_level` remains as 3 so current workflows that do not alter the pass
    context from the command line / TVMC Python API should not be affected.
    
    Change-Id: I7a601daf6fbe664f77bce1b45efeb7ca29f621b3
    
    * fix vitis-ai test and typo
    
    Change-Id: I04f5bd031ae4717825f42e373bcb0e1e2c1c9d90
---
 python/tvm/driver/tvmc/autotuner.py        | 197 +++++++++++++++--------------
 python/tvm/driver/tvmc/compiler.py         | 118 ++++++++---------
 python/tvm/driver/tvmc/transform.py        |  11 +-
 tests/python/driver/tvmc/conftest.py       |  16 +++
 tests/python/driver/tvmc/test_autotuner.py |  16 +++
 tests/python/driver/tvmc/test_compiler.py  |   8 +-
 tests/python/driver/tvmc/test_frontends.py |  17 ++-
 tests/python/driver/tvmc/test_transform.py |  65 ++++++----
 8 files changed, 247 insertions(+), 201 deletions(-)

diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py
index f9ba427ffa..98293e596b 100644
--- a/python/tvm/driver/tvmc/autotuner.py
+++ b/python/tvm/driver/tvmc/autotuner.py
@@ -389,110 +389,115 @@ def tune_model(
     # model is fixed. For now, creating a clone avoids the issue.
     mod = deepcopy(tvmc_model.mod)
     params = tvmc_model.params
-    if tuning_records is None:
-        tuning_records = tvmc_model.default_tuning_records_path()
-
-    for codegen_from_cli in extra_targets:
-        codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"])
-        partition_function = codegen["pass_pipeline"]
-        mod = partition_function(mod, params, **codegen_from_cli["opts"])
-
-    # min_repeat_ms should be:
-    # a. the value provided by the user, if any, or
-    # b. 0ms in case target is "cpu"; otherwise 1000ms
-    if min_repeat_ms is None:
-        min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000
-        logger.info("Default --min-repeat-ms for this target is %s", min_repeat_ms)
-
-    if rpc_key:
-        if hostname is None or port is None:
-            raise TVMCException(
-                "You must provide a hostname and port to connect to a remote RPC device."
-            )
-        if isinstance(port, str):
-            port = int(port)
-
-        logger.info("Tuning will be performed on device %s at %s:%d.", rpc_key, hostname, port)
-
-        runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner
-        runner = runner_ctor(
-            key=rpc_key,
-            host=hostname,
-            port=port,
-            number=number,
-            repeat=repeat,
-            n_parallel=parallel,
-            timeout=timeout,
-            min_repeat_ms=min_repeat_ms,
-        )
-    else:
-        logger.info("Starting localhost tuning.")
-        runner_ctor = (
-            auto_scheduler.LocalRPCMeasureContext if enable_autoscheduler else autotvm.LocalRunner
-        )
-        local_server = runner_ctor(
-            number=number,
-            repeat=repeat,
-            timeout=timeout,
-            min_repeat_ms=min_repeat_ms,
-        )
 
-        # For autoscheduling on some devices, we need to maintain a LocalRPCMeasureContext object.
-        if enable_autoscheduler:
-            runner = local_server.runner
+    with tvm.transform.PassContext(opt_level=3):
+        if tuning_records is None:
+            tuning_records = tvmc_model.default_tuning_records_path()
+
+        for codegen_from_cli in extra_targets:
+            codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"])
+            partition_function = codegen["pass_pipeline"]
+            mod = partition_function(mod, params, **codegen_from_cli["opts"])
+
+        # min_repeat_ms should be:
+        # a. the value provided by the user, if any, or
+        # b. 0ms in case target is "cpu"; otherwise 1000ms
+        if min_repeat_ms is None:
+            min_repeat_ms = 0 if target.keys[0] == "cpu" else 1000
+            logger.info("Default --min-repeat-ms for this target is %s", min_repeat_ms)
+
+        if rpc_key:
+            if hostname is None or port is None:
+                raise TVMCException(
+                    "You must provide a hostname and port to connect to a remote RPC device."
+                )
+            if isinstance(port, str):
+                port = int(port)
+
+            logger.info("Tuning will be performed on device %s at %s:%d.", rpc_key, hostname, port)
+
+            runner_ctor = auto_scheduler.RPCRunner if enable_autoscheduler else autotvm.RPCRunner
+            runner = runner_ctor(
+                key=rpc_key,
+                host=hostname,
+                port=port,
+                number=number,
+                repeat=repeat,
+                n_parallel=parallel,
+                timeout=timeout,
+                min_repeat_ms=min_repeat_ms,
+            )
         else:
-            runner = local_server
+            logger.info("Starting localhost tuning.")
+            runner_ctor = (
+                auto_scheduler.LocalRPCMeasureContext
+                if enable_autoscheduler
+                else autotvm.LocalRunner
+            )
+            local_server = runner_ctor(
+                number=number,
+                repeat=repeat,
+                timeout=timeout,
+                min_repeat_ms=min_repeat_ms,
+            )
 
-    if enable_autoscheduler:
+            # For autoscheduling on some devices, we need to maintain a
+            # LocalRPCMeasureContext object.
+            if enable_autoscheduler:
+                runner = local_server.runner
+            else:
+                runner = local_server
 
-        tasks, weights = autoscheduler_get_tuning_tasks(
-            mod=mod,
-            params=params,
-            target=target,
-            alter_layout=desired_layout,
-            hardware_params=hardware_params,
-            include_simple_tasks=include_simple_tasks,
-        )
+        if enable_autoscheduler:
 
-        # Create the autoscheduler tuning options
-        tuning_options = auto_scheduler.TuningOptions(
-            num_measure_trials=trials,
-            measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)],
-            runner=runner,
-            early_stopping=early_stopping,
-        )
+            tasks, weights = autoscheduler_get_tuning_tasks(
+                mod=mod,
+                params=params,
+                target=target,
+                alter_layout=desired_layout,
+                hardware_params=hardware_params,
+                include_simple_tasks=include_simple_tasks,
+            )
+
+            # Create the autoscheduler tuning options
+            tuning_options = auto_scheduler.TuningOptions(
+                num_measure_trials=trials,
+                measure_callbacks=[auto_scheduler.RecordToFile(tuning_records)],
+                runner=runner,
+                early_stopping=early_stopping,
+            )
 
-        logger.info("Autoscheduling with configuration: %s", tuning_options)
+            logger.info("Autoscheduling with configuration: %s", tuning_options)
 
-        # Schedule the tasks (i.e., produce a schedule for each task)
-        schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency)
-    else:
-        tasks = autotvm_get_tuning_tasks(
-            mod=mod,
-            params=params,
-            target=target,
-            alter_layout=desired_layout,
-        )
+            # Schedule the tasks (i.e., produce a schedule for each task)
+            schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency)
+        else:
+            tasks = autotvm_get_tuning_tasks(
+                mod=mod,
+                params=params,
+                target=target,
+                alter_layout=desired_layout,
+            )
 
-        # In autotvm, trials is specified per task. We can convert the per-model input
-        # provided to per-task trials by dividing by the number of tasks.
-        trials = int(trials / max(len(tasks), 1))
-        logger.info("Autotuning with %d trials per task.", trials)
-
-        tuning_options = {
-            "tuner": tuner,
-            "trials": trials,
-            "early_stopping": early_stopping,
-            "measure_option": autotvm.measure_option(
-                builder=autotvm.LocalBuilder(build_func="default"), runner=runner
-            ),
-            "tuning_records": prior_records,
-        }
-        logger.info("Autotuning with configuration: %s", tuning_options)
-
-        tune_tasks(tasks, tuning_records, **tuning_options)
-
-    return tuning_records
+            # In autotvm, trials is specified per task. We can convert the per-model input
+            # provided to per-task trials by dividing by the number of tasks.
+            trials = int(trials / max(len(tasks), 1))
+            logger.info("Autotuning with %d trials per task.", trials)
+
+            tuning_options = {
+                "tuner": tuner,
+                "trials": trials,
+                "early_stopping": early_stopping,
+                "measure_option": autotvm.measure_option(
+                    builder=autotvm.LocalBuilder(build_func="default"), runner=runner
+                ),
+                "tuning_records": prior_records,
+            }
+            logger.info("Autotuning with configuration: %s", tuning_options)
+
+            tune_tasks(tasks, tuning_records, **tuning_options)
+
+        return tuning_records
 
 
 def autotvm_get_tuning_tasks(
diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py
index c24d36c432..eec80820cd 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -292,39 +292,42 @@ def compile_model(
 
     config = parse_configs(pass_context_configs)
 
-    if desired_layout:
-        mod = convert_graph_layout(mod, desired_layout)
-
     tvm_target, extra_targets = target_from_cli(target, additional_target_options)
     tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host)
 
+    partition_functions = []
+    partition_opts = []
     for codegen_from_cli in extra_targets:
         codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"])
-        partition_function = codegen["pass_pipeline"]
-
+        partition_functions.append(codegen["pass_pipeline"])
+        partition_opts.append(codegen_from_cli["opts"])
         if codegen["config_key"] is not None:
             config[codegen["config_key"]] = codegen_from_cli["opts"]
-        with tvm.transform.PassContext(config=config):
-            mod = partition_function(mod, params, mod_name=mod_name, **codegen_from_cli["opts"])
-
-    if tuning_records and os.path.exists(tuning_records):
-        logger.debug("tuning records file provided: %s", tuning_records)
-
-        use_autoscheduler = True
-        try:
-            auto_scheduler.load_records(tuning_records)
-        except tvm._ffi.base.TVMError:
-            use_autoscheduler = False
-
-        if use_autoscheduler:
-            with auto_scheduler.ApplyHistoryBest(tuning_records):
-                config["relay.backend.use_auto_scheduler"] = True
-                with tvm.transform.PassContext(
-                    opt_level=opt_level,
-                    config=config,
-                    disabled_pass=disabled_pass,
-                    instruments=instruments,
-                ):
+
+    with tvm.transform.PassContext(
+        opt_level=opt_level,
+        config=config,
+        disabled_pass=disabled_pass,
+        instruments=instruments,
+    ):
+        if desired_layout:
+            mod = convert_graph_layout(mod, desired_layout)
+
+        for partition_function, opts in zip(partition_functions, partition_opts):
+            mod = partition_function(mod, params, mod_name=mod_name, **opts)
+
+        if tuning_records and os.path.exists(tuning_records):
+            logger.debug("tuning records file provided: %s", tuning_records)
+
+            use_autoscheduler = True
+            try:
+                auto_scheduler.load_records(tuning_records)
+            except tvm._ffi.base.TVMError:
+                use_autoscheduler = False
+
+            if use_autoscheduler:
+                with auto_scheduler.ApplyHistoryBest(tuning_records):
+                    config["relay.backend.use_auto_scheduler"] = True
                     logger.debug("building relay graph with autoscheduler")
                     graph_module = build(
                         mod,
@@ -336,14 +339,8 @@ def compile_model(
                         mod_name=mod_name,
                         workspace_pools=workspace_pools,
                     )
-        else:
-            with autotvm.apply_history_best(tuning_records):
-                with tvm.transform.PassContext(
-                    opt_level=opt_level,
-                    config=config,
-                    disabled_pass=disabled_pass,
-                    instruments=instruments,
-                ):
+            else:
+                with autotvm.apply_history_best(tuning_records):
                     logger.debug("building relay graph with tuning records")
                     graph_module = build(
                         mod,
@@ -355,10 +352,7 @@ def compile_model(
                         mod_name=mod_name,
                         workspace_pools=workspace_pools,
                     )
-    else:
-        with tvm.transform.PassContext(
-            opt_level=opt_level, config=config, disabled_pass=disabled_pass, instruments=instruments
-        ):
+        else:
             logger.debug("building relay graph (no tuning records provided)")
             graph_module = build(
                 mod,
@@ -371,32 +365,32 @@ def compile_model(
                 workspace_pools=workspace_pools,
             )
 
-    # Generate output dump files with sources
-    if dump_code is None:
-        dump_code = []
-    if not isinstance(dump_code, list):
-        dump_code = [dump_code]
-    dumps = {}
-    for source_type in dump_code:
-        if use_vm:
-            lib = graph_module.lib
-        else:
-            lib = graph_module.get_lib()
-        # TODO lib.get_source call have inconsistent behavior for unsupported
-        #      formats (@leandron).
-        source = str(mod) if source_type == "relay" else lib.get_source(source_type)
-        dumps[source_type] = source
-
-    # Create a new tvmc model package object from the graph definition.
-    package_path = tvmc_model.export_package(
-        graph_module, package_path, cross, cross_options, output_format
-    )
+        # Generate output dump files with sources
+        if dump_code is None:
+            dump_code = []
+        if not isinstance(dump_code, list):
+            dump_code = [dump_code]
+        dumps = {}
+        for source_type in dump_code:
+            if use_vm:
+                lib = graph_module.lib
+            else:
+                lib = graph_module.get_lib()
+            # TODO lib.get_source call have inconsistent behavior for unsupported
+            #      formats (@leandron).
+            source = str(mod) if source_type == "relay" else lib.get_source(source_type)
+            dumps[source_type] = source
+
+        # Create a new tvmc model package object from the graph definition.
+        package_path = tvmc_model.export_package(
+            graph_module, package_path, cross, cross_options, output_format
+        )
 
-    # Write dumps to file.
-    if dumps:
-        save_dumps(package_path, dumps)
+        # Write dumps to file.
+        if dumps:
+            save_dumps(package_path, dumps)
 
-    return TVMCPackage(package_path)
+        return TVMCPackage(package_path)
 
 
 def build(
diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py
index 51c9e52f21..8527c48b6b 100644
--- a/python/tvm/driver/tvmc/transform.py
+++ b/python/tvm/driver/tvmc/transform.py
@@ -54,10 +54,7 @@ def convert_graph_layout(mod, desired_layout):
         ]
     )
 
-    with transform.PassContext(opt_level=3):
-        try:
-            return seq(mod)
-        except Exception as err:
-            raise TVMCException(
-                "Error converting layout to {0}: {1}".format(desired_layout, str(err))
-            )
+    try:
+        return seq(mod)
+    except Exception as err:
+        raise TVMCException("Error converting layout to {0}: {1}".format(desired_layout, str(err)))
diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py
index 8009448bff..e0dbeebf98 100644
--- a/tests/python/driver/tvmc/conftest.py
+++ b/tests/python/driver/tvmc/conftest.py
@@ -23,6 +23,8 @@ import numpy as np
 
 from PIL import Image
 
+import tvm
+from tvm import relay
 from tvm.driver import tvmc
 
 from tvm.contrib.download import download_testdata
@@ -284,3 +286,17 @@ def relay_text_conv2d(tmpdir_factory):
     with open(file_path, "w") as relay_text:
         relay_text.write(RELAY_MODEL)
     return file_path
+
+
+@pytest.fixture(scope="session")
+def relay_conv2d():
+    """
+    Simple conv2d Relay implementation.
+    """
+    dtype = "float32"
+
+    x = relay.var("x", shape=(1, 4, 2, 2), dtype=dtype)
+    weight = relay.const(np.random.uniform(size=(2, 4, 2, 2)), dtype=dtype)
+    x = relay.nn.conv2d(x, weight)
+    func = relay.Function(relay.analysis.free_vars(x), x)
+    return tvm.IRModule.from_expr(func)
diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py
index 7c05ff804f..eb6550e40c 100644
--- a/tests/python/driver/tvmc/test_autotuner.py
+++ b/tests/python/driver/tvmc/test_autotuner.py
@@ -23,6 +23,7 @@ from unittest import mock
 from os import path
 from pathlib import Path
 
+import tvm
 from tvm import autotvm
 from tvm.driver import tvmc
 
@@ -191,3 +192,18 @@ def test_tune_rpc_tracker_parsing(mock_load_model, mock_tune_model, mock_auto_sc
     assert "10.0.0.1" == kwargs["hostname"]
     assert "port" in kwargs
     assert 9999 == kwargs["port"]
+
+
+@mock.patch("tvm.transform.PassContext", return_value=tvm.transform.PassContext())
+def test_autotune_pass_context(mock_pc, onnx_mnist, tmpdir_factory):
+    """
+    Check that the pass context while tuning is as expected.
+    """
+    pytest.importorskip("onnx")
+
+    tmpdir_name = tmpdir_factory.mktemp("data")
+    _tuner_test_helper(onnx_mnist, "gridsearch", tmpdir_name)
+
+    # AutoTVM overrides the pass context later in the pipeline to disable AlterOpLayout
+    assert mock_pc.call_count == 2
+    assert mock_pc.call_args_list[0][1]["opt_level"] == 3
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index 7cb50dd0e3..3a3f297729 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -508,10 +508,7 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock
     tvmc_model = tvmc.load("no_file_needed")
     tvmc.compile(tvmc_model, target="mockcodegen -testopt=value, llvm")
 
-    assert mock_pc.call_count == 2
-    codegen_partition_context = mock.call(
-        config={"relay.ext.mock.options": {"testopt": "value"}},
-    )
+    assert mock_pc.call_count == 1
     codegen_compile_context = mock.call(
         config={"relay.ext.mock.options": {"testopt": "value"}},
         opt_level=3,
@@ -520,9 +517,6 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock
     )
     mock_pc.assert_has_calls(
         [
-            codegen_partition_context,
-            codegen_partition_context.__enter__(),
-            codegen_partition_context.__exit__(None, None, None),
             codegen_compile_context,
             codegen_compile_context.__enter__(),
             codegen_compile_context.__exit__(None, None, None),
diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py
index c1a3be67c2..718babd15c 100644
--- a/tests/python/driver/tvmc/test_frontends.py
+++ b/tests/python/driver/tvmc/test_frontends.py
@@ -297,7 +297,8 @@ def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant):
     before = tvmc_model.mod
 
     expected_layout = "NCHW"
-    after = tvmc.transform.convert_graph_layout(before, expected_layout)
+    with tvm.transform.PassContext(opt_level=3):
+        after = tvmc.transform.convert_graph_layout(before, expected_layout)
 
     layout_transform_calls = []
 
@@ -322,7 +323,8 @@ def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50):
     before = tvmc_model.mod
 
     expected_layout = "NHWC"
-    after = tvmc.transform.convert_graph_layout(before, expected_layout)
+    with tvm.transform.PassContext(opt_level=3):
+        after = tvmc.transform.convert_graph_layout(before, expected_layout)
 
     layout_transform_calls = []
 
@@ -347,7 +349,8 @@ def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50):
     before = tvmc_model.mod
 
     expected_layout = "NHWC"
-    after = tvmc.transform.convert_graph_layout(before, expected_layout)
+    with tvm.transform.PassContext(opt_level=3):
+        after = tvmc.transform.convert_graph_layout(before, expected_layout)
 
     layout_transform_calls = []
 
@@ -372,7 +375,9 @@ def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_
     before = tvmc_model.mod
 
     expected_layout = "NHWC"
-    after = tvmc.transform.convert_graph_layout(before, expected_layout)
+
+    with tvm.transform.PassContext(opt_level=3):
+        after = tvmc.transform.convert_graph_layout(before, expected_layout)
 
     layout_transform_calls = []
 
@@ -397,7 +402,9 @@ def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50):
     before = tvmc_model.mod
 
     expected_layout = "NCHW"
-    after = tvmc.transform.convert_graph_layout(before, expected_layout)
+
+    with tvm.transform.PassContext(opt_level=3):
+        after = tvmc.transform.convert_graph_layout(before, expected_layout)
 
     layout_transform_calls = []
 
diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py
index 98a0210a1b..98bd3b5f98 100644
--- a/tests/python/driver/tvmc/test_transform.py
+++ b/tests/python/driver/tvmc/test_transform.py
@@ -14,43 +14,60 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
-import numpy as np
+
+from unittest.mock import MagicMock
 
 import tvm
 from tvm import relay
+from tvm.ir.instrument import pass_instrument
 from tvm.driver.tvmc.transform import convert_graph_layout
 
 
-def test_layout_transform():
+def test_layout_transform_fold_constant(relay_conv2d):
     """
     Test layout is correctly transformed and constant folding is applied.
     """
-    dtype = "int8"
-    iinfo = np.iinfo(dtype)
-    data_min = iinfo.min
-    data_max = iinfo.max
-
-    x = relay.var("x", shape=(1, 4, 2, 2), dtype=dtype)
-    weight = relay.const(
-        np.random.randint(data_min, data_max, size=(2, 4, 2, 2), dtype=dtype), dtype=dtype
-    )
-    x = relay.nn.conv2d(x, weight)
-    func = relay.Function(relay.analysis.free_vars(x), x)
-    mod = tvm.IRModule.from_expr(func)
+    desired_layout = "NHWC"
+
+    @pass_instrument
+    class CollectPassNames:
+        def __init__(self):
+            self.names = []
+
+        def run_after_pass(self, _, info):
+            self.names.append(info.name)
+
+    pass_names = CollectPassNames()
+    with tvm.transform.PassContext(opt_level=3, instruments=[pass_names]):
+        convert_graph_layout(relay_conv2d, desired_layout)
 
+    names = pass_names.names
+    assert "ConvertLayout" in names
+    assert "FoldConstant" in names
+    assert names.index("ConvertLayout") < names.index("FoldConstant")
+
+
+def test_layout_transform_convert_layout_pass_args(relay_conv2d, monkeypatch):
+    """
+    Check the convert layout desired layouts arugment is what is expected when
+    a desired layout is provided.
+    """
     desired_layout = "NHWC"
-    mod = convert_graph_layout(mod, desired_layout)
 
-    main_expr = mod["main"].body
-    conv = main_expr.args[0]
-    assert conv.op.name == "nn.conv2d"
-    assert conv.attrs["data_layout"] == "NHWC"
-    assert conv.attrs["kernel_layout"] == "HWIO"
+    mock_convert_layout = MagicMock()
+    mock_convert_layout.return_value = relay.transform.ConvertLayout({})
+    monkeypatch.setattr(relay.transform, "ConvertLayout", mock_convert_layout)
+
+    with tvm.transform.PassContext(opt_level=3):
+        convert_graph_layout(relay_conv2d, desired_layout)
 
-    # Ensure transform has been folded into the constant
-    weights = conv.args[1]
-    assert isinstance(weights, relay.expr.Constant)
+    mock_convert_layout.assert_called_once_with(
+        {
+            "nn.conv2d": ["NHWC", "default"],
+            "nn.conv2d_transpose": ["NHWC", "default"],
+            "qnn.conv2d": ["NHWC", "default"],
+        }
+    )
 
 
 if __name__ == "__main__":