You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2022/07/05 17:36:33 UTC
[tvm] branch main updated: [microTVM] Autotuning performance tests (#11782)
This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3cca6465ba [microTVM] Autotuning performance tests (#11782)
3cca6465ba is described below
commit 3cca6465ba685921a2f5dbe711d10f5b5ee33d33
Author: Gavin Uberti <gu...@users.noreply.github.com>
AuthorDate: Tue Jul 5 10:36:26 2022 -0700
[microTVM] Autotuning performance tests (#11782)
* Common autotuning test
* Autotuned model evaluation utilities
* Bugfixes and more enablement
* Working autotune profiling test
* Refactoring based on PR comments
Bugfixes to get tests passing
Refactor to remove tflite model for consistency
Black formatting
Linting and bugfixes
Add Apache license header
Use larger chunk size to read files
Explicitly specify LRU cache size for compatibility with Python 3.7
Pass platform to microTVM common tests
Better comment for runtime bound
Stop directory from being removed after session creation
* Use the actual Zephyr timing library
Use unsigned integer
Additional logging
Try negation
Try 64 bit timer
Use Zephyr's timing library
Fix linting
Enable timing utilities
---
.../zephyr/template_project/microtvm_api_server.py | 1 +
.../zephyr/template_project/src/host_driven/main.c | 52 ++-----
python/tvm/micro/testing/__init__.py | 20 +++
python/tvm/micro/testing/aot_test_utils.py | 13 +-
python/tvm/micro/testing/evaluation.py | 150 +++++++++++++++++++++
python/tvm/micro/testing/utils.py | 19 ++-
python/tvm/testing/utils.py | 45 ++++++-
tests/lint/check_file_type.py | 1 -
tests/micro/arduino/test_utils.py | 20 +--
tests/micro/common/conftest.py | 13 +-
tests/micro/common/test_autotune.py | 96 +++++++++++++
tests/micro/common/test_tvmc.py | 27 +---
tests/micro/testdata/kws/yes_no.tflite | Bin 18712 -> 0 bytes
tests/scripts/task_python_microtvm.sh | 4 +-
14 files changed, 373 insertions(+), 88 deletions(-)
diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py
index d3559cc5f7..7b9538f6ce 100644
--- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py
+++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py
@@ -393,6 +393,7 @@ class Handler(server.ProjectAPIHandler):
if options["project_type"] == "host_driven":
f.write(
+ "CONFIG_TIMING_FUNCTIONS=y\n"
"# For RPC server C++ bindings.\n"
"CONFIG_CPLUSPLUS=y\n"
"CONFIG_LIB_CPLUSPLUS=y\n"
diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c
index 623266c0ca..ff02b3cb1d 100644
--- a/apps/microtvm/zephyr/template_project/src/host_driven/main.c
+++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c
@@ -38,6 +38,7 @@
#include <sys/printk.h>
#include <sys/reboot.h>
#include <sys/ring_buffer.h>
+#include <timing/timing.h>
#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/microtvm_rpc_server.h>
#include <unistd.h>
@@ -144,11 +145,7 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
return kTvmErrorNoError;
}
-#define MILLIS_TIL_EXPIRY 200
-#define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY))
-K_TIMER_DEFINE(g_microtvm_timer, /* expiry func */ NULL, /* stop func */ NULL);
-
-uint32_t g_microtvm_start_time;
+volatile timing_t g_microtvm_start_time, g_microtvm_end_time;
int g_microtvm_timer_running = 0;
// Called to start system timer.
@@ -161,8 +158,7 @@ tvm_crt_error_t TVMPlatformTimerStart() {
#ifdef CONFIG_LED
gpio_pin_set(led0_pin, LED0_PIN, 1);
#endif
- k_timer_start(&g_microtvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY);
- g_microtvm_start_time = k_cycle_get_32();
+ g_microtvm_start_time = timing_counter_get();
g_microtvm_timer_running = 1;
return kTvmErrorNoError;
}
@@ -174,43 +170,14 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorSystemErrorMask | 2;
}
- uint32_t stop_time = k_cycle_get_32();
#ifdef CONFIG_LED
gpio_pin_set(led0_pin, LED0_PIN, 0);
#endif
- // compute how long the work took
- uint32_t cycles_spent = stop_time - g_microtvm_start_time;
- if (stop_time < g_microtvm_start_time) {
- // we rolled over *at least* once, so correct the rollover it was *only*
- // once, because we might still use this result
- cycles_spent = ~((uint32_t)0) - (g_microtvm_start_time - stop_time);
- }
-
- uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent);
- double hw_clock_res_us = ns_spent / 1000.0;
-
- // need to grab time remaining *before* stopping. when stopped, this function
- // always returns 0.
- int32_t time_remaining_ms = k_timer_remaining_get(&g_microtvm_timer);
- k_timer_stop(&g_microtvm_timer);
- // check *after* stopping to prevent extra expiries on the happy path
- if (time_remaining_ms < 0) {
- TVMLogf("negative time remaining");
- return kTvmErrorSystemErrorMask | 3;
- }
- uint32_t num_expiries = k_timer_status_get(&g_microtvm_timer);
- uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms);
- double approx_num_cycles =
- (double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms);
- // if we approach the limits of the HW clock datatype (uint32_t), use the
- // coarse-grained timer result instead
- if (approx_num_cycles > (0.5 * (~((uint32_t)0)))) {
- *elapsed_time_seconds = timer_res_ms / 1000.0;
- } else {
- *elapsed_time_seconds = hw_clock_res_us / 1e6;
- }
-
+ g_microtvm_end_time = timing_counter_get();
+ uint64_t cycles = timing_cycles_get(&g_microtvm_start_time, &g_microtvm_end_time);
+ uint64_t ns_spent = timing_cycles_to_ns(cycles);
+ *elapsed_time_seconds = ns_spent / (double)1e9;
g_microtvm_timer_running = 0;
return kTvmErrorNoError;
}
@@ -278,6 +245,11 @@ void main(void) {
tvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console)));
uart_rx_init(&uart_rx_rbuf, tvm_uart);
+ // Initialize system timing. We could stop and start it every time, but we'll
+ // be using it enough we should just keep it enabled.
+ timing_init();
+ timing_start();
+
// Initialize microTVM RPC server, which will receive commands from the UART and execute them.
microtvm_rpc_server_t server = MicroTVMRpcServerInit(write_serial, NULL);
TVMLogf("microTVM Zephyr runtime - running");
diff --git a/python/tvm/micro/testing/__init__.py b/python/tvm/micro/testing/__init__.py
new file mode 100644
index 0000000000..9062f061bd
--- /dev/null
+++ b/python/tvm/micro/testing/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Allows the tools specified below to be imported directly from tvm.micro.testing"""
+from .evaluation import tune_model, create_aot_session, evaluate_model_accuracy
+from .utils import get_supported_boards, get_target
diff --git a/python/tvm/micro/testing/aot_test_utils.py b/python/tvm/micro/testing/aot_test_utils.py
index 82ac1ac68e..89c08395de 100644
--- a/python/tvm/micro/testing/aot_test_utils.py
+++ b/python/tvm/micro/testing/aot_test_utils.py
@@ -15,17 +15,22 @@
# specific language governing permissions and limitations
# under the License.
+"""
+This file provides utilities for running AOT tests, especially for Corstone.
+
+"""
+
import logging
import itertools
import shutil
import pytest
-pytest.importorskip("tvm.micro")
-
import tvm
from tvm.testing.aot import AOTTestRunner
+pytest.importorskip("tvm.micro")
+
_LOG = logging.getLogger(__name__)
@@ -97,9 +102,9 @@ def parametrize_aot_options(test):
valid_combinations,
)
- fn = pytest.mark.parametrize(
+ func = pytest.mark.parametrize(
["interface_api", "use_unpacked_api", "test_runner"],
marked_combinations,
)(test)
- return tvm.testing.skip_if_32bit(reason="Reference system unavailable in i386 container")(fn)
+ return tvm.testing.skip_if_32bit(reason="Reference system unavailable in i386 container")(func)
diff --git a/python/tvm/micro/testing/evaluation.py b/python/tvm/micro/testing/evaluation.py
new file mode 100644
index 0000000000..c60f0fc482
--- /dev/null
+++ b/python/tvm/micro/testing/evaluation.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Provides high-level functions for instantiating and timing AOT models. Used
+by autotuning tests in tests/micro, and may be used for more performance
+tests in the future.
+
+"""
+
+from io import StringIO
+from pathlib import Path
+from contextlib import ExitStack
+import tempfile
+
+import tvm
+
+
+def tune_model(
+ platform, board, target, mod, params, num_trials, tuner_cls=tvm.autotvm.tuner.GATuner
+):
+ """Autotunes a model with microTVM and returns a StringIO with the tuning logs"""
+ with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
+ tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target)
+ assert len(tasks) > 0
+ assert isinstance(params, dict)
+
+ module_loader = tvm.micro.AutoTvmModuleLoader(
+ template_project_dir=tvm.micro.get_microtvm_template_projects(platform),
+ project_options={
+ f"{platform}_board": board,
+ "project_type": "host_driven",
+ },
+ )
+
+ builder = tvm.autotvm.LocalBuilder(
+ n_parallel=1,
+ build_kwargs={"build_option": {"tir.disable_vectorize": True}},
+ do_fork=False,
+ build_func=tvm.micro.autotvm_build_func,
+ runtime=tvm.relay.backend.Runtime("crt", {"system-lib": True}),
+ )
+ runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)
+ measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)
+
+ results = StringIO()
+ for task in tasks:
+ tuner = tuner_cls(task)
+
+ tuner.tune(
+ n_trial=num_trials,
+ measure_option=measure_option,
+ callbacks=[
+ tvm.autotvm.callback.log_to_file(results),
+ tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"),
+ ],
+ si_prefix="M",
+ )
+ assert tuner.best_flops > 1
+
+ return results
+
+
+def create_aot_session(
+ platform,
+ board,
+ target,
+ mod,
+ params,
+ build_dir=Path(tempfile.mkdtemp()),
+ tune_logs=None,
+ use_cmsis_nn=False,
+):
+ """AOT-compiles and uploads a model to a microcontroller, and returns the RPC session"""
+
+ executor = tvm.relay.backend.Executor("aot")
+ crt_runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})
+
+ with ExitStack() as stack:
+ config = {"tir.disable_vectorize": True}
+ if use_cmsis_nn:
+ config["relay.ext.cmsisnn.options"] = {"mcpu": target.mcpu}
+ stack.enter_context(tvm.transform.PassContext(opt_level=3, config=config))
+ if tune_logs is not None:
+ stack.enter_context(tvm.autotvm.apply_history_best(tune_logs))
+
+ lowered = tvm.relay.build(
+ mod,
+ target=target,
+ params=params,
+ runtime=crt_runtime,
+ executor=executor,
+ )
+ parameter_size = len(tvm.runtime.save_param_dict(lowered.get_params()))
+ print(f"Model parameter size: {parameter_size}")
+
+ # Once the project has been uploaded, we don't need to keep it
+ project = tvm.micro.generate_project(
+ str(tvm.micro.get_microtvm_template_projects(platform)),
+ lowered,
+ build_dir / "project",
+ {
+ f"{platform}_board": board,
+ "project_type": "host_driven",
+ },
+ )
+ project.build()
+ project.flash()
+
+ return tvm.micro.Session(project.transport())
+
+
+# This utility functions was designed ONLY for one input / one output models
+# where the outputs are confidences for different classes.
+def evaluate_model_accuracy(session, aot_executor, input_data, true_labels, runs_per_sample=1):
+ """Evaluates an AOT-compiled model's accuracy and runtime over an RPC session. Works well
+ when used with create_aot_session."""
+
+ assert aot_executor.get_num_inputs() == 1
+ assert aot_executor.get_num_outputs() == 1
+ assert runs_per_sample > 0
+
+ predicted_labels = []
+ aot_runtimes = []
+ for sample in input_data:
+ aot_executor.get_input(0).copyfrom(sample)
+ result = aot_executor.module.time_evaluator("run", session.device, number=runs_per_sample)()
+ runtime = result.mean
+ output = aot_executor.get_output(0).numpy()
+ predicted_labels.append(output.argmax())
+ aot_runtimes.append(runtime)
+
+ num_correct = sum(u == v for u, v in zip(true_labels, predicted_labels))
+ average_time = sum(aot_runtimes) / len(aot_runtimes)
+ accuracy = num_correct / len(predicted_labels)
+ return average_time, accuracy
diff --git a/python/tvm/micro/testing/utils.py b/python/tvm/micro/testing/utils.py
index a48c8dc323..820b649c74 100644
--- a/python/tvm/micro/testing/utils.py
+++ b/python/tvm/micro/testing/utils.py
@@ -17,9 +17,10 @@
"""Defines the test methods used with microTVM."""
-import pathlib
+from functools import lru_cache
import json
import logging
+from pathlib import Path
import tarfile
import time
from typing import Union
@@ -32,7 +33,19 @@ from tvm.micro.project_api.server import IoTimeoutError
TIMEOUT_SEC = 10
-def check_tune_log(log_path: Union[pathlib.Path, str]):
+@lru_cache(maxsize=None)
+def get_supported_boards(platform: str):
+ template = Path(tvm.micro.get_microtvm_template_projects(platform))
+ with open(template / "boards.json") as f:
+ return json.load(f)
+
+
+def get_target(platform: str, board: str):
+ model = get_supported_boards(platform)[board]["model"]
+ return str(tvm.target.target.micro(model))
+
+
+def check_tune_log(log_path: Union[Path, str]):
"""Read the tuning log and check each result."""
with open(log_path, "r") as f:
lines = f.readlines()
@@ -76,7 +89,7 @@ def _read_line(transport, timeout_sec: int) -> str:
return data.decode(encoding="utf-8")
-def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[pathlib.Path, str]) -> int:
+def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[Path, str]) -> int:
"""Extract an MLF archive file and read workspace size from metadata file."""
workspace_size = 0
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index d7c2adaa86..47bdab5828 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -67,6 +67,7 @@ import copy
import copyreg
import ctypes
import functools
+import hashlib
import itertools
import logging
import os
@@ -77,7 +78,7 @@ import sys
import time
from pathlib import Path
-from typing import Optional, Callable, Union, List
+from typing import Optional, Callable, Union, List, Tuple
import pytest
import numpy as np
@@ -90,6 +91,7 @@ import tvm._ffi
from tvm.contrib import nvcc, cudnn
import tvm.contrib.hexagon._ci_env_check as hexagon
+from tvm.driver.tvmc.frontends import load_model
from tvm.error import TVMError
@@ -1661,6 +1663,47 @@ def install_request_hook(depth: int) -> None:
request_hook.init()
+def fetch_model_from_url(
+ url: str,
+ model_format: str,
+ sha256: str,
+) -> Tuple[tvm.ir.module.IRModule, dict]:
+ """Testing function to fetch a model from a URL and return it as a Relay
+ model. Downloaded files are cached for future re-use.
+
+ Parameters
+ ----------
+ url : str
+ The URL or list of URLs to try downloading the model from.
+
+ model_format: str
+ The file extension of the model format used.
+
+ sha256 : str
+ The sha256 hex hash to compare the downloaded model against.
+
+ Returns
+ -------
+ (mod, params) : object
+ The Relay representation of the downloaded model.
+ """
+
+ rel_path = f"model_{sha256}.{model_format}"
+ file = tvm.contrib.download.download_testdata(url, rel_path, overwrite=False)
+
+ # Check SHA-256 hash
+ file_hash = hashlib.sha256()
+ with open(file, "rb") as f:
+ for block in iter(lambda: f.read(2**24), b""):
+ file_hash.update(block)
+
+ if file_hash.hexdigest() != sha256:
+ raise FileNotFoundError("SHA-256 hash for model does not match")
+
+ tvmc_model = load_model(file, model_format)
+ return tvmc_model.mod, tvmc_model.params
+
+
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index d26b047e81..37b64433b2 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -140,7 +140,6 @@ ALLOW_SPECIFIC_FILE = {
"tests/micro/testdata/mnist/digit-2.jpg",
"tests/micro/testdata/mnist/digit-9.jpg",
"tests/micro/testdata/mnist/mnist-8.onnx",
- "tests/micro/testdata/kws/yes_no.tflite",
# microTVM Zephyr runtime
"apps/microtvm/zephyr/template_project/CMakeLists.txt.template",
"apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm",
diff --git a/tests/micro/arduino/test_utils.py b/tests/micro/arduino/test_utils.py
index c107d5b1fe..20e7d9e750 100644
--- a/tests/micro/arduino/test_utils.py
+++ b/tests/micro/arduino/test_utils.py
@@ -25,7 +25,7 @@ import tvm.target.target
from tvm.micro import project
from tvm import relay
from tvm.relay.backend import Executor, Runtime
-
+from tvm.testing.utils import fetch_model_from_url
TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino"))
@@ -66,20 +66,12 @@ def make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir):
model = ARDUINO_BOARDS[board]
build_config = {"debug": tvm_debug}
- with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f:
- tflite_model_buf = f.read()
-
- # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
- try:
- import tflite.Model
-
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
- except AttributeError:
- import tflite
-
- tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+ mod, params = fetch_model_from_url(
+ url="https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite",
+ model_format="tflite",
+ sha256="09e5e2a9dfb2d8ed78802bf18ce297bff54281a66ca18e0c23d69ca14f822a83",
+ )
- mod, params = relay.frontend.from_tflite(tflite_model)
target = tvm.target.target.micro(model)
runtime = Runtime("crt")
executor = Executor("aot", {"unpacked-api": True})
diff --git a/tests/micro/common/conftest.py b/tests/micro/common/conftest.py
index 3fbfdbcbc8..10dda8774b 100644
--- a/tests/micro/common/conftest.py
+++ b/tests/micro/common/conftest.py
@@ -21,11 +21,17 @@ from ..arduino.test_utils import ARDUINO_BOARDS
def pytest_addoption(parser):
+ parser.addoption(
+ "--platform",
+ required=True,
+ choices=["arduino", "zephyr"],
+ help="Platform to run tests with",
+ )
parser.addoption(
"--board",
required=True,
choices=list(ARDUINO_BOARDS.keys()) + list(ZEPHYR_BOARDS.keys()),
- help="microTVM boards for tests.",
+ help="microTVM boards for tests",
)
parser.addoption(
"--test-build-only",
@@ -34,6 +40,11 @@ def pytest_addoption(parser):
)
+@pytest.fixture
+def platform(request):
+ return request.config.getoption("--platform")
+
+
@pytest.fixture
def board(request):
return request.config.getoption("--board")
diff --git a/tests/micro/common/test_autotune.py b/tests/micro/common/test_autotune.py
new file mode 100644
index 0000000000..37836563a0
--- /dev/null
+++ b/tests/micro/common/test_autotune.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from io import StringIO
+import json
+from pathlib import Path
+import sys
+import tempfile
+from typing import Union
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.micro.testing
+from tvm.testing.utils import fetch_model_from_url
+
+TUNING_RUNS_PER_OPERATOR = 2
+
+
+@pytest.mark.requires_hardware
+@tvm.testing.requires_micro
+def test_kws_autotune_workflow(platform, board, tmp_path):
+ mod, params = fetch_model_from_url(
+ url="https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite",
+ model_format="tflite",
+ sha256="09e5e2a9dfb2d8ed78802bf18ce297bff54281a66ca18e0c23d69ca14f822a83",
+ )
+ target = tvm.micro.testing.get_target(platform, board)
+
+ str_io_logs = tvm.micro.testing.tune_model(
+ platform, board, target, mod, params, TUNING_RUNS_PER_OPERATOR
+ )
+ assert isinstance(str_io_logs, StringIO)
+
+ str_logs = str_io_logs.getvalue().rstrip().split("\n")
+ logs = list(map(json.loads, str_logs))
+ assert len(logs) == 2 * TUNING_RUNS_PER_OPERATOR # Two operators
+
+ # Check we tested both operators
+ op_names = list(map(lambda x: x["input"][1], logs))
+ assert op_names[0] == op_names[1] == "dense_nopack.x86"
+ assert op_names[2] == op_names[3] == "dense_pack.x86"
+
+ # Make sure we tested different code. != does deep comparison in Python 3
+ assert logs[0]["config"]["index"] != logs[1]["config"]["index"]
+ assert logs[0]["config"]["entity"] != logs[1]["config"]["entity"]
+ assert logs[2]["config"]["index"] != logs[3]["config"]["index"]
+ assert logs[2]["config"]["entity"] != logs[3]["config"]["entity"]
+
+ # Compile the best model with AOT and connect to it
+ with tvm.micro.testing.create_aot_session(
+ platform,
+ board,
+ target,
+ mod,
+ params,
+ build_dir=tmp_path,
+ tune_logs=str_io_logs,
+ ) as session:
+ aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())
+
+ samples = (
+ np.random.randint(low=-127, high=128, size=(1, 1960), dtype=np.int8) for x in range(3)
+ )
+
+ labels = [0, 0, 0]
+
+ # Validate perforance across random runs
+ time, acc = tvm.micro.testing.evaluate_model_accuracy(
+ session, aot_executor, samples, labels, runs_per_sample=20
+ )
+ # `time` is the average time taken to execute model inference on the
+ # device, measured in seconds. It does not include the time to upload
+ # the input data via RPC. On slow boards like the Arduino Due, time
+ # is around 0.12 (120 ms), so this gives us plenty of buffer.
+ assert time < 1
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py
index 24d0213b77..096e12393d 100644
--- a/tests/micro/common/test_tvmc.py
+++ b/tests/micro/common/test_tvmc.py
@@ -29,9 +29,6 @@ import tvm
import tvm.testing
from tvm.contrib.download import download_testdata
-from ..zephyr.test_utils import ZEPHYR_BOARDS
-from ..arduino.test_utils import ARDUINO_BOARDS
-
TVMC_COMMAND = [sys.executable, "-m", "tvm.driver.tvmc"]
MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite"
@@ -47,22 +44,8 @@ def _run_tvmc(cmd_args: list, *args, **kwargs):
return subprocess.check_call(cmd_args_list, *args, **kwargs)
-def _get_target_and_platform(board: str):
- if board in ZEPHYR_BOARDS.keys():
- target_model = ZEPHYR_BOARDS[board]
- platform = "zephyr"
- elif board in ARDUINO_BOARDS.keys():
- target_model = ARDUINO_BOARDS[board]
- platform = "arduino"
- else:
- raise ValueError(f"Board {board} is not supported.")
-
- target = tvm.target.target.micro(target_model)
- return str(target), platform
-
-
@tvm.testing.requires_micro
-def test_tvmc_exist(board):
+def test_tvmc_exist(platform, board):
cmd_result = _run_tvmc(["micro", "-h"])
assert cmd_result == 0
@@ -72,8 +55,8 @@ def test_tvmc_exist(board):
"output_dir,",
[pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())],
)
-def test_tvmc_model_build_only(board, output_dir):
- target, platform = _get_target_and_platform(board)
+def test_tvmc_model_build_only(platform, board, output_dir):
+ target = tvm.micro.testing.get_target(platform, board)
if not os.path.isabs(output_dir):
out_dir_temp = os.path.abspath(output_dir)
@@ -138,8 +121,8 @@ def test_tvmc_model_build_only(board, output_dir):
"output_dir,",
[pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())],
)
-def test_tvmc_model_run(board, output_dir):
- target, platform = _get_target_and_platform(board)
+def test_tvmc_model_run(platform, board, output_dir):
+ target = tvm.micro.testing.get_target(platform, board)
if not os.path.isabs(output_dir):
out_dir_temp = os.path.abspath(output_dir)
diff --git a/tests/micro/testdata/kws/yes_no.tflite b/tests/micro/testdata/kws/yes_no.tflite
deleted file mode 100644
index 4f533dac84..0000000000
Binary files a/tests/micro/testdata/kws/yes_no.tflite and /dev/null differ
diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh
index 2274c6ca6b..e057883776 100755
--- a/tests/scripts/task_python_microtvm.sh
+++ b/tests/scripts/task_python_microtvm.sh
@@ -38,8 +38,8 @@ run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-
run_pytest ctypes python-microtvm-stm32 tests/micro/stm32
# Common Tests
-run_pytest ctypes python-microtvm-common-qemu_x86 tests/micro/common --board=qemu_x86
-run_pytest ctypes python-microtvm-common-due tests/micro/common --test-build-only --board=due
+run_pytest ctypes python-microtvm-common-qemu_x86 tests/micro/common --platform=zephyr --board=qemu_x86
+run_pytest ctypes python-microtvm-common-due tests/micro/common --platform=arduino --test-build-only --board=due
# Tutorials
python3 gallery/how_to/work_with_microtvm/micro_tflite.py