You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/04/25 14:47:11 UTC

[tvm] branch main updated: [TVMC] compile/tune: Check if FILE exists (#10865)

This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 24e5498021 [TVMC] compile/tune: Check if FILE exists (#10865)
24e5498021 is described below

commit 24e5498021cecca2fe7d44149ce90efe28b6d930
Author: Gustavo Romero <gr...@users.noreply.github.com>
AuthorDate: Mon Apr 25 11:47:01 2022 -0300

    [TVMC] compile/tune: Check if FILE exists (#10865)
    
    Currently when a non-existing FILE is passed to 'tvmc tune' it throws
    a traceback because a FileNotFoundError exception is not handled. Since
    there is no need for such abrupt exit, and the trace can also confuse
    users, this commit fixes it by checking if FILE indeed exists, kindly
    informing the user about the non-existing FILE before exiting.
    
    Add test for verifying if 'tvmc compile' and 'tvmc tune' commands handle
    correctly the FILE option when it is invalid (e.g. missing, a dir, or a
    broken link).
    
    A TVMCException will be generated by test_tune_rpc_tracker_parsing test
    because FILE will be set by pytest to a mock object, which is not a
    valid input. Since FILE argument is irrelevant for the test in question,
    circumvent the Mock hijack of FILE argument by setting it before using
    mock.
    
    Signed-off-by: Gustavo Romero <gu...@linaro.org>
---
 python/tvm/driver/tvmc/autotuner.py           |  5 +++
 tests/python/driver/tvmc/test_autotuner.py    |  8 ++++
 tests/python/driver/tvmc/test_command_line.py | 65 ++++++++++++++++++++++++++-
 3 files changed, 77 insertions(+), 1 deletion(-)

diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py
index 97cd3bfbc1..c279b04f49 100644
--- a/python/tvm/driver/tvmc/autotuner.py
+++ b/python/tvm/driver/tvmc/autotuner.py
@@ -236,6 +236,11 @@ def drive_tune(args):
     args: argparse.Namespace
         Arguments from command line parser.
     """
+    if not os.path.isfile(args.FILE):
+        raise TVMCException(
+            f"Input file '{args.FILE}' doesn't exist, is a broken symbolic link, or a directory."
+        )
+
     tvmc_model = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes)
 
     # Specify hardware parameters, although they'll only be used if autoscheduling.
diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py
index a1915a0251..66017823a6 100644
--- a/tests/python/driver/tvmc/test_autotuner.py
+++ b/tests/python/driver/tvmc/test_autotuner.py
@@ -20,6 +20,7 @@ import os
 from unittest import mock
 
 from os import path
+from pathlib import Path
 
 from tvm import autotvm
 from tvm.driver import tvmc
@@ -163,9 +164,16 @@ def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory):
 def test_tune_rpc_tracker_parsing(mock_load_model, mock_tune_model, mock_auto_scheduler):
     cli_args = mock.MagicMock()
     cli_args.rpc_tracker = "10.0.0.1:9999"
+    # FILE is not used but it's set to a valid value here to avoid it being set
+    # by mock to a MagicMock class, which won't pass the checks for valid FILE.
+    fake_input_file = "./fake_input_file.tflite"
+    Path(fake_input_file).touch()
+    cli_args.FILE = fake_input_file
 
     tvmc.autotuner.drive_tune(cli_args)
 
+    os.remove(fake_input_file)
+
     mock_tune_model.assert_called_once()
 
     # inspect the mock call, to search for specific arguments
diff --git a/tests/python/driver/tvmc/test_command_line.py b/tests/python/driver/tvmc/test_command_line.py
index bbf608a5f2..5b15492aa4 100644
--- a/tests/python/driver/tvmc/test_command_line.py
+++ b/tests/python/driver/tvmc/test_command_line.py
@@ -14,11 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import os
 import platform
 import pytest
-import os
+import shutil
 
+from pytest_lazyfixture import lazy_fixture
 from tvm.driver.tvmc.main import _main
+from tvm.driver.tvmc.model import TVMCException
 
 
 @pytest.mark.skipif(
@@ -92,3 +95,63 @@ def test_tvmc_cl_workflow_json_config(keras_simple, tmpdir_factory):
     run_args = run_str.split(" ")[1:]
     _main(run_args)
     assert os.path.exists(output_path)
+
+
+@pytest.fixture
+def missing_file():
+    missing_file_name = "missing_file_as_invalid_input.tfite"
+    return missing_file_name
+
+
+@pytest.fixture
+def broken_symlink(tmp_path):
+    broken_symlink = "broken_symlink_as_invalid_input.tflite"
+    os.symlink("non_existing_file", tmp_path / broken_symlink)
+    yield broken_symlink
+    os.unlink(tmp_path / broken_symlink)
+
+
+@pytest.fixture
+def fake_directory(tmp_path):
+    dir_as_invalid = "dir_as_invalid_input.tflite"
+    os.mkdir(tmp_path / dir_as_invalid)
+    yield dir_as_invalid
+    shutil.rmtree(tmp_path / dir_as_invalid)
+
+
+@pytest.mark.parametrize(
+    "invalid_input",
+    [lazy_fixture("missing_file"), lazy_fixture("broken_symlink"), lazy_fixture("fake_directory")],
+)
+def test_tvmc_compile_file_check(capsys, invalid_input):
+    compile_cmd = f"tvmc compile --target 'c' {invalid_input}"
+    run_arg = compile_cmd.split(" ")[1:]
+
+    _main(run_arg)
+
+    captured = capsys.readouterr()
+    expected_err = (
+        f"Error: Input file '{invalid_input}' doesn't exist, "
+        "is a broken symbolic link, or a directory.\n"
+    )
+    on_assert_error = f"'tvmc compile' failed to check invalid FILE: {invalid_input}"
+    assert captured.err == expected_err, on_assert_error
+
+
+@pytest.mark.parametrize(
+    "invalid_input",
+    [lazy_fixture("missing_file"), lazy_fixture("broken_symlink"), lazy_fixture("fake_directory")],
+)
+def test_tvmc_tune_file_check(capsys, invalid_input):
+    tune_cmd = f"tvmc tune --target 'llvm' --output output.json {invalid_input}"
+    run_arg = tune_cmd.split(" ")[1:]
+
+    _main(run_arg)
+
+    captured = capsys.readouterr()
+    expected_err = (
+        f"Error: Input file '{invalid_input}' doesn't exist, "
+        "is a broken symbolic link, or a directory.\n"
+    )
+    on_assert_error = f"'tvmc tune' failed to check invalid FILE: {invalid_input}"
+    assert captured.err == expected_err, on_assert_error