You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/26 22:07:12 UTC

[GitHub] [tvm] gromero commented on a change in pull request #9584: [microtvm] Add TVMC test for Arduino and Zephyr

gromero commented on a change in pull request #9584:
URL: https://github.com/apache/tvm/pull/9584#discussion_r757676911



##########
File path: apps/microtvm/arduino/template_project/microtvm_api_server.py
##########
@@ -248,7 +249,10 @@ def _convert_includes(self, project_dir, source_dir):
         for ext in ("c", "h", "cpp"):
             for filename in source_dir.rglob(f"*.{ext}"):
                 with filename.open() as file:
-                    lines = file.readlines()
+                    try:
+                        lines = file.readlines()
+                    except:

Review comment:
       What exception are you trying to pass here exactly? Afaics `readlines()` method will throw an exception in quite rare cases, like "permission denied", which should not happen in that context...

##########
File path: python/tvm/micro/project.py
##########
@@ -56,6 +56,17 @@ def read(self, n, timeout_sec):
         return self._api_client.read_transport(n, timeout_sec)["data"]
 
 
+def prepare_options(received_options: dict, all_options: dict) -> dict:

Review comment:
       I think I do see the problem you're trying to solve here. It's problematic that for an option that a default value is set in `ProjectOptions`  is default value is not set accordingly in the own server code that uses it. It also complicates a bit when the default value is "probed" from the environment, for instance via `os.getenv()`.
   
   However by using `prepare_options` here it's kind of a loop: grab the default value returned by the API and feeding back it via an API call. I think it should be avoided.
   
   I think the defaults should be addressed on the server side. I wondering if something like that would fix your use case:
   
   ```
   diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py
   index 3039eb313..5f476a2d2 100644
   --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py
   +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py
   @@ -68,6 +68,10 @@ class BoardAutodetectFailed(Exception):
    
    PROJECT_TYPES = ["example_project", "host_driven"]
    
   +
   +ARDUINO_CLI_CMD = shutil.which("arduino-cli")
   +
   +
    PROJECT_OPTIONS = [
        server.ProjectOption(
            "arduino_board",
   @@ -78,7 +82,9 @@ PROJECT_OPTIONS = [
        ),
        server.ProjectOption(
            "arduino_cli_cmd",
   -        required=["build", "flash", "open_transport"],
   +        required=["generate_project", "build", "flash", "open_transport"] if not ARDUINO_CLI_CMD else None,
   +        optional=["generate_project", "build", "flash", "open_transport"] if ARDUINO_CLI_CMD else None,
   +        default=ARDUINO_CLI_CMD,
            type="str",
            help="Path to the arduino-cli tool.",
        ),
   @@ -305,10 +311,12 @@ class Handler(server.ProjectAPIHandler):
            # It's probably a standard C/C++ header
            return include_path
    
   -    def _get_platform_version(self, arduino_cli_path: str) -> float:
   +    def _get_platform_version(self, options: dict) -> float:
            # sample output of this command:
            # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n'
   -        version_output = subprocess.check_output([arduino_cli_path, "version"], encoding="utf-8")
   +        arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
   +        assert arduino_cli_cmd, "'arduino-cli command not passed and not found by default!"
   +        version_output = subprocess.check_output([arduino_cli_cmd, "version"], encoding="utf-8")
            full_version = re.findall("version: ([\.0-9]*)", version_output.lower())
            full_version = full_version[0].split(".")
            version = float(f"{full_version[0]}.{full_version[1]}")
   @@ -317,7 +325,7 @@ class Handler(server.ProjectAPIHandler):
    
        def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
            # Check Arduino version
   -        version = self._get_platform_version(options["arduino_cli_cmd"])
   +        version = self._get_platform_version(options)
            if version != ARDUINO_CLI_VERSION:
                message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}."
                if options.get("warning_as_error") is not None and options["warning_as_error"]:
   @@ -366,8 +374,11 @@ class Handler(server.ProjectAPIHandler):
        def build(self, options):
            BUILD_DIR.mkdir()
    
   +        arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
   +        assert arduino_cli_cmd, "'arduino-cli command not passed and not found by default!"
   +
            compile_cmd = [
   -            options["arduino_cli_cmd"],
   +            arduino_cli_cmd,
                "compile",
                "./project/",
                "--fqbn",
   @@ -416,7 +427,9 @@ class Handler(server.ProjectAPIHandler):
                yield parsed_row
    
        def _auto_detect_port(self, options):
   -        list_cmd = [options["arduino_cli_cmd"], "board", "list"]
   +        arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
   +        assert arduino_cli_cmd, "'arduino-cli command not passed and not found by default!"
   +        list_cmd = [arduino_cli_cmd, "board", "list"]
            list_cmd_output = subprocess.run(
                list_cmd, check=True, stdout=subprocess.PIPE
            ).stdout.decode("utf-8")
   @@ -441,8 +454,11 @@ class Handler(server.ProjectAPIHandler):
        def flash(self, options):
            port = self._get_arduino_port(options)
    
   +        arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
   +        assert arduino_cli_cmd, "'arduino-cli command not passed and not found by default!"
   +
            upload_cmd = [
   -            options["arduino_cli_cmd"],
   +            arduino_cli_cmd,
                "upload",
                "./project",
                "--fqbn",
   diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py
   index 3c96f31df..ca7a5ab40 100644
   --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py
   +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py
   @@ -231,6 +231,12 @@ if IS_TEMPLATE:
                PROJECT_TYPES.append(d.name)
    
    
   +WEST_CMD = default=sys.executable + " -m west" if sys.executable else None
   +
   +
   +ZEPHYR_BASE = os.getenv("ZEPHYR_BASE")
   +
   +
    PROJECT_OPTIONS = [
        server.ProjectOption(
            "extra_files_tar",
   @@ -271,8 +277,8 @@ PROJECT_OPTIONS = [
        ),
        server.ProjectOption(
            "west_cmd",
   -        optional=["generate_project"],
   -        default=sys.executable + " -m west" if sys.executable else None,
   +        optional=["build"],
   +        default=WEST_CMD,
            type="str",
            help=(
                "Path to the west tool. If given, supersedes both the zephyr_base "
   @@ -281,8 +287,9 @@ PROJECT_OPTIONS = [
        ),
        server.ProjectOption(
            "zephyr_base",
   -        optional=["build", "open_transport"],
   -        default=os.getenv("ZEPHYR_BASE"),
   +        required=["generate_project", "open_transport"] if not ZEPHYR_BASE else None,
   +        optional=["generate_project", "open_transport", "build"] if ZEPHYR_BASE else ["build"],
   +        default=ZEPHYR_BASE,
            type="str",
            help="Path to the zephyr base directory.",
        ),
   @@ -388,8 +395,8 @@ class Handler(server.ProjectAPIHandler):
            "aot_demo": "memory microtvm_rpc_common common",
        }
    
   -    def _get_platform_version(self) -> float:
   -        with open(pathlib.Path(os.getenv("ZEPHYR_BASE")) / "VERSION", "r") as f:
   +    def _get_platform_version(self, options: dict) -> float:
   +        with open(pathlib.Path(options.get("zephyr_base", ZEPHYR_BASE)) / "VERSION", "r") as f:
                lines = f.readlines()
                for line in lines:
                    line = line.replace(" ", "").replace("\n", "").replace("\r", "")
   @@ -402,7 +409,7 @@ class Handler(server.ProjectAPIHandler):
    
        def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
            # Check Zephyr version
   -        version = self._get_platform_version()
   +        version = self._get_platform_version(options)
            if version != ZEPHYR_VERSION:
                message = f"Zephyr version found is not supported: found {version}, expected {ZEPHYR_VERSION}."
                if options.get("warning_as_error") is not None and options["warning_as_error"]:
   @@ -574,7 +581,7 @@ def _set_nonblock(fd):
    class ZephyrSerialTransport:
        @classmethod
        def _lookup_baud_rate(cls, options):
   -        zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"])
   +        zephyr_base = options.get("zephyr_base", ZEPHYR_BASE)
            sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts"))
            try:
                import dtlib  # pylint: disable=import-outside-toplevel
   ```
   
   Here the ProjectOptions are set correctly with the default values and the server code uses them accordingly. If some option is required because a default value was not possible to be determined by any means (`vi os.getenv()` etc, for instance) it is marked as required so the cli tools like `tvmc` know about it and adapts the command line arguments accordingly, requesting the user to provide them. If the required option is not given when calling the methods outside `tvmc` then the `assert`s would take care of it to inform what's missing. In that sense neither the server API nor the wrappers in `project.py`  should enforce that the required options will be set by their defaults - it's the user's responsibility to know about them and provide them properly or to query them via the API and learn about the required options, passing the required ones accordingly.

