You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/12/08 12:34:17 UTC

[tvm] branch main updated: [TVMC][MicroTVM] Fix tvmc micro `project_dir` arg relative path (#9663)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b54beed  [TVMC][MicroTVM] Fix tvmc micro `project_dir` arg relative path (#9663)
b54beed is described below

commit b54beed37ca2baad6002990b014a2119223e0900
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Wed Dec 8 04:33:54 2021 -0800

    [TVMC][MicroTVM] Fix tvmc micro `project_dir` arg relative path (#9663)
    
    * Add fix for project dir path
    
    * address @gromero comments
---
 python/tvm/driver/tvmc/common.py | 11 +++++++++--
 python/tvm/driver/tvmc/micro.py  | 27 ++++++++++++++-------------
 python/tvm/driver/tvmc/runner.py |  6 ++++--
 tests/micro/common/test_tvmc.py  | 37 +++++++++++++++++++++++++++++--------
 4 files changed, 56 insertions(+), 25 deletions(-)

diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
index 97b7c52..5319193 100644
--- a/python/tvm/driver/tvmc/common.py
+++ b/python/tvm/driver/tvmc/common.py
@@ -22,12 +22,12 @@ import json
 import logging
 import os.path
 import argparse
-
+import pathlib
+from typing import Union
 from collections import defaultdict
 from urllib.parse import urlparse
 
 import tvm
-
 from tvm.driver import tvmc
 from tvm import relay
 from tvm import transform
@@ -786,3 +786,10 @@ def get_and_check_options(passed_options, valid_options):
     check_options_choices(opts, valid_options)
 
     return opts
+
+
+def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str:
+    """Get project directory path"""
+    if not os.path.isabs(project_dir):
+        return os.path.abspath(project_dir)
+    return project_dir
diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py
index ef72446..a9c17b8 100644
--- a/python/tvm/driver/tvmc/micro.py
+++ b/python/tvm/driver/tvmc/micro.py
@@ -29,6 +29,7 @@ from .common import (
     TVMCSuppressedArgumentParser,
     get_project_options,
     get_and_check_options,
+    get_project_dir,
 )
 
 try:
@@ -238,16 +239,16 @@ def drive_micro(args):
 
 def create_project_handler(args):
     """Creates a new project dir."""
+    project_dir = get_project_dir(args.project_dir)
 
-    if os.path.exists(args.project_dir):
+    if os.path.exists(project_dir):
         if args.force:
-            shutil.rmtree(args.project_dir)
+            shutil.rmtree(project_dir)
         else:
             raise TVMCException(
                 "The specified project dir already exists. "
                 "To force overwriting it use '-f' or '--force'."
             )
-    project_dir = args.project_dir
 
     template_dir = str(Path(args.template_dir).resolve())
     if not os.path.exists(template_dir):
@@ -268,21 +269,20 @@ def create_project_handler(args):
 
 def build_handler(args):
     """Builds a firmware image given a project dir."""
+    project_dir = get_project_dir(args.project_dir)
 
-    if not os.path.exists(args.project_dir):
-        raise TVMCException(f"{args.project_dir} doesn't exist.")
+    if not os.path.exists(project_dir):
+        raise TVMCException(f"{project_dir} doesn't exist.")
 
-    if os.path.exists(args.project_dir + "/build"):
+    if os.path.exists(project_dir + "/build"):
         if args.force:
-            shutil.rmtree(args.project_dir + "/build")
+            shutil.rmtree(project_dir + "/build")
         else:
             raise TVMCException(
-                f"There is already a build in {args.project_dir}. "
+                f"There is already a build in {project_dir}. "
                 "To force rebuild it use '-f' or '--force'."
             )
 
-    project_dir = args.project_dir
-
     options = get_and_check_options(args.project_option, args.valid_options)
 
     try:
@@ -295,10 +295,11 @@ def build_handler(args):
 
 def flash_handler(args):
     """Flashes a firmware image to a target device given a project dir."""
-    if not os.path.exists(args.project_dir + "/build"):
-        raise TVMCException(f"Could not find a build in {args.project_dir}")
 
-    project_dir = args.project_dir
+    project_dir = get_project_dir(args.project_dir)
+
+    if not os.path.exists(project_dir + "/build"):
+        raise TVMCException(f"Could not find a build in {project_dir}")
 
     options = get_and_check_options(args.project_option, args.valid_options)
 
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index b140cf6..4a37906 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -40,6 +40,7 @@ from .common import (
     TVMCSuppressedArgumentParser,
     get_project_options,
     get_and_check_options,
+    get_project_dir,
 )
 from .main import register_parser
 from .model import TVMCPackage, TVMCResult
@@ -147,7 +148,7 @@ def add_run_parser(subparsers, main_parser):
             "Please build TVM with micro support (USE_MICRO ON)!"
         )
 
-    project_dir = known_args.PATH
+    project_dir = get_project_dir(known_args.PATH)
 
     try:
         project_ = project.GeneratedProject.from_directory(project_dir, None)
@@ -496,7 +497,8 @@ def run_module(
             if tvmc_package.type != "mlf":
                 raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.")
 
-            project_dir = os.path.dirname(tvmc_package.package_path)
+            project_dir = get_project_dir(tvmc_package.package_path)
+            project_dir = os.path.dirname(project_dir)
 
             # This is guaranteed to work since project_dir was already checked when
             # building the dynamic parser to accommodate the project options, so no
diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py
index d462b3f..eb0b3a6 100644
--- a/tests/micro/common/test_tvmc.py
+++ b/tests/micro/common/test_tvmc.py
@@ -23,6 +23,7 @@ import tempfile
 import pathlib
 import sys
 import os
+import shutil
 
 import tvm
 from tvm.contrib.download import download_testdata
@@ -66,13 +67,22 @@ def test_tvmc_exist(board):
 
 
 @tvm.testing.requires_micro
-def test_tvmc_model_build_only(board):
+@pytest.mark.parametrize(
+    "output_dir,",
+    [pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())],
+)
+def test_tvmc_model_build_only(board, output_dir):
     target, platform = _get_target_and_platform(board)
 
+    if not os.path.isabs(output_dir):
+        out_dir_temp = os.path.abspath(output_dir)
+        if os.path.isdir(out_dir_temp):
+            shutil.rmtree(out_dir_temp)
+        os.mkdir(out_dir_temp)
+
     model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data")
-    temp_dir = pathlib.Path(tempfile.mkdtemp())
-    tar_path = str(temp_dir / "model.tar")
-    project_dir = str(temp_dir / "project")
+    tar_path = str(output_dir / "model.tar")
+    project_dir = str(output_dir / "project")
 
     runtime = "crt"
     executor = "graph"
@@ -118,17 +128,27 @@ def test_tvmc_model_build_only(board):
         ["micro", "build", project_dir, platform, "--project-option", f"{platform}_board={board}"]
     )
     assert cmd_result == 0, "tvmc micro failed in step: build"
+    shutil.rmtree(output_dir)
 
 
 @pytest.mark.requires_hardware
 @tvm.testing.requires_micro
-def test_tvmc_model_run(board):
+@pytest.mark.parametrize(
+    "output_dir,",
+    [pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())],
+)
+def test_tvmc_model_run(board, output_dir):
     target, platform = _get_target_and_platform(board)
 
+    if not os.path.isabs(output_dir):
+        out_dir_temp = os.path.abspath(output_dir)
+        if os.path.isdir(out_dir_temp):
+            shutil.rmtree(out_dir_temp)
+        os.mkdir(out_dir_temp)
+
     model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data")
-    temp_dir = pathlib.Path(tempfile.mkdtemp())
-    tar_path = str(temp_dir / "model.tar")
-    project_dir = str(temp_dir / "project")
+    tar_path = str(output_dir / "model.tar")
+    project_dir = str(output_dir / "project")
 
     runtime = "crt"
     executor = "graph"
@@ -193,6 +213,7 @@ def test_tvmc_model_run(board):
         ]
     )
     assert cmd_result == 0, "tvmc micro failed in step: run"
+    shutil.rmtree(output_dir)
 
 
 if __name__ == "__main__":