You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/07/25 09:43:05 UTC

[tvm] branch main updated: [TVMC] Workspace Pools Parameters (#11427)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 75ec1cffa9 [TVMC] Workspace Pools Parameters (#11427)
75ec1cffa9 is described below

commit 75ec1cffa9f160dd3165fedbf4408731ebfa797a
Author: Dhruv Chauhan <89...@users.noreply.github.com>
AuthorDate: Mon Jul 25 10:43:00 2022 +0100

    [TVMC] Workspace Pools Parameters (#11427)
    
    * [TVMC] Workspace Pools Parameters
    
    Attributes from tvmc are now passable into the created PoolInfo objects
    inside WorkspaceMemoryPools. This is passed in to relay.build that get
    attached to IRModule attribute.
    
    * [TVMC] Workspace Pools Parameters
    
    Address comments, fix linting. Testing improved.
    Change-Id: Iea79329b6b9ec1cbc51e5c293449bf6dd43b00c5
    
    * [TVMC] Workspace Pools Parameters
    
    Update workspace pools test naming
    Change-Id: Ib698d6248be1e6f44340f27db3641c985bc5c5d8
    
    * [TVMC] Workspace Pools Parameters
    
    Add test for parameter overrides.
    
    Change-Id: I67d5470dcfbfbc9ab27f34e20a9269d2070193ca
    
    * [TVMC] Workspace Pools Parameters
    
    Rebasing over #10189
    Updates to the way a WorkspaceMemoryPool object is created
    Change-Id: I1f0e1d240343af311ddb3ed5c564cc1ab329f463
    
    * [TVMC] Workspace Pools Parameters
    
    Fix linting, fix CI
    Change-Id: If75f8709ac4ad925655eca54b3e5c1bb09d025e8
    
    * [TVMC] Workspace Pools Parameters
    
    Add mcpu and mattr to target registry for cmsis-nn
    Change-Id: I15257b8d01624c071c738cab6d12ecb84ed6cb16
    
    * [TVMC] Workspace Pools Parameters
    
    Added test for override on single pool when multiple pools are present
    Updated functionality of parsing multiple attributes
    Change-Id: I2c0745051b7a923dd7f75040bfb89bbc99376a11
---
 include/tvm/ir/memory_pools.h                    |   1 +
 python/tvm/driver/tvmc/compiler.py               |  30 +-
 python/tvm/driver/tvmc/workspace_pools.py        | 237 +++++++++++++
 python/tvm/ir/memory_pools.py                    |   2 +-
 src/relay/backend/contrib/cmsisnn/target.cc      |   4 +-
 tests/python/driver/tvmc/test_command_line.py    |  22 ++
 tests/python/driver/tvmc/test_compiler.py        |  22 ++
 tests/python/driver/tvmc/test_workspace_pools.py | 404 +++++++++++++++++++++++
 8 files changed, 717 insertions(+), 5 deletions(-)

diff --git a/include/tvm/ir/memory_pools.h b/include/tvm/ir/memory_pools.h
index ee07841de4..ebab13cf3a 100644
--- a/include/tvm/ir/memory_pools.h
+++ b/include/tvm/ir/memory_pools.h
@@ -65,6 +65,7 @@ struct PoolInfoNode : public Object {
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("pool_name", &pool_name);
+    v->Visit("targets", &targets);
     v->Visit("size_hint_bytes", &size_hint_bytes);
     v->Visit("clock_frequency_hz", &clock_frequency_hz);
     v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py
index 1385044704..2955df5543 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -26,6 +26,7 @@ import tvm
 from tvm import autotvm, auto_scheduler
 from tvm import relay
 from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
+from tvm.ir.memory_pools import WorkspaceMemoryPools
 from tvm.target import Target
 from tvm.relay.backend import Executor, Runtime
 
@@ -37,6 +38,7 @@ from .pass_config import parse_configs
 from .pass_list import parse_pass_list_str
 from .transform import convert_graph_layout
 from .shape_parser import parse_shape_string
+from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate
 
 # pylint: disable=invalid-name
 logger = logging.getLogger("TVMC")
@@ -142,10 +144,11 @@ def add_compile_parser(subparsers, _, json_params):
         default="default",
         help="The output module name. Defaults to 'default'.",
     )
-
     for one_entry in json_params:
         parser.set_defaults(**one_entry)
 
+    generate_workspace_pools_args(parser)
+
 
 def drive_compile(args):
     """Invoke tvmc.compiler module with command line arguments
@@ -161,6 +164,7 @@ def drive_compile(args):
         Zero if successfully completed
 
     """
+
     if not os.path.isfile(args.FILE):
         raise TVMCException(
             f"Input file '{args.FILE}' doesn't exist, is a broken symbolic link, or a directory."
@@ -170,6 +174,9 @@ def drive_compile(args):
 
     dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None
 
+    additional_targets = reconstruct_target_args(args)
+    workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets)
+
     compile_model(
         tvmc_model,
         args.target,
@@ -186,8 +193,11 @@ def drive_compile(args):
         desired_layout=args.desired_layout,
         disabled_pass=args.disabled_pass,
         pass_context_configs=args.pass_config,
-        additional_target_options=reconstruct_target_args(args),
         mod_name=args.module_name,
+        additional_target_options=additional_targets,
+        workspace_pools=(
+            workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
+        ),
     )
 
     return 0
@@ -212,6 +222,7 @@ def compile_model(
     additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
     use_vm: bool = False,
     mod_name: Optional[str] = "default",
+    workspace_pools: Optional[WorkspaceMemoryPools] = None,
 ):
     """Compile a model from a supported framework into a TVM module.
 
@@ -263,6 +274,9 @@ def compile_model(
         Whether to use the VM to compile the model as opposed to the graph executor
     mod_name: str, optional
         The module name
+    workspace_pools: WorkspaceMemoryPools, optional
+        Specification of WorkspacePoolInfo objects to be used as workspace memory in the
+        compilation.
 
     Returns
     -------
@@ -313,6 +327,7 @@ def compile_model(
                         params=params,
                         use_vm=use_vm,
                         mod_name=mod_name,
+                        workspace_pools=workspace_pools,
                     )
         else:
             with autotvm.apply_history_best(tuning_records):
@@ -328,6 +343,7 @@ def compile_model(
                         params=params,
                         use_vm=use_vm,
                         mod_name=mod_name,
+                        workspace_pools=workspace_pools,
                     )
     else:
         with tvm.transform.PassContext(
@@ -342,6 +358,7 @@ def compile_model(
                 params=params,
                 use_vm=use_vm,
                 mod_name=mod_name,
+                workspace_pools=workspace_pools,
             )
 
     # Generate output dump files with sources
@@ -380,6 +397,7 @@ def build(
     params: Dict[str, tvm.nd.NDArray],
     use_vm: bool,
     mod_name: str,
+    workspace_pools: Optional[WorkspaceMemoryPools],
 ):
     """
     Builds the model with the provided executor.
@@ -408,7 +426,13 @@ def build(
         return relay.vm.compile(mod, target=tvm_target, params=params)
     logger.debug("building with relay build")
     return relay.build(
-        mod, target=tvm_target, executor=executor, runtime=runtime, params=params, mod_name=mod_name
+        mod,
+        target=tvm_target,
+        executor=executor,
+        runtime=runtime,
+        params=params,
+        mod_name=mod_name,
+        workspace_memory_pools=workspace_pools,
     )
 
 
diff --git a/python/tvm/driver/tvmc/workspace_pools.py b/python/tvm/driver/tvmc/workspace_pools.py
new file mode 100644
index 0000000000..2c91488fb4
--- /dev/null
+++ b/python/tvm/driver/tvmc/workspace_pools.py
@@ -0,0 +1,237 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Functions for processing dynamic workspace pool TVMC args
+"""
+
+
+import logging
+import re
+
+from tvm.driver.tvmc import TVMCException
+from tvm.target import Target
+from tvm.ir.memory_pools import PoolInfoProperties, WorkspaceMemoryPools, WorkspacePoolInfo
+
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+
+def generate_workspace_pools_args(parser):
+    """Generates arguments for each Workspace Pools's options"""
+    parser.add_argument(
+        "--workspace-pools",
+        help="""The name of the memory pool
+                Example usage: --workspace-pools=flash""",
+    )
+    parser.add_argument(
+        "--workspace-pools-targets",
+        help="""The name of the targets specified for the memory pool
+                Example usage: --workspace-pools-targets=flash:llvm""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-size-hint-bytes",
+        nargs="?",
+        help="""The expected size hint to be used by the allocator.
+                Example usage: --workspace-pools-size-hint-bytes=flash:8""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-clock-frequency-hz",
+        nargs="?",
+        help="""The clock frequency that the memory pool runs at in Hz.
+                Example usage: --workspace-pools-clock-frequency-hz=flash:70000000""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-read-bandwidth-bytes-per-cycle",
+        nargs="?",
+        help="""The read bandwidth of the memory pool in bytes/cycle.
+                Example usage: --workspace-pools-read-bandwidth-bytes-per-cycle=flash:4""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-write-bandwidth-bytes-per-cycle",
+        nargs="?",
+        help="""The write bandwidth of the memory pool in bytes/cycle.
+                Example usage: --workspace-pools-write-bandwidth-bytes-per-cycle=flash:8""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-read-latency-cycles",
+        nargs="?",
+        help="""The read latency of the memory pool in cycles.
+                Example usage: --workspace-pools-read-latency-cycles=flash:4""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-write-latency-cycles",
+        nargs="?",
+        help="""The write latency of the memory pool in cycles.
+                Example usage: --workspace-pools-write-latency-cycles=flash:8""",
+        action="append",
+    )
+    parser.add_argument(
+        "--workspace-pools-target-burst-bytes",
+        help="""The burst length of the memory pool in bytes per target.
+                Example usage: --workspace-pools-target-burst-bytes=flash:accel:1""",
+        action="append",
+    )
+
+
+def _parse_target_burst(attr_str, pool_name):
+    if pool_name not in attr_str:
+        return {}
+
+    return {target: int(attr_str[pool_name][target]) for target in attr_str[pool_name]}
+
+
+def _parse_target_string(attr_str, targets, pool_name):
+    if attr_str is None:
+        raise TVMCException(f'No target specified for Workspace Pool "{pool_name}"')
+
+    target_name = [re.split(",", attr_str)]
+    matched_targets = [
+        target
+        for target in targets
+        if any(target.kind.name in target_string_match for target_string_match in target_name[0])
+    ]
+    if not matched_targets:
+        raise TVMCException(f'Workspace Pool "{pool_name}" using undefined Target "{target_name}"')
+    return matched_targets
+
+
+def _split_pools_to_pool_names(attr_str):
+    return re.split(",", attr_str) if attr_str else []
+
+
+def _parse_target_attributes_of_pool_name(attr_str, targets):
+    if not targets or attr_str is None:
+        return {}
+
+    target_attributes = {}
+    for pool_values in attr_str:
+        pool_name, target_name, target_value = re.split(":", pool_values)
+        if pool_name not in target_attributes:
+            target_attributes[pool_name] = {}
+
+        matched_targets = [target for target in targets if target_name == target.kind.name]
+        if matched_targets:
+            target_attributes[pool_name][matched_targets[0]] = target_value
+        else:
+            raise TVMCException(
+                "The workspace pool target specification "
+                "needs to contain a subset of the same TVM "
+                "targets as when specifying targets to use."
+            )
+    return target_attributes
+
+
+def _parse_attribute_of_pool_name(attr_str):
+    return dict(pool.split(":", maxsplit=1) for pool in attr_str) if attr_str else {}
+
+
+def workspace_pools_recombobulate(parsed, targets, extra_target):
+    """Reconstructs the Workspace Pools args and returns a WorkspaceMemoryPool object"""
+    WORKSPACE_POOL_PARAMS = [
+        "workspace_pools_size_hint_bytes",
+        "workspace_pools_targets",
+        "workspace_pools_clock_frequency_hz",
+        "workspace_pools_read_bandwidth_bytes_per_cycle",
+        "workspace_pools_write_bandwidth_bytes_per_cycle",
+        "workspace_pools_read_latency_cycles",
+        "workspace_pools_write_latency_cycles",
+    ]
+    WORKSPACE_POOL_TARGET_PARAMS = [
+        "workspace_pools_target_burst_bytes",
+    ]
+
+    # Load extra targets from CLI
+    additional_targets = []
+
+    for t in extra_target:
+        additional_targets.append(Target(t["raw"], host=targets[0].host or targets[0]))
+
+    target = targets + additional_targets
+    if targets[0].host:
+        target.append(targets[0].host)
+
+    workspace_pools = _split_pools_to_pool_names(parsed.workspace_pools)
+    if not workspace_pools:
+        return None
+
+    parse_attribute_to_pool_name = {
+        workspace_pool_param: _parse_attribute_of_pool_name(getattr(parsed, workspace_pool_param))
+        for workspace_pool_param in WORKSPACE_POOL_PARAMS
+    }
+    parse_target_burst_bytes_to_pool = {
+        workspace_pool_param: _parse_target_attributes_of_pool_name(
+            getattr(parsed, workspace_pool_param), targets
+        )
+        for workspace_pool_param in WORKSPACE_POOL_TARGET_PARAMS
+    }
+
+    return WorkspaceMemoryPools(
+        [
+            WorkspacePoolInfo(
+                pool_name,
+                targets=_parse_target_string(
+                    parse_attribute_to_pool_name["workspace_pools_targets"].get(pool_name),
+                    target,
+                    pool_name,
+                ),
+                pool_info_properties=PoolInfoProperties(
+                    size_hint_bytes=int(
+                        parse_attribute_to_pool_name["workspace_pools_size_hint_bytes"].get(
+                            pool_name, -1
+                        )
+                    ),
+                    clock_frequency_hz=int(
+                        parse_attribute_to_pool_name["workspace_pools_clock_frequency_hz"].get(
+                            pool_name, -1
+                        )
+                    ),
+                    read_bandwidth_bytes_per_cycle=int(
+                        parse_attribute_to_pool_name[
+                            "workspace_pools_read_bandwidth_bytes_per_cycle"
+                        ].get(pool_name, -1)
+                    ),
+                    write_bandwidth_bytes_per_cycle=int(
+                        parse_attribute_to_pool_name[
+                            "workspace_pools_write_bandwidth_bytes_per_cycle"
+                        ].get(pool_name, -1)
+                    ),
+                    read_latency_cycles=int(
+                        parse_attribute_to_pool_name["workspace_pools_read_latency_cycles"].get(
+                            pool_name, 0
+                        )
+                    ),
+                    write_latency_cycles=int(
+                        parse_attribute_to_pool_name["workspace_pools_write_latency_cycles"].get(
+                            pool_name, 0
+                        )
+                    ),
+                    target_burst_bytes=_parse_target_burst(
+                        parse_target_burst_bytes_to_pool["workspace_pools_target_burst_bytes"],
+                        pool_name,
+                    ),
+                ),
+            )
+            for pool_name in workspace_pools
+        ]
+    )
diff --git a/python/tvm/ir/memory_pools.py b/python/tvm/ir/memory_pools.py
index 0186a89f84..553bb49e3c 100644
--- a/python/tvm/ir/memory_pools.py
+++ b/python/tvm/ir/memory_pools.py
@@ -189,7 +189,7 @@ class WorkspaceMemoryPools(Object):
 
     def __init__(
         self,
-        pools: List[PoolInfo],
+        pools: List[WorkspacePoolInfo],
     ):
         self.__init_handle_by_constructor__(
             _ffi_api.WorkspaceMemoryPools, pools  # type: ignore # pylint: disable=no-member
diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc
index fd2f18aa99..9a238fba3b 100644
--- a/src/relay/backend/contrib/cmsisnn/target.cc
+++ b/src/relay/backend/contrib/cmsisnn/target.cc
@@ -32,7 +32,9 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
 
 TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
     .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
-    .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
+    .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
+    .add_attr_option<Array<String>>("mattr")
+    .add_attr_option<String>("mcpu");
 
 }  // namespace cmsisnn
 }  // namespace contrib
diff --git a/tests/python/driver/tvmc/test_command_line.py b/tests/python/driver/tvmc/test_command_line.py
index 0fddb7073f..af45f0bb7e 100644
--- a/tests/python/driver/tvmc/test_command_line.py
+++ b/tests/python/driver/tvmc/test_command_line.py
@@ -21,6 +21,8 @@ import shutil
 
 from pytest_lazyfixture import lazy_fixture
 from unittest import mock
+
+import tvm
 from tvm.driver.tvmc.main import _main
 from tvm.driver.tvmc.model import TVMCException
 from tvm.driver.tvmc import compiler
@@ -159,6 +161,26 @@ def test_tvmc_tune_file_check(capsys, invalid_input):
     assert captured.err == expected_err, on_assert_error
 
 
+@mock.patch("tvm.relay.build", side_effect=tvm.relay.build)
+@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
+def test_tvmc_workspace_pools_check(mock_pkg, mock_relay, keras_simple, tmpdir_factory):
+    pytest.importorskip("tensorflow")
+    tmpdir = tmpdir_factory.mktemp("data")
+
+    # Test model compilation
+    package_path = os.path.join(tmpdir, "keras-tvm.tar")
+    compile_str = (
+        f"tvmc compile --target=llvm --workspace-pools=sram "
+        f"--workspace-pools-targets=sram:llvm "
+        f"--output={package_path} {keras_simple}"
+    )
+    compile_args = compile_str.split(" ")[1:]
+    _main(compile_args)
+    assert os.path.exists(package_path)
+    assert mock_relay.call_count == 1
+    assert mock_relay.call_args_list[0][1]["workspace_memory_pools"].pools[0].pool_name == "sram"
+
+
 @pytest.fixture
 def paddle_model(paddle_resnet50):
     # If we can't import "paddle" module, skip testing paddle as the input model.
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index e8e93a6c75..27cd78d436 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -24,6 +24,8 @@ from unittest import mock
 import pytest
 
 import tvm
+from tvm.ir.memory_pools import WorkspacePoolInfo, WorkspaceMemoryPools
+from tvm.target import Target
 import tvm.testing
 from tvm.relay.op.contrib.ethosn import ethosn_available
 from tvm.relay.backend import Runtime, Executor
@@ -674,5 +676,25 @@ def test_compile_tflite_module_with_mod_name_and_ethosu(
             assert b"tvmgen_classify_ethos_u_main_" in content
 
 
+@mock.patch("tvm.relay.build")
+@mock.patch("tvm.driver.tvmc.load")
+@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
+def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay):
+    mock_fe.return_value = mock.MagicMock()
+    mock_relay.return_value = mock.MagicMock()
+    memory_pools = WorkspaceMemoryPools(
+        [WorkspacePoolInfo(pool_name="sram", targets=[Target("llvm")])]
+    )
+    tvmc_model = tvmc.load("no_file_needed")
+    tvmc.compile(
+        tvmc_model,
+        target="llvm,c",
+        workspace_pools=memory_pools,
+    )
+
+    assert mock_relay.call_count == 1
+    assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/driver/tvmc/test_workspace_pools.py b/tests/python/driver/tvmc/test_workspace_pools.py
new file mode 100644
index 0000000000..386181aaf2
--- /dev/null
+++ b/tests/python/driver/tvmc/test_workspace_pools.py
@@ -0,0 +1,404 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import argparse
+
+from tvm.driver.tvmc.workspace_pools import (
+    generate_workspace_pools_args,
+    workspace_pools_recombobulate,
+)
+from tvm.target import Target
+from tvm.driver.tvmc import TVMCException
+
+
+def test_workspace_pools_argparse():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, unparsed = parser.parse_known_args(
+        [
+            "--workspace-pools=sram,flash",
+            "--workspace-pools-targets=sram:c,llvm",
+            "--workspace-pools-targets=flash:c",
+            "--workspace-pools-size-hint-bytes=sram:400",
+            "--workspace-pools-size-hint-bytes=sram:500",
+            "--workspace-pools-clock-frequency-hz=sram:500",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:200",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:100",
+            "--workspace-pools-read-latency-cycles=sram:50",
+            "--workspace-pools-read-latency-cycles=flash:30",
+            "--workspace-pools-write-latency-cycles=sram:9001",
+            "--workspace-pools-target-burst-bytes=sram:c:2",
+            "--workspace-pools-is-internal=sram:0",
+        ]
+    )
+
+    assert parsed.workspace_pools == "sram,flash"
+    assert parsed.workspace_pools_targets == ["sram:c,llvm", "flash:c"]
+    assert parsed.workspace_pools_size_hint_bytes == ["sram:400", "sram:500"]
+    assert parsed.workspace_pools_clock_frequency_hz == ["sram:500"]
+    assert parsed.workspace_pools_read_bandwidth_bytes_per_cycle == ["sram:200"]
+    assert parsed.workspace_pools_write_bandwidth_bytes_per_cycle == ["sram:100"]
+    assert parsed.workspace_pools_read_latency_cycles == ["sram:50", "flash:30"]
+    assert parsed.workspace_pools_write_latency_cycles == ["sram:9001"]
+    assert parsed.workspace_pools_target_burst_bytes == ["sram:c:2"]
+
+    assert unparsed == ["--workspace-pools-is-internal=sram:0"]
+
+
+def test_workspace_pools_recombobulate_empty():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args([])
+
+    targets = [Target("llvm")]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert memory_pools is None
+
+
+def test_workspace_pools_recombobulate():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:llvm",
+            "--workspace-pools-size-hint-bytes=sram:400",
+            "--workspace-pools-clock-frequency-hz=sram:500",
+        ]
+    )
+
+    targets = [Target("llvm")]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert len(memory_pools.pools) == 1
+    assert memory_pools.pools[0].pool_name == "sram"
+    assert memory_pools.pools[0].size_hint_bytes == 400
+    assert memory_pools.pools[0].clock_frequency_hz == 500
+
+
+def test_workspace_pools_defaults():
+    parser = argparse.ArgumentParser()
+    targets = [Target("llvm")]
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:llvm",
+        ]
+    )
+
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert len(memory_pools.pools) == 1
+    assert memory_pools.pools[0].pool_name == "sram"
+    assert memory_pools.pools[0].size_hint_bytes == -1
+    assert memory_pools.pools[0].clock_frequency_hz == -1
+    assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == -1
+    assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == -1
+    assert memory_pools.pools[0].read_latency_cycles == 0
+    assert memory_pools.pools[0].write_latency_cycles == 0
+    assert len(memory_pools.pools[0].target_burst_bytes) == 0
+
+
+def test_workspace_pools_recombobulate_multi_fields():
+    parser = argparse.ArgumentParser()
+    targets = [Target("c")]
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:c",
+            "--workspace-pools-size-hint-bytes=sram:400",
+            "--workspace-pools-clock-frequency-hz=sram:500",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:200",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:100",
+            "--workspace-pools-read-latency-cycles=sram:50",
+            "--workspace-pools-write-latency-cycles=sram:9001",
+            "--workspace-pools-target-burst-bytes=sram:c:2",
+        ]
+    )
+
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert len(memory_pools.pools) == 1
+    assert memory_pools.pools[0].pool_name == "sram"
+    assert memory_pools.pools[0].size_hint_bytes == 400
+    assert memory_pools.pools[0].clock_frequency_hz == 500
+    assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 200
+    assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 100
+    assert memory_pools.pools[0].read_latency_cycles == 50
+    assert memory_pools.pools[0].write_latency_cycles == 9001
+    assert len(memory_pools.pools[0].target_burst_bytes) == 1
+    assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 2
+
+
+def test_workspace_pools_recombobulate_multi_fields_variant():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=flash",
+            "--workspace-pools-targets=flash:c",
+            "--workspace-pools-size-hint-bytes=flash:2048",
+            "--workspace-pools-clock-frequency-hz=flash:2000000",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1",
+            "--workspace-pools-read-latency-cycles=flash:2000",
+            "--workspace-pools-write-latency-cycles=flash:1000",
+            "--workspace-pools-target-burst-bytes=flash:c:4",
+        ]
+    )
+
+    targets = [Target("c")]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert len(memory_pools.pools) == 1
+    assert memory_pools.pools[0].pool_name == "flash"
+    assert memory_pools.pools[0].size_hint_bytes == 2048
+    assert memory_pools.pools[0].clock_frequency_hz == 2000000
+    assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 4
+    assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 1
+    assert memory_pools.pools[0].read_latency_cycles == 2000
+    assert memory_pools.pools[0].write_latency_cycles == 1000
+    assert len(memory_pools.pools[0].target_burst_bytes) == 1
+    assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 4
+
+
+def test_workspace_pools_recombobulate_multi_fields_multi_pools():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram,flash",
+            "--workspace-pools-targets=sram:c",
+            "--workspace-pools-targets=flash:c",
+            "--workspace-pools-size-hint-bytes=sram:1024",
+            "--workspace-pools-size-hint-bytes=flash:2048",
+            "--workspace-pools-clock-frequency-hz=sram:4000000",
+            "--workspace-pools-clock-frequency-hz=flash:2000000",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:8",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:4",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1",
+            "--workspace-pools-read-latency-cycles=sram:250",
+            "--workspace-pools-read-latency-cycles=flash:2000",
+            "--workspace-pools-write-latency-cycles=sram:500",
+            "--workspace-pools-write-latency-cycles=flash:1000",
+            "--workspace-pools-target-burst-bytes=sram:c:8",
+            "--workspace-pools-target-burst-bytes=flash:c:4",
+        ]
+    )
+
+    targets = [Target("c")]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert len(memory_pools.pools) == 2
+
+    assert memory_pools.pools[0].pool_name == "sram"
+    assert memory_pools.pools[0].size_hint_bytes == 1024
+    assert memory_pools.pools[0].clock_frequency_hz == 4000000
+    assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 8
+    assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 4
+    assert memory_pools.pools[0].read_latency_cycles == 250
+    assert memory_pools.pools[0].write_latency_cycles == 500
+    assert len(memory_pools.pools[0].target_burst_bytes) == 1
+    assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 8
+
+    assert memory_pools.pools[1].pool_name == "flash"
+    assert memory_pools.pools[1].size_hint_bytes == 2048
+    assert memory_pools.pools[1].clock_frequency_hz == 2000000
+    assert memory_pools.pools[1].read_bandwidth_bytes_per_cycle == 4
+    assert memory_pools.pools[1].write_bandwidth_bytes_per_cycle == 1
+    assert memory_pools.pools[1].read_latency_cycles == 2000
+    assert memory_pools.pools[1].write_latency_cycles == 1000
+    assert len(memory_pools.pools[1].target_burst_bytes) == 1
+    assert memory_pools.pools[1].target_burst_bytes[targets[0]] == 4
+
+
+def test_workspace_pools_recombobulate_multi_fields_ordering():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram,flash",
+            "--workspace-pools-targets=flash:c",
+            "--workspace-pools-targets=sram:c",
+            "--workspace-pools-size-hint-bytes=flash:2048",
+            "--workspace-pools-size-hint-bytes=sram:1024",
+            "--workspace-pools-clock-frequency-hz=sram:4000000",
+            "--workspace-pools-clock-frequency-hz=flash:2000000",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:8",
+            "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:4",
+            "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1",
+            "--workspace-pools-read-latency-cycles=sram:250",
+            "--workspace-pools-read-latency-cycles=flash:2000",
+            "--workspace-pools-write-latency-cycles=flash:1000",
+            "--workspace-pools-write-latency-cycles=sram:500",
+            "--workspace-pools-target-burst-bytes=sram:c:8",
+            "--workspace-pools-target-burst-bytes=flash:c:4",
+        ]
+    )
+
+    targets = [Target("c")]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+    assert len(memory_pools.pools) == 2
+
+    assert memory_pools.pools[0].pool_name == "sram"
+    assert memory_pools.pools[0].size_hint_bytes == 1024
+    assert memory_pools.pools[0].write_latency_cycles == 500
+
+    assert memory_pools.pools[1].pool_name == "flash"
+    assert memory_pools.pools[1].size_hint_bytes == 2048
+    assert memory_pools.pools[1].write_latency_cycles == 1000
+
+
+def test_workspace_pools_recombobulate_multi_target():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:c,llvm",
+            "--workspace-pools-target-burst-bytes=sram:c:8",
+            "--workspace-pools-target-burst-bytes=sram:llvm:4",
+        ]
+    )
+
+    c_target = Target("c")
+    llvm_target = Target("llvm")
+    extra_targets = []
+
+    targets = [c_target, llvm_target]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, extra_targets)
+
+    assert len(memory_pools.pools) == 1
+
+    assert len(memory_pools.pools[0].target_burst_bytes) == 2
+    assert memory_pools.pools[0].target_burst_bytes[c_target] == 8
+    assert memory_pools.pools[0].target_burst_bytes[llvm_target] == 4
+
+
+def test_workspace_pools_recombobulate_no_target_burst_bytes():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:c",
+            "--workspace-pools-target-burst-bytes=sram:c:8",
+        ]
+    )
+
+    c_target = Target("c")
+    targets = [c_target]
+
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+    assert len(memory_pools.pools) == 1
+    assert len(memory_pools.pools[0].target_burst_bytes) == 1
+    assert memory_pools.pools[0].target_burst_bytes[c_target] == 8
+
+
+def test_workspace_pools_recombobulate_missing_target():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+        ]
+    )
+
+    c_target = Target("c")
+    with pytest.raises(TVMCException):
+        workspace_pools_recombobulate(parsed, [c_target], _)
+
+
+def test_workspace_pools_recombobulate_multi_target_multi_pool():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:c,llvm",
+            "--workspace-pools-target-burst-bytes=sram:c:8",
+            "--workspace-pools-target-burst-bytes=sram:llvm:4",
+        ]
+    )
+
+    c_target = Target("c")
+    llvm_target = Target("llvm")
+
+    targets = [c_target, llvm_target]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+    assert len(memory_pools.pools) == 1
+
+    assert len(memory_pools.pools[0].target_burst_bytes) == 2
+    assert memory_pools.pools[0].target_burst_bytes[llvm_target] == 4
+    assert memory_pools.pools[0].target_burst_bytes[c_target] == 8
+
+
+def test_workspace_pools_recombobulate_parameter_overrides():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram",
+            "--workspace-pools-targets=sram:c",
+            "--workspace-pools-size-hint-bytes=sram:800",
+            "--workspace-pools-size-hint-bytes=sram:400",
+            "--workspace-pools-clock-frequency-hz=sram:4000000",
+            "--workspace-pools-clock-frequency-hz=sram:3600000",
+        ]
+    )
+
+    c_target = Target("c")
+
+    targets = [c_target]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+    assert len(memory_pools.pools) == 1
+
+    assert memory_pools.pools[0].size_hint_bytes == 400
+    assert memory_pools.pools[0].clock_frequency_hz == 3600000
+
+
+def test_workspace_pools_recombobulate_single_pool_overrides():
+    parser = argparse.ArgumentParser()
+    generate_workspace_pools_args(parser)
+    parsed, _ = parser.parse_known_args(
+        [
+            "--workspace-pools=sram,flash",
+            "--workspace-pools-targets=sram:c",
+            "--workspace-pools-targets=flash:c",
+            "--workspace-pools-targets=sram:c,llvm",  # Override on one pool
+            "--workspace-pools-size-hint-bytes=sram:800",
+            "--workspace-pools-size-hint-bytes=flash:1200",
+            "--workspace-pools-size-hint-bytes=sram:400",  # Override on one pool
+        ]
+    )
+
+    c_target = Target("c")
+    llvm_target = Target("llvm")
+
+    targets = [c_target, llvm_target]
+    memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+    assert len(memory_pools.pools) == 2
+
+    assert memory_pools.pools[0].size_hint_bytes == 400
+    assert memory_pools.pools[1].size_hint_bytes == 1200
+
+    assert len(memory_pools.pools[0].targets) == 2
+    assert len(memory_pools.pools[1].targets) == 1