##########
File path: tests/micro/common/test_tvmc.py
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import subprocess
+import shlex
+import sys
+import logging
+import tempfile
+import pathlib
+import sys
+import os
+
+import tvm
+from tvm.contrib.download import download_testdata
+from tvm.relay.backend import Executor, Runtime
+
+from ..zephyr.test_utils import ZEPHYR_BOARDS
+from ..arduino.test_utils import ARDUINO_BOARDS
+
+TVMC_COMMAND = [sys.executable, "-m", "tvm.driver.tvmc"]
+
+MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite"
+MODEL_FILE = "micro_speech.tflite"
+
+
+def _run_tvmc(cmd_args: list, *args, **kwargs):
+    """Run a tvmc command and return the results"""
+    cmd_args_list = TVMC_COMMAND + cmd_args
+    cwd_str = "" if "cwd" not in kwargs else f" (in cwd: {kwargs['cwd']})"
+    logging.debug("run%s: %s", cwd_str, " ".join(shlex.quote(a) for a in cmd_args_list))
+    return subprocess.check_call(cmd_args_list, *args, **kwargs)
+
+
+def _get_target_and_platform(board: str):
+    if board in ZEPHYR_BOARDS.keys():
+        target_model = ZEPHYR_BOARDS[board]
+        platform = "zephyr"
+    elif board in ARDUINO_BOARDS.keys():
+        target_model = ARDUINO_BOARDS[board]
+        platform = "arduino"
+    else:
+        raise ValueError(f"Board {board} is not supported.")
+
+    target = tvm.target.target.micro(target_model)
+    return str(target), platform
+
+
+@tvm.testing.requires_micro
+def test_tvmc_exist(board):
+    cmd_result = _run_tvmc(["micro", "-h"])
+    assert cmd_result == 0
+
+
+@tvm.testing.requires_micro
+def test_tvmc_model_build_only(board):
+    target, platform = _get_target_and_platform(board)
+
+    model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data")
+    temp_dir = pathlib.Path(tempfile.mkdtemp())
+    tar_path = str(temp_dir / "model.tar")
+    project_dir = str(temp_dir / "project")
+
+    runtime = str(Runtime("crt"))

