You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/07/25 09:43:05 UTC
[tvm] branch main updated: [TVMC] Workspace Pools Parameters (#11427)
This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 75ec1cffa9 [TVMC] Workspace Pools Parameters (#11427)
75ec1cffa9 is described below
commit 75ec1cffa9f160dd3165fedbf4408731ebfa797a
Author: Dhruv Chauhan <89...@users.noreply.github.com>
AuthorDate: Mon Jul 25 10:43:00 2022 +0100
[TVMC] Workspace Pools Parameters (#11427)
* [TVMC] Workspace Pools Parameters
Attributes from tvmc are now passable into the created PoolInfo objects
inside WorkspaceMemoryPools. This is passed in to relay.build that get
attached to IRModule attribute.
* [TVMC] Workspace Pools Parameters
Address comments, fix linting. Testing improved.
Change-Id: Iea79329b6b9ec1cbc51e5c293449bf6dd43b00c5
* [TVMC] Workspace Pools Parameters
Update workspace pools test naming
Change-Id: Ib698d6248be1e6f44340f27db3641c985bc5c5d8
* [TVMC] Workspace Pools Parameters
Add test for parameter overrides.
Change-Id: I67d5470dcfbfbc9ab27f34e20a9269d2070193ca
* [TVMC] Workspace Pools Parameters
Rebasing over #10189
Updates to the way a WorkspaceMemoryPool object is created
Change-Id: I1f0e1d240343af311ddb3ed5c564cc1ab329f463
* [TVMC] Workspace Pools Parameters
Fix linting, fix CI
Change-Id: If75f8709ac4ad925655eca54b3e5c1bb09d025e8
* [TVMC] Workspace Pools Parameters
Add mcpu and mattr to target registry for cmsis-nn
Change-Id: I15257b8d01624c071c738cab6d12ecb84ed6cb16
* [TVMC] Workspace Pools Parameters
Added test for override on single pool when multiple pools are present
Updated functionality of parsing multiple attributes
Change-Id: I2c0745051b7a923dd7f75040bfb89bbc99376a11
---
include/tvm/ir/memory_pools.h | 1 +
python/tvm/driver/tvmc/compiler.py | 30 +-
python/tvm/driver/tvmc/workspace_pools.py | 237 +++++++++++++
python/tvm/ir/memory_pools.py | 2 +-
src/relay/backend/contrib/cmsisnn/target.cc | 4 +-
tests/python/driver/tvmc/test_command_line.py | 22 ++
tests/python/driver/tvmc/test_compiler.py | 22 ++
tests/python/driver/tvmc/test_workspace_pools.py | 404 +++++++++++++++++++++++
8 files changed, 717 insertions(+), 5 deletions(-)
diff --git a/include/tvm/ir/memory_pools.h b/include/tvm/ir/memory_pools.h
index ee07841de4..ebab13cf3a 100644
--- a/include/tvm/ir/memory_pools.h
+++ b/include/tvm/ir/memory_pools.h
@@ -65,6 +65,7 @@ struct PoolInfoNode : public Object {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pool_name", &pool_name);
+ v->Visit("targets", &targets);
v->Visit("size_hint_bytes", &size_hint_bytes);
v->Visit("clock_frequency_hz", &clock_frequency_hz);
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py
index 1385044704..2955df5543 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -26,6 +26,7 @@ import tvm
from tvm import autotvm, auto_scheduler
from tvm import relay
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
+from tvm.ir.memory_pools import WorkspaceMemoryPools
from tvm.target import Target
from tvm.relay.backend import Executor, Runtime
@@ -37,6 +38,7 @@ from .pass_config import parse_configs
from .pass_list import parse_pass_list_str
from .transform import convert_graph_layout
from .shape_parser import parse_shape_string
+from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate
# pylint: disable=invalid-name
logger = logging.getLogger("TVMC")
@@ -142,10 +144,11 @@ def add_compile_parser(subparsers, _, json_params):
default="default",
help="The output module name. Defaults to 'default'.",
)
-
for one_entry in json_params:
parser.set_defaults(**one_entry)
+ generate_workspace_pools_args(parser)
+
def drive_compile(args):
"""Invoke tvmc.compiler module with command line arguments
@@ -161,6 +164,7 @@ def drive_compile(args):
Zero if successfully completed
"""
+
if not os.path.isfile(args.FILE):
raise TVMCException(
f"Input file '{args.FILE}' doesn't exist, is a broken symbolic link, or a directory."
@@ -170,6 +174,9 @@ def drive_compile(args):
dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None
+ additional_targets = reconstruct_target_args(args)
+ workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets)
+
compile_model(
tvmc_model,
args.target,
@@ -186,8 +193,11 @@ def drive_compile(args):
desired_layout=args.desired_layout,
disabled_pass=args.disabled_pass,
pass_context_configs=args.pass_config,
- additional_target_options=reconstruct_target_args(args),
mod_name=args.module_name,
+ additional_target_options=additional_targets,
+ workspace_pools=(
+ workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
+ ),
)
return 0
@@ -212,6 +222,7 @@ def compile_model(
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
use_vm: bool = False,
mod_name: Optional[str] = "default",
+ workspace_pools: Optional[WorkspaceMemoryPools] = None,
):
"""Compile a model from a supported framework into a TVM module.
@@ -263,6 +274,9 @@ def compile_model(
Whether to use the VM to compile the model as opposed to the graph executor
mod_name: str, optional
The module name
+ workspace_pools: WorkspaceMemoryPools, optional
+ Specification of WorkspacePoolInfo objects to be used as workspace memory in the
+ compilation.
Returns
-------
@@ -313,6 +327,7 @@ def compile_model(
params=params,
use_vm=use_vm,
mod_name=mod_name,
+ workspace_pools=workspace_pools,
)
else:
with autotvm.apply_history_best(tuning_records):
@@ -328,6 +343,7 @@ def compile_model(
params=params,
use_vm=use_vm,
mod_name=mod_name,
+ workspace_pools=workspace_pools,
)
else:
with tvm.transform.PassContext(
@@ -342,6 +358,7 @@ def compile_model(
params=params,
use_vm=use_vm,
mod_name=mod_name,
+ workspace_pools=workspace_pools,
)
# Generate output dump files with sources
@@ -380,6 +397,7 @@ def build(
params: Dict[str, tvm.nd.NDArray],
use_vm: bool,
mod_name: str,
+ workspace_pools: Optional[WorkspaceMemoryPools],
):
"""
Builds the model with the provided executor.
@@ -408,7 +426,13 @@ def build(
return relay.vm.compile(mod, target=tvm_target, params=params)
logger.debug("building with relay build")
return relay.build(
- mod, target=tvm_target, executor=executor, runtime=runtime, params=params, mod_name=mod_name
+ mod,
+ target=tvm_target,
+ executor=executor,
+ runtime=runtime,
+ params=params,
+ mod_name=mod_name,
+ workspace_memory_pools=workspace_pools,
)
diff --git a/python/tvm/driver/tvmc/workspace_pools.py b/python/tvm/driver/tvmc/workspace_pools.py
new file mode 100644
index 0000000000..2c91488fb4
--- /dev/null
+++ b/python/tvm/driver/tvmc/workspace_pools.py
@@ -0,0 +1,237 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Functions for processing dynamic workspace pool TVMC args
+"""
+
+
+import logging
+import re
+
+from tvm.driver.tvmc import TVMCException
+from tvm.target import Target
+from tvm.ir.memory_pools import PoolInfoProperties, WorkspaceMemoryPools, WorkspacePoolInfo
+
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+
+def generate_workspace_pools_args(parser):
+ """Generates arguments for each Workspace Pools's options"""
+ parser.add_argument(
+ "--workspace-pools",
+ help="""The name of the memory pool
+ Example usage: --workspace-pools=flash""",
+ )
+ parser.add_argument(
+ "--workspace-pools-targets",
+ help="""The name of the targets specified for the memory pool
+ Example usage: --workspace-pools-targets=flash:llvm""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-size-hint-bytes",
+ nargs="?",
+ help="""The expected size hint to be used by the allocator.
+ Example usage: --workspace-pools-size-hint-bytes=flash:8""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-clock-frequency-hz",
+ nargs="?",
+ help="""The clock frequency that the memory pool runs at in Hz.
+ Example usage: --workspace-pools-clock-frequency-hz=flash:70000000""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-read-bandwidth-bytes-per-cycle",
+ nargs="?",
+ help="""The read bandwidth of the memory pool in bytes/cycle.
+ Example usage: --workspace-pools-read-bandwidth-bytes-per-cycle=flash:4""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-write-bandwidth-bytes-per-cycle",
+ nargs="?",
+ help="""The write bandwidth of the memory pool in bytes/cycle.
+ Example usage: --workspace-pools-write-bandwidth-bytes-per-cycle=flash:8""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-read-latency-cycles",
+ nargs="?",
+ help="""The read latency of the memory pool in cycles.
+ Example usage: --workspace-pools-read-latency-cycles=flash:4""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-write-latency-cycles",
+ nargs="?",
+ help="""The write latency of the memory pool in cycles.
+ Example usage: --workspace-pools-write-latency-cycles=flash:8""",
+ action="append",
+ )
+ parser.add_argument(
+ "--workspace-pools-target-burst-bytes",
+ help="""The burst length of the memory pool in bytes per target.
+ Example usage: --workspace-pools-target-burst-bytes=flash:accel:1""",
+ action="append",
+ )
+
+
+def _parse_target_burst(attr_str, pool_name):
+ if pool_name not in attr_str:
+ return {}
+
+ return {target: int(attr_str[pool_name][target]) for target in attr_str[pool_name]}
+
+
+def _parse_target_string(attr_str, targets, pool_name):
+ if attr_str is None:
+ raise TVMCException(f'No target specified for Workspace Pool "{pool_name}"')
+
+ target_name = [re.split(",", attr_str)]
+ matched_targets = [
+ target
+ for target in targets
+ if any(target.kind.name in target_string_match for target_string_match in target_name[0])
+ ]
+ if not matched_targets:
+ raise TVMCException(f'Workspace Pool "{pool_name}" using undefined Target "{target_name}"')
+ return matched_targets
+
+
+def _split_pools_to_pool_names(attr_str):
+ return re.split(",", attr_str) if attr_str else []
+
+
+def _parse_target_attributes_of_pool_name(attr_str, targets):
+ if not targets or attr_str is None:
+ return {}
+
+ target_attributes = {}
+ for pool_values in attr_str:
+ pool_name, target_name, target_value = re.split(":", pool_values)
+ if pool_name not in target_attributes:
+ target_attributes[pool_name] = {}
+
+ matched_targets = [target for target in targets if target_name == target.kind.name]
+ if matched_targets:
+ target_attributes[pool_name][matched_targets[0]] = target_value
+ else:
+ raise TVMCException(
+ "The workspace pool target specification "
+ "needs to contain a subset of the same TVM "
+ "targets as when specifying targets to use."
+ )
+ return target_attributes
+
+
+def _parse_attribute_of_pool_name(attr_str):
+ return dict(pool.split(":", maxsplit=1) for pool in attr_str) if attr_str else {}
+
+
+def workspace_pools_recombobulate(parsed, targets, extra_target):
+ """Reconstructs the Workspace Pools args and returns a WorkspaceMemoryPool object"""
+ WORKSPACE_POOL_PARAMS = [
+ "workspace_pools_size_hint_bytes",
+ "workspace_pools_targets",
+ "workspace_pools_clock_frequency_hz",
+ "workspace_pools_read_bandwidth_bytes_per_cycle",
+ "workspace_pools_write_bandwidth_bytes_per_cycle",
+ "workspace_pools_read_latency_cycles",
+ "workspace_pools_write_latency_cycles",
+ ]
+ WORKSPACE_POOL_TARGET_PARAMS = [
+ "workspace_pools_target_burst_bytes",
+ ]
+
+ # Load extra targets from CLI
+ additional_targets = []
+
+ for t in extra_target:
+ additional_targets.append(Target(t["raw"], host=targets[0].host or targets[0]))
+
+ target = targets + additional_targets
+ if targets[0].host:
+ target.append(targets[0].host)
+
+ workspace_pools = _split_pools_to_pool_names(parsed.workspace_pools)
+ if not workspace_pools:
+ return None
+
+ parse_attribute_to_pool_name = {
+ workspace_pool_param: _parse_attribute_of_pool_name(getattr(parsed, workspace_pool_param))
+ for workspace_pool_param in WORKSPACE_POOL_PARAMS
+ }
+ parse_target_burst_bytes_to_pool = {
+ workspace_pool_param: _parse_target_attributes_of_pool_name(
+ getattr(parsed, workspace_pool_param), targets
+ )
+ for workspace_pool_param in WORKSPACE_POOL_TARGET_PARAMS
+ }
+
+ return WorkspaceMemoryPools(
+ [
+ WorkspacePoolInfo(
+ pool_name,
+ targets=_parse_target_string(
+ parse_attribute_to_pool_name["workspace_pools_targets"].get(pool_name),
+ target,
+ pool_name,
+ ),
+ pool_info_properties=PoolInfoProperties(
+ size_hint_bytes=int(
+ parse_attribute_to_pool_name["workspace_pools_size_hint_bytes"].get(
+ pool_name, -1
+ )
+ ),
+ clock_frequency_hz=int(
+ parse_attribute_to_pool_name["workspace_pools_clock_frequency_hz"].get(
+ pool_name, -1
+ )
+ ),
+ read_bandwidth_bytes_per_cycle=int(
+ parse_attribute_to_pool_name[
+ "workspace_pools_read_bandwidth_bytes_per_cycle"
+ ].get(pool_name, -1)
+ ),
+ write_bandwidth_bytes_per_cycle=int(
+ parse_attribute_to_pool_name[
+ "workspace_pools_write_bandwidth_bytes_per_cycle"
+ ].get(pool_name, -1)
+ ),
+ read_latency_cycles=int(
+ parse_attribute_to_pool_name["workspace_pools_read_latency_cycles"].get(
+ pool_name, 0
+ )
+ ),
+ write_latency_cycles=int(
+ parse_attribute_to_pool_name["workspace_pools_write_latency_cycles"].get(
+ pool_name, 0
+ )
+ ),
+ target_burst_bytes=_parse_target_burst(
+ parse_target_burst_bytes_to_pool["workspace_pools_target_burst_bytes"],
+ pool_name,
+ ),
+ ),
+ )
+ for pool_name in workspace_pools
+ ]
+ )
diff --git a/python/tvm/ir/memory_pools.py b/python/tvm/ir/memory_pools.py
index 0186a89f84..553bb49e3c 100644
--- a/python/tvm/ir/memory_pools.py
+++ b/python/tvm/ir/memory_pools.py
@@ -189,7 +189,7 @@ class WorkspaceMemoryPools(Object):
def __init__(
self,
- pools: List[PoolInfo],
+ pools: List[WorkspacePoolInfo],
):
self.__init_handle_by_constructor__(
_ffi_api.WorkspaceMemoryPools, pools # type: ignore # pylint: disable=no-member
diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc
index fd2f18aa99..9a238fba3b 100644
--- a/src/relay/backend/contrib/cmsisnn/target.cc
+++ b/src/relay/backend/contrib/cmsisnn/target.cc
@@ -32,7 +32,9 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
- .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
+ .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
+ .add_attr_option<Array<String>>("mattr")
+ .add_attr_option<String>("mcpu");
} // namespace cmsisnn
} // namespace contrib
diff --git a/tests/python/driver/tvmc/test_command_line.py b/tests/python/driver/tvmc/test_command_line.py
index 0fddb7073f..af45f0bb7e 100644
--- a/tests/python/driver/tvmc/test_command_line.py
+++ b/tests/python/driver/tvmc/test_command_line.py
@@ -21,6 +21,8 @@ import shutil
from pytest_lazyfixture import lazy_fixture
from unittest import mock
+
+import tvm
from tvm.driver.tvmc.main import _main
from tvm.driver.tvmc.model import TVMCException
from tvm.driver.tvmc import compiler
@@ -159,6 +161,26 @@ def test_tvmc_tune_file_check(capsys, invalid_input):
assert captured.err == expected_err, on_assert_error
+@mock.patch("tvm.relay.build", side_effect=tvm.relay.build)
+@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
+def test_tvmc_workspace_pools_check(mock_pkg, mock_relay, keras_simple, tmpdir_factory):
+ pytest.importorskip("tensorflow")
+ tmpdir = tmpdir_factory.mktemp("data")
+
+ # Test model compilation
+ package_path = os.path.join(tmpdir, "keras-tvm.tar")
+ compile_str = (
+ f"tvmc compile --target=llvm --workspace-pools=sram "
+ f"--workspace-pools-targets=sram:llvm "
+ f"--output={package_path} {keras_simple}"
+ )
+ compile_args = compile_str.split(" ")[1:]
+ _main(compile_args)
+ assert os.path.exists(package_path)
+ assert mock_relay.call_count == 1
+ assert mock_relay.call_args_list[0][1]["workspace_memory_pools"].pools[0].pool_name == "sram"
+
+
@pytest.fixture
def paddle_model(paddle_resnet50):
# If we can't import "paddle" module, skip testing paddle as the input model.
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index e8e93a6c75..27cd78d436 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -24,6 +24,8 @@ from unittest import mock
import pytest
import tvm
+from tvm.ir.memory_pools import WorkspacePoolInfo, WorkspaceMemoryPools
+from tvm.target import Target
import tvm.testing
from tvm.relay.op.contrib.ethosn import ethosn_available
from tvm.relay.backend import Runtime, Executor
@@ -674,5 +676,25 @@ def test_compile_tflite_module_with_mod_name_and_ethosu(
assert b"tvmgen_classify_ethos_u_main_" in content
+@mock.patch("tvm.relay.build")
+@mock.patch("tvm.driver.tvmc.load")
+@mock.patch("tvm.driver.tvmc.model.TVMCPackage.__init__", return_value=None)
+def test_compile_check_workspace_pools(mock_pkg, mock_fe, mock_relay):
+ mock_fe.return_value = mock.MagicMock()
+ mock_relay.return_value = mock.MagicMock()
+ memory_pools = WorkspaceMemoryPools(
+ [WorkspacePoolInfo(pool_name="sram", targets=[Target("llvm")])]
+ )
+ tvmc_model = tvmc.load("no_file_needed")
+ tvmc.compile(
+ tvmc_model,
+ target="llvm,c",
+ workspace_pools=memory_pools,
+ )
+
+ assert mock_relay.call_count == 1
+ assert mock_relay.call_args_list[0][1]["workspace_memory_pools"] == memory_pools
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/driver/tvmc/test_workspace_pools.py b/tests/python/driver/tvmc/test_workspace_pools.py
new file mode 100644
index 0000000000..386181aaf2
--- /dev/null
+++ b/tests/python/driver/tvmc/test_workspace_pools.py
@@ -0,0 +1,404 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import argparse
+
+from tvm.driver.tvmc.workspace_pools import (
+ generate_workspace_pools_args,
+ workspace_pools_recombobulate,
+)
+from tvm.target import Target
+from tvm.driver.tvmc import TVMCException
+
+
+def test_workspace_pools_argparse():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, unparsed = parser.parse_known_args(
+ [
+ "--workspace-pools=sram,flash",
+ "--workspace-pools-targets=sram:c,llvm",
+ "--workspace-pools-targets=flash:c",
+ "--workspace-pools-size-hint-bytes=sram:400",
+ "--workspace-pools-size-hint-bytes=sram:500",
+ "--workspace-pools-clock-frequency-hz=sram:500",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:200",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:100",
+ "--workspace-pools-read-latency-cycles=sram:50",
+ "--workspace-pools-read-latency-cycles=flash:30",
+ "--workspace-pools-write-latency-cycles=sram:9001",
+ "--workspace-pools-target-burst-bytes=sram:c:2",
+ "--workspace-pools-is-internal=sram:0",
+ ]
+ )
+
+ assert parsed.workspace_pools == "sram,flash"
+ assert parsed.workspace_pools_targets == ["sram:c,llvm", "flash:c"]
+ assert parsed.workspace_pools_size_hint_bytes == ["sram:400", "sram:500"]
+ assert parsed.workspace_pools_clock_frequency_hz == ["sram:500"]
+ assert parsed.workspace_pools_read_bandwidth_bytes_per_cycle == ["sram:200"]
+ assert parsed.workspace_pools_write_bandwidth_bytes_per_cycle == ["sram:100"]
+ assert parsed.workspace_pools_read_latency_cycles == ["sram:50", "flash:30"]
+ assert parsed.workspace_pools_write_latency_cycles == ["sram:9001"]
+ assert parsed.workspace_pools_target_burst_bytes == ["sram:c:2"]
+
+ assert unparsed == ["--workspace-pools-is-internal=sram:0"]
+
+
+def test_workspace_pools_recombobulate_empty():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args([])
+
+ targets = [Target("llvm")]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert memory_pools is None
+
+
+def test_workspace_pools_recombobulate():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:llvm",
+ "--workspace-pools-size-hint-bytes=sram:400",
+ "--workspace-pools-clock-frequency-hz=sram:500",
+ ]
+ )
+
+ targets = [Target("llvm")]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert len(memory_pools.pools) == 1
+ assert memory_pools.pools[0].pool_name == "sram"
+ assert memory_pools.pools[0].size_hint_bytes == 400
+ assert memory_pools.pools[0].clock_frequency_hz == 500
+
+
+def test_workspace_pools_defaults():
+ parser = argparse.ArgumentParser()
+ targets = [Target("llvm")]
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:llvm",
+ ]
+ )
+
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert len(memory_pools.pools) == 1
+ assert memory_pools.pools[0].pool_name == "sram"
+ assert memory_pools.pools[0].size_hint_bytes == -1
+ assert memory_pools.pools[0].clock_frequency_hz == -1
+ assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == -1
+ assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == -1
+ assert memory_pools.pools[0].read_latency_cycles == 0
+ assert memory_pools.pools[0].write_latency_cycles == 0
+ assert len(memory_pools.pools[0].target_burst_bytes) == 0
+
+
+def test_workspace_pools_recombobulate_multi_fields():
+ parser = argparse.ArgumentParser()
+ targets = [Target("c")]
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:c",
+ "--workspace-pools-size-hint-bytes=sram:400",
+ "--workspace-pools-clock-frequency-hz=sram:500",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:200",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:100",
+ "--workspace-pools-read-latency-cycles=sram:50",
+ "--workspace-pools-write-latency-cycles=sram:9001",
+ "--workspace-pools-target-burst-bytes=sram:c:2",
+ ]
+ )
+
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert len(memory_pools.pools) == 1
+ assert memory_pools.pools[0].pool_name == "sram"
+ assert memory_pools.pools[0].size_hint_bytes == 400
+ assert memory_pools.pools[0].clock_frequency_hz == 500
+ assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 200
+ assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 100
+ assert memory_pools.pools[0].read_latency_cycles == 50
+ assert memory_pools.pools[0].write_latency_cycles == 9001
+ assert len(memory_pools.pools[0].target_burst_bytes) == 1
+ assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 2
+
+
+def test_workspace_pools_recombobulate_multi_fields_variant():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=flash",
+ "--workspace-pools-targets=flash:c",
+ "--workspace-pools-size-hint-bytes=flash:2048",
+ "--workspace-pools-clock-frequency-hz=flash:2000000",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1",
+ "--workspace-pools-read-latency-cycles=flash:2000",
+ "--workspace-pools-write-latency-cycles=flash:1000",
+ "--workspace-pools-target-burst-bytes=flash:c:4",
+ ]
+ )
+
+ targets = [Target("c")]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert len(memory_pools.pools) == 1
+ assert memory_pools.pools[0].pool_name == "flash"
+ assert memory_pools.pools[0].size_hint_bytes == 2048
+ assert memory_pools.pools[0].clock_frequency_hz == 2000000
+ assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 4
+ assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 1
+ assert memory_pools.pools[0].read_latency_cycles == 2000
+ assert memory_pools.pools[0].write_latency_cycles == 1000
+ assert len(memory_pools.pools[0].target_burst_bytes) == 1
+ assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 4
+
+
+def test_workspace_pools_recombobulate_multi_fields_multi_pools():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram,flash",
+ "--workspace-pools-targets=sram:c",
+ "--workspace-pools-targets=flash:c",
+ "--workspace-pools-size-hint-bytes=sram:1024",
+ "--workspace-pools-size-hint-bytes=flash:2048",
+ "--workspace-pools-clock-frequency-hz=sram:4000000",
+ "--workspace-pools-clock-frequency-hz=flash:2000000",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:8",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:4",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1",
+ "--workspace-pools-read-latency-cycles=sram:250",
+ "--workspace-pools-read-latency-cycles=flash:2000",
+ "--workspace-pools-write-latency-cycles=sram:500",
+ "--workspace-pools-write-latency-cycles=flash:1000",
+ "--workspace-pools-target-burst-bytes=sram:c:8",
+ "--workspace-pools-target-burst-bytes=flash:c:4",
+ ]
+ )
+
+ targets = [Target("c")]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert len(memory_pools.pools) == 2
+
+ assert memory_pools.pools[0].pool_name == "sram"
+ assert memory_pools.pools[0].size_hint_bytes == 1024
+ assert memory_pools.pools[0].clock_frequency_hz == 4000000
+ assert memory_pools.pools[0].read_bandwidth_bytes_per_cycle == 8
+ assert memory_pools.pools[0].write_bandwidth_bytes_per_cycle == 4
+ assert memory_pools.pools[0].read_latency_cycles == 250
+ assert memory_pools.pools[0].write_latency_cycles == 500
+ assert len(memory_pools.pools[0].target_burst_bytes) == 1
+ assert memory_pools.pools[0].target_burst_bytes[targets[0]] == 8
+
+ assert memory_pools.pools[1].pool_name == "flash"
+ assert memory_pools.pools[1].size_hint_bytes == 2048
+ assert memory_pools.pools[1].clock_frequency_hz == 2000000
+ assert memory_pools.pools[1].read_bandwidth_bytes_per_cycle == 4
+ assert memory_pools.pools[1].write_bandwidth_bytes_per_cycle == 1
+ assert memory_pools.pools[1].read_latency_cycles == 2000
+ assert memory_pools.pools[1].write_latency_cycles == 1000
+ assert len(memory_pools.pools[1].target_burst_bytes) == 1
+ assert memory_pools.pools[1].target_burst_bytes[targets[0]] == 4
+
+
+def test_workspace_pools_recombobulate_multi_fields_ordering():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram,flash",
+ "--workspace-pools-targets=flash:c",
+ "--workspace-pools-targets=sram:c",
+ "--workspace-pools-size-hint-bytes=flash:2048",
+ "--workspace-pools-size-hint-bytes=sram:1024",
+ "--workspace-pools-clock-frequency-hz=sram:4000000",
+ "--workspace-pools-clock-frequency-hz=flash:2000000",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=sram:8",
+ "--workspace-pools-read-bandwidth-bytes-per-cycle=flash:4",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=sram:4",
+ "--workspace-pools-write-bandwidth-bytes-per-cycle=flash:1",
+ "--workspace-pools-read-latency-cycles=sram:250",
+ "--workspace-pools-read-latency-cycles=flash:2000",
+ "--workspace-pools-write-latency-cycles=flash:1000",
+ "--workspace-pools-write-latency-cycles=sram:500",
+ "--workspace-pools-target-burst-bytes=sram:c:8",
+ "--workspace-pools-target-burst-bytes=flash:c:4",
+ ]
+ )
+
+ targets = [Target("c")]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+ assert len(memory_pools.pools) == 2
+
+ assert memory_pools.pools[0].pool_name == "sram"
+ assert memory_pools.pools[0].size_hint_bytes == 1024
+ assert memory_pools.pools[0].write_latency_cycles == 500
+
+ assert memory_pools.pools[1].pool_name == "flash"
+ assert memory_pools.pools[1].size_hint_bytes == 2048
+ assert memory_pools.pools[1].write_latency_cycles == 1000
+
+
+def test_workspace_pools_recombobulate_multi_target():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:c,llvm",
+ "--workspace-pools-target-burst-bytes=sram:c:8",
+ "--workspace-pools-target-burst-bytes=sram:llvm:4",
+ ]
+ )
+
+ c_target = Target("c")
+ llvm_target = Target("llvm")
+ extra_targets = []
+
+ targets = [c_target, llvm_target]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, extra_targets)
+
+ assert len(memory_pools.pools) == 1
+
+ assert len(memory_pools.pools[0].target_burst_bytes) == 2
+ assert memory_pools.pools[0].target_burst_bytes[c_target] == 8
+ assert memory_pools.pools[0].target_burst_bytes[llvm_target] == 4
+
+
+def test_workspace_pools_recombobulate_no_target_burst_bytes():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:c",
+ "--workspace-pools-target-burst-bytes=sram:c:8",
+ ]
+ )
+
+ c_target = Target("c")
+ targets = [c_target]
+
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+ assert len(memory_pools.pools) == 1
+ assert len(memory_pools.pools[0].target_burst_bytes) == 1
+ assert memory_pools.pools[0].target_burst_bytes[c_target] == 8
+
+
+def test_workspace_pools_recombobulate_missing_target():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ ]
+ )
+
+ c_target = Target("c")
+ with pytest.raises(TVMCException):
+ workspace_pools_recombobulate(parsed, [c_target], _)
+
+
+def test_workspace_pools_recombobulate_multi_target_multi_pool():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:c,llvm",
+ "--workspace-pools-target-burst-bytes=sram:c:8",
+ "--workspace-pools-target-burst-bytes=sram:llvm:4",
+ ]
+ )
+
+ c_target = Target("c")
+ llvm_target = Target("llvm")
+
+ targets = [c_target, llvm_target]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+ assert len(memory_pools.pools) == 1
+
+ assert len(memory_pools.pools[0].target_burst_bytes) == 2
+ assert memory_pools.pools[0].target_burst_bytes[llvm_target] == 4
+ assert memory_pools.pools[0].target_burst_bytes[c_target] == 8
+
+
+def test_workspace_pools_recombobulate_parameter_overrides():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram",
+ "--workspace-pools-targets=sram:c",
+ "--workspace-pools-size-hint-bytes=sram:800",
+ "--workspace-pools-size-hint-bytes=sram:400",
+ "--workspace-pools-clock-frequency-hz=sram:4000000",
+ "--workspace-pools-clock-frequency-hz=sram:3600000",
+ ]
+ )
+
+ c_target = Target("c")
+
+ targets = [c_target]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+ assert len(memory_pools.pools) == 1
+
+ assert memory_pools.pools[0].size_hint_bytes == 400
+ assert memory_pools.pools[0].clock_frequency_hz == 3600000
+
+
+def test_workspace_pools_recombobulate_single_pool_overrides():
+ parser = argparse.ArgumentParser()
+ generate_workspace_pools_args(parser)
+ parsed, _ = parser.parse_known_args(
+ [
+ "--workspace-pools=sram,flash",
+ "--workspace-pools-targets=sram:c",
+ "--workspace-pools-targets=flash:c",
+ "--workspace-pools-targets=sram:c,llvm", # Override on one pool
+ "--workspace-pools-size-hint-bytes=sram:800",
+ "--workspace-pools-size-hint-bytes=flash:1200",
+ "--workspace-pools-size-hint-bytes=sram:400", # Override on one pool
+ ]
+ )
+
+ c_target = Target("c")
+ llvm_target = Target("llvm")
+
+ targets = [c_target, llvm_target]
+ memory_pools = workspace_pools_recombobulate(parsed, targets, _)
+
+ assert len(memory_pools.pools) == 2
+
+ assert memory_pools.pools[0].size_hint_bytes == 400
+ assert memory_pools.pools[1].size_hint_bytes == 1200
+
+ assert len(memory_pools.pools[0].targets) == 2
+ assert len(memory_pools.pools[1].targets) == 1