Review comment:
       It's not necessary to instantiate a `Runtime` and `Executor` class here. The name in both classes used in the class constructors (`name` attribute) is then used to set `__repr__` accordingly:
   
   https://github.com/apache/tvm/blob/main/src/relay/backend/runtime.cc#L35-L37
   
   https://github.com/apache/tvm/blob/main/src/relay/backend/executor.cc#L34-L36
   
   so `str()` will returned the same name given when the class is instantiated.
   
   How about use `runtime = "crt"` for clarity or even `"crt"` directly in the command line?

##########
File path: tests/micro/arduino/test_utils.py
##########
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import pathlib
+import requests
+import datetime
+
+import tvm.micro
+import tvm.target.target
+from tvm.micro import project
+from tvm import relay
+from tvm.relay.backend import Executor, Runtime
+
+
+TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino"))
+
+BOARDS = TEMPLATE_PROJECT_DIR / "boards.json"
+
+
+def arduino_boards() -> dict:
+    """Returns a dict mapping board to target model"""
+    with open(BOARDS) as f:
+        board_properties = json.load(f)
+
+    boards_model = {board: info["model"] for board, info in board_properties.items()}
+    return boards_model
+
+
+ARDUINO_BOARDS = arduino_boards()
+
+
+def make_workspace_dir(test_name, board):
+    filepath = pathlib.Path(__file__)
+    board_workspace = (
+        filepath.parent
+        / f"workspace_{test_name}_{board}"
+        / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+    )
+
+    number = 0
+    while board_workspace.exists():
+        number += 1
+        board_workspace = pathlib.Path(str(board_workspace) + f"-{number}")
+    board_workspace.parent.mkdir(exist_ok=True, parents=True)
+    t = tvm.contrib.utils.tempdir(board_workspace)
+    # time.sleep(200)

Review comment:
       Is it some leftover that must be removed?

##########
File path: apps/microtvm/arduino/template_project/microtvm_api_server.py
##########
@@ -71,14 +71,15 @@ class BoardAutodetectFailed(Exception):
 PROJECT_OPTIONS = [
     server.ProjectOption(
         "arduino_board",
-        required=["build", "flash", "open_transport"],
+        required=["generate_project", "build", "flash", "open_transport"],

Review comment:
       I don't see where option `arduino_board` is necessary in `generate_project` API method. Could you please double check it and if that's indeed the case point it out to me?

##########
File path: tests/micro/common/test_tvmc.py
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import subprocess
+import shlex
+import sys
+import logging
+import tempfile
+import pathlib
+import sys
+import os
+
+import tvm
+from tvm.contrib.download import download_testdata
+from tvm.relay.backend import Executor, Runtime
+
+from ..zephyr.test_utils import ZEPHYR_BOARDS
+from ..arduino.test_utils import ARDUINO_BOARDS
+
+TVMC_COMMAND = [sys.executable, "-m", "tvm.driver.tvmc"]
+
+MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite"
+MODEL_FILE = "micro_speech.tflite"
+
+
+def _run_tvmc(cmd_args: list, *args, **kwargs):

Review comment:
       Have you considered to use `tvmc` module directly to run the command like it's done already in https://github.com/apache/tvm/blob/main/tests/python/driver/tvmc/test_mlf.py#L45 ? 

##########
File path: apps/microtvm/arduino/template_project/microtvm_api_server.py
##########
@@ -71,14 +71,15 @@ class BoardAutodetectFailed(Exception):
 PROJECT_OPTIONS = [
     server.ProjectOption(
         "arduino_board",
-        required=["build", "flash", "open_transport"],
+        required=["generate_project", "build", "flash", "open_transport"],
         choices=list(BOARD_PROPERTIES),
         type="str",
         help="Name of the Arduino board to build for.",
     ),
     server.ProjectOption(
         "arduino_cli_cmd",
-        required=["build", "flash", "open_transport"],
+        optional=["build", "flash", "open_transport"],

Review comment:
       On the other hand I see that option `arduino_cli_cmd` is actually also necessary in `generate_project` API method since it's passed an argument to `_get_platform_version()`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org