You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by dr...@apache.org on 2022/06/23 21:18:41 UTC

[tvm] branch main updated: [ci] Enable pylint for tests/python/ci (#11666)

This is an automated email from the ASF dual-hosted git repository.

driazati pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d14519e14 [ci] Enable pylint for tests/python/ci (#11666)
0d14519e14 is described below

commit 0d14519e14ba538c01eb57a83f0633fd3159e9ca
Author: driazati <94...@users.noreply.github.com>
AuthorDate: Thu Jun 23 14:18:35 2022 -0700

    [ci] Enable pylint for tests/python/ci (#11666)
    
    This fixes up the pylint issues as part of #11414 for the CI tests
---
 tests/lint/pylint.sh                           |   2 +-
 tests/python/ci/{test_utils.py => __init__.py} |   5 +-
 tests/python/ci/test_ci.py                     | 137 +++++++++++++++----------
 tests/python/ci/test_mergebot.py               |  36 +++----
 tests/python/ci/test_script_converter.py       |  22 +++-
 tests/python/ci/test_utils.py                  |  23 ++++-
 6 files changed, 140 insertions(+), 85 deletions(-)

diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh
index 3e55168f26..39568fd341 100755
--- a/tests/lint/pylint.sh
+++ b/tests/lint/pylint.sh
@@ -22,4 +22,4 @@ python3 -m pylint vta/python/vta --rcfile="$(dirname "$0")"/pylintrc
 python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile="$(dirname "$0")"/pylintrc
 python3 -m pylint tests/python/contrib/test_cmsisnn --rcfile="$(dirname "$0")"/pylintrc
 python3 -m pylint tests/python/relay/aot/*.py --rcfile="$(dirname "$0")"/pylintrc
-
+python3 -m pylint tests/python/ci --rcfile="$(dirname "$0")"/pylintrc
diff --git a/tests/python/ci/test_utils.py b/tests/python/ci/__init__.py
similarity index 89%
copy from tests/python/ci/test_utils.py
copy to tests/python/ci/__init__.py
index 0ad88f19f4..0c5f28c1f2 100644
--- a/tests/python/ci/test_utils.py
+++ b/tests/python/ci/__init__.py
@@ -14,7 +14,4 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-import pathlib
-
-REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
+"""Infrastructure and tests for CI scripts"""
diff --git a/tests/python/ci/test_ci.py b/tests/python/ci/test_ci.py
index 8cfc9bf625..8adf77d500 100644
--- a/tests/python/ci/test_ci.py
+++ b/tests/python/ci/test_ci.py
@@ -14,17 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import os
+"""Test various CI scripts and GitHub Actions workflows"""
 import subprocess
-import sys
 import json
-from tempfile import tempdir
 import textwrap
-import pytest
-import tvm.testing
 from pathlib import Path
 
-from test_utils import REPO_ROOT
+import pytest
+import tvm.testing
+from .test_utils import REPO_ROOT, TempGit
 
 
 def parameterize_named(*values):
@@ -35,18 +33,6 @@ def parameterize_named(*values):
     return pytest.mark.parametrize(",".join(keys), [tuple(d.values()) for d in values])
 
 
-class TempGit:
-    def __init__(self, cwd):
-        self.cwd = cwd
-
-    def run(self, *args, **kwargs):
-        proc = subprocess.run(["git"] + list(args), encoding="utf-8", cwd=self.cwd, **kwargs)
-        if proc.returncode != 0:
-            raise RuntimeError(f"git command failed: '{args}'")
-
-        return proc
-
-
 @pytest.mark.parametrize(
     "target_url,base_url,commit_sha,expected_url,expected_body",
     [
@@ -55,13 +41,17 @@ class TempGit:
             "https://pr-docs.tlcpack.ai",
             "SHA",
             "issues/11594/comments",
-            "Built docs for commit SHA can be found [here](https://pr-docs.tlcpack.ai/PR-11594/3/docs/index.html).",
+            "Built docs for commit SHA can be found "
+            "[here](https://pr-docs.tlcpack.ai/PR-11594/3/docs/index.html).",
         )
     ],
 )
 def test_docs_comment(
     tmpdir_factory, target_url, base_url, commit_sha, expected_url, expected_body
 ):
+    """
+    Test that a comment with a link to the docs is successfully left on PRs
+    """
     docs_comment_script = REPO_ROOT / "tests" / "scripts" / "github_docs_comment.py"
 
     git = TempGit(tmpdir_factory.mktemp("tmp_git_dir"))
@@ -75,6 +65,7 @@ def test_docs_comment(
         env={"TARGET_URL": target_url, "COMMIT_SHA": commit_sha},
         encoding="utf-8",
         cwd=git.cwd,
+        check=False,
     )
     if proc.returncode != 0:
         raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}")
@@ -84,6 +75,9 @@ def test_docs_comment(
 
 @tvm.testing.skip_if_wheel_test
 def test_cc_reviewers(tmpdir_factory):
+    """
+    Test that reviewers are added from 'cc @someone' messages in PRs
+    """
     reviewers_script = REPO_ROOT / "tests" / "scripts" / "github_cc_reviewers.py"
 
     def run(pr_body, requested_reviewers, existing_review_users, expected_reviewers):
@@ -104,6 +98,7 @@ def test_cc_reviewers(tmpdir_factory):
             },
             encoding="utf-8",
             cwd=git.cwd,
+            check=False,
         )
         if proc.returncode != 0:
             raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}")
@@ -157,6 +152,9 @@ def test_cc_reviewers(tmpdir_factory):
 
 
 def test_update_branch(tmpdir_factory):
+    """
+    Test that the last-successful branch script updates successfully
+    """
     update_script = REPO_ROOT / "tests" / "scripts" / "update_branch.py"
 
     def run(statuses, expected_rc, expected_output):
@@ -182,6 +180,7 @@ def test_update_branch(tmpdir_factory):
             stderr=subprocess.PIPE,
             encoding="utf-8",
             cwd=git.cwd,
+            check=False,
         )
 
         if proc.returncode != expected_rc:
@@ -258,6 +257,9 @@ def test_update_branch(tmpdir_factory):
 
 
 def test_skip_ci(tmpdir_factory):
+    """
+    Test that CI is skipped when it should be
+    """
     skip_ci_script = REPO_ROOT / "tests" / "scripts" / "git_skip_ci.py"
 
     def test(commands, should_skip, pr_title, why):
@@ -273,7 +275,9 @@ def test_skip_ci(tmpdir_factory):
             git.run(*command)
         pr_number = "1234"
         proc = subprocess.run(
-            [str(skip_ci_script), "--pr", pr_number, "--pr-title", pr_title], cwd=git.cwd
+            [str(skip_ci_script), "--pr", pr_number, "--pr-title", pr_title],
+            cwd=git.cwd,
+            check=False,
         )
         expected = 0 if should_skip else 1
         assert proc.returncode == expected, why
@@ -311,7 +315,8 @@ def test_skip_ci(tmpdir_factory):
         ],
         should_skip=False,
         pr_title="[no skip ci] test",
-        why="ci should not be skipped on a branch with [skip ci] in the last commit but not the PR title",
+        why="ci should not be skipped on a branch with "
+        "[skip ci] in the last commit but not the PR title",
     )
 
     test(
@@ -351,6 +356,9 @@ def test_skip_ci(tmpdir_factory):
 
 
 def test_skip_globs(tmpdir_factory):
+    """
+    Test that CI is skipped if only certain files are edited
+    """
     script = REPO_ROOT / "tests" / "scripts" / "git_skip_ci_globs.py"
 
     def run(files, should_skip):
@@ -370,6 +378,7 @@ def test_skip_globs(tmpdir_factory):
             stderr=subprocess.PIPE,
             encoding="utf-8",
             cwd=git.cwd,
+            check=False,
         )
 
         if should_skip:
@@ -386,9 +395,12 @@ def test_skip_globs(tmpdir_factory):
 
 
 def test_ping_reviewers(tmpdir_factory):
+    """
+    Test that reviewers are messaged after a time period of inactivity
+    """
     reviewers_script = REPO_ROOT / "tests" / "scripts" / "ping_reviewers.py"
 
-    def run(pr, check):
+    def run(pull_request, check):
         git = TempGit(tmpdir_factory.mktemp("tmp_git_dir"))
         # Jenkins git is too old and doesn't have 'git init --initial-branch'
         git.run("init")
@@ -399,7 +411,7 @@ def test_ping_reviewers(tmpdir_factory):
             "data": {
                 "repository": {
                     "pullRequests": {
-                        "nodes": [pr],
+                        "nodes": [pull_request],
                         "edges": [],
                     }
                 }
@@ -424,6 +436,7 @@ def test_ping_reviewers(tmpdir_factory):
             stderr=subprocess.PIPE,
             encoding="utf-8",
             cwd=git.cwd,
+            check=False,
         )
         if proc.returncode != 0:
             raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}")
@@ -530,15 +543,21 @@ def test_ping_reviewers(tmpdir_factory):
 
 
 def assert_in(needle: str, haystack: str):
+    """
+    Check that 'needle' is in 'haystack'
+    """
     if needle not in haystack:
         raise AssertionError(f"item not found:\n{needle}\nin:\n{haystack}")
 
 
 @tvm.testing.skip_if_wheel_test
 def test_github_tag_teams(tmpdir_factory):
+    """
+    Check that individuals are tagged from team headers
+    """
     tag_script = REPO_ROOT / "tests" / "scripts" / "github_tag_teams.py"
 
-    def run(type, data, check):
+    def run(source_type, data, check):
         git = TempGit(tmpdir_factory.mktemp("tmp_git_dir"))
         git.run("init")
         git.run("checkout", "-b", "main")
@@ -573,7 +592,7 @@ def test_github_tag_teams(tmpdir_factory):
             }
         }
         env = {
-            type: json.dumps(data),
+            source_type: json.dumps(data),
         }
         proc = subprocess.run(
             [
@@ -587,6 +606,7 @@ def test_github_tag_teams(tmpdir_factory):
             encoding="utf-8",
             cwd=git.cwd,
             env=env,
+            check=False,
         )
         if proc.returncode != 0:
             raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}")
@@ -594,8 +614,8 @@ def test_github_tag_teams(tmpdir_factory):
         assert_in(check, proc.stdout)
 
     run(
-        "ISSUE",
-        {
+        source_type="ISSUE",
+        data={
             "title": "A title",
             "number": 1234,
             "user": {
@@ -608,12 +628,12 @@ def test_github_tag_teams(tmpdir_factory):
             """.strip()
             ),
         },
-        "No one to cc, exiting",
+        check="No one to cc, exiting",
     )
 
     run(
-        "ISSUE",
-        {
+        source_type="ISSUE",
+        data={
             "title": "A title",
             "number": 1234,
             "user": {
@@ -628,11 +648,11 @@ def test_github_tag_teams(tmpdir_factory):
             """.strip()
             ),
         },
-        "No one to cc, exiting",
+        check="No one to cc, exiting",
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "A title",
             "number": 1234,
@@ -647,11 +667,12 @@ def test_github_tag_teams(tmpdir_factory):
                 something"""
             ),
         },
-        check="would have updated issues/1234 with {'body': '\\nhello\\n\\nsomething\\n\\ncc @person1 @person2 @person4'}",
+        check="would have updated issues/1234 with {'body': "
+        "'\\nhello\\n\\nsomething\\n\\ncc @person1 @person2 @person4'}",
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "A title",
             "number": 1234,
@@ -670,7 +691,7 @@ def test_github_tag_teams(tmpdir_factory):
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "[something] A title",
             "number": 1234,
@@ -685,11 +706,12 @@ def test_github_tag_teams(tmpdir_factory):
                 something"""
             ),
         },
-        check="would have updated issues/1234 with {'body': '\\nhello\\n\\nsomething\\n\\ncc @person1 @person2 @person4'}",
+        check="would have updated issues/1234 with {'body': "
+        "'\\nhello\\n\\nsomething\\n\\ncc @person1 @person2 @person4'}",
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "[something] A title",
             "number": 1234,
@@ -708,7 +730,7 @@ def test_github_tag_teams(tmpdir_factory):
     )
 
     run(
-        type="PR",
+        source_type="PR",
         data={
             "title": "[something] A title",
             "number": 1234,
@@ -728,7 +750,7 @@ def test_github_tag_teams(tmpdir_factory):
     )
 
     run(
-        type="PR",
+        source_type="PR",
         data={
             "title": "[something] A title",
             "number": 1234,
@@ -748,7 +770,7 @@ def test_github_tag_teams(tmpdir_factory):
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "[something] A title",
             "number": 1234,
@@ -756,19 +778,17 @@ def test_github_tag_teams(tmpdir_factory):
                 "login": "person5",
             },
             "labels": [{"name": "something2"}],
-            "body": textwrap.dedent(
-                """
-                `mold` and `lld` can be a much faster alternative to `ld` from gcc. We should modify our CMakeLists.txt to detect and use these when possible. cc @person1
-
-                cc @person4
-                """
-            ),
+            "body": "`mold` and `lld` can be a much faster alternative to `ld` from gcc. "
+            "We should modify our CMakeLists.txt to detect and use these when possible. cc @person1"
+            "\n\ncc @person4",
         },
-        check="would have updated issues/1234 with {'body': '\\n`mold` and `lld` can be a much faster alternative to `ld` from gcc. We should modify our CMakeLists.txt to detect and use these when possible. cc @person1\\n\\ncc @person2 @person4\\n'}",
+        check="would have updated issues/1234 with {'body': '`mold` and `lld` can be a much"
+        " faster alternative to `ld` from gcc. We should modify our CMakeLists.txt to "
+        "detect and use these when possible. cc @person1\\n\\ncc @person2 @person4'}",
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "[something3] A title",
             "number": 1234,
@@ -778,11 +798,12 @@ def test_github_tag_teams(tmpdir_factory):
             "labels": [{"name": "something2"}],
             "body": "@person2 @SOME1-ONE-",
         },
-        check="Dry run, would have updated issues/1234 with {'body': '@person2 @SOME1-ONE-\\n\\ncc @person1'}",
+        check="Dry run, would have updated issues/1234 with"
+        " {'body': '@person2 @SOME1-ONE-\\n\\ncc @person1'}",
     )
 
     run(
-        type="ISSUE",
+        source_type="ISSUE",
         data={
             "title": "[] A title",
             "number": 1234,
@@ -856,6 +877,7 @@ def test_github_tag_teams(tmpdir_factory):
     ),
 )
 def test_open_docker_update_pr(tmpdir_factory, tlcpackstaging_body, tlcpack_body, expected):
+    """Test workflow to open a PR to update Docker images"""
     tag_script = REPO_ROOT / "tests" / "scripts" / "open_docker_update_pr.py"
 
     git = TempGit(tmpdir_factory.mktemp("tmp_git_dir"))
@@ -920,9 +942,10 @@ def test_open_docker_update_pr(tmpdir_factory, tlcpackstaging_body, tlcpack_body
     ],
 )
 def test_determine_docker_images(tmpdir_factory, images, expected):
+    """Test script to decide whether to use tlcpack or tlcpackstaging for images"""
     tag_script = REPO_ROOT / "tests" / "scripts" / "determine_docker_images.py"
 
-    dir = tmpdir_factory.mktemp("tmp_git_dir")
+    git_dir = tmpdir_factory.mktemp("tmp_git_dir")
 
     docker_data = {
         "repositories/tlcpack/ci-arm/tags/abc-abc-123": {},
@@ -935,20 +958,20 @@ def test_determine_docker_images(tmpdir_factory, images, expected):
             "--testing-docker-data",
             json.dumps(docker_data),
             "--base-dir",
-            dir,
+            git_dir,
         ]
         + images,
         stdout=subprocess.PIPE,
         stderr=subprocess.STDOUT,
         encoding="utf-8",
-        cwd=dir,
+        cwd=git_dir,
         check=False,
     )
     if proc.returncode != 0:
         raise RuntimeError(f"Failed to run script:\n{proc.stdout}")
 
     for expected_filename, expected_image in expected.items():
-        with open(Path(dir) / expected_filename) as f:
+        with open(Path(git_dir) / expected_filename) as f:
             actual_image = f.read()
 
         assert actual_image == expected_image
@@ -984,6 +1007,9 @@ def test_determine_docker_images(tmpdir_factory, images, expected):
     ],
 )
 def test_should_rebuild_docker(tmpdir_factory, changed_files, name, check, expected_code):
+    """
+    Check that the Docker images are built when necessary
+    """
     tag_script = REPO_ROOT / "tests" / "scripts" / "should_rebuild_docker.py"
 
     git = TempGit(tmpdir_factory.mktemp("tmp_git_dir"))
@@ -1037,6 +1063,7 @@ def test_should_rebuild_docker(tmpdir_factory, changed_files, name, check, expec
         stderr=subprocess.STDOUT,
         encoding="utf-8",
         cwd=git.cwd,
+        check=False,
     )
 
     assert_in(check, proc.stdout)
diff --git a/tests/python/ci/test_mergebot.py b/tests/python/ci/test_mergebot.py
index 75f56eee56..ccdfdc6539 100644
--- a/tests/python/ci/test_mergebot.py
+++ b/tests/python/ci/test_mergebot.py
@@ -14,27 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""
+Test the @tvm-bot merge code
+"""
 
-import os
 import subprocess
 import json
-import sys
-import pytest
-
 from pathlib import Path
 
-import tvm.testing
-from test_utils import REPO_ROOT
-
-
-class TempGit:
-    def __init__(self, cwd):
-        self.cwd = cwd
-
-    def run(self, *args, **kwargs):
-        proc = subprocess.run(["git"] + list(args), cwd=self.cwd, **kwargs)
-        if proc.returncode != 0:
-            raise RuntimeError(f"git command failed: '{args}'")
+import pytest
+import tvm
+from .test_utils import REPO_ROOT, TempGit
 
 
 SUCCESS_EXPECTED_OUTPUT = """
@@ -47,7 +37,7 @@ Dry run, would have merged with url=pulls/10786/merge and data={
 """.strip()
 
 
-test_data = {
+TEST_DATA = {
     "successful-merge": {
         "number": 10786,
         "filename": "pr10786-merges.json",
@@ -118,7 +108,7 @@ test_data = {
         "expected": "Cannot merge, found [this review]",
         "comment": "@tvm-bot merge",
         "user": "abc",
-        "detail": "Check that a merge request with a 'Changes Requested' review on HEAD is rejected",
+        "detail": "Check that a merge request with a 'Changes Requested' review is rejected",
     },
     "co-authors": {
         "number": 10786,
@@ -142,10 +132,13 @@ test_data = {
 @tvm.testing.skip_if_wheel_test
 @pytest.mark.parametrize(
     ["number", "filename", "expected", "comment", "user", "detail"],
-    [tuple(d.values()) for d in test_data.values()],
-    ids=test_data.keys(),
+    [tuple(d.values()) for d in TEST_DATA.values()],
+    ids=TEST_DATA.keys(),
 )
 def test_mergebot(tmpdir_factory, number, filename, expected, comment, user, detail):
+    """
+    Test the mergebot test cases
+    """
     mergebot_script = REPO_ROOT / "tests" / "scripts" / "github_tvmbot.py"
     test_json_dir = Path(__file__).resolve().parent / "sample_prs"
 
@@ -187,6 +180,7 @@ def test_mergebot(tmpdir_factory, number, filename, expected, comment, user, det
             "TVM_BOT_JENKINS_TOKEN": "123",
         },
         cwd=git.cwd,
+        check=False,
     )
     if proc.returncode != 0:
         raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}")
@@ -196,4 +190,4 @@ def test_mergebot(tmpdir_factory, number, filename, expected, comment, user, det
 
 
 if __name__ == "__main__":
-    sys.exit(pytest.main([__file__] + sys.argv[1:]))
+    tvm.testing.main()
diff --git a/tests/python/ci/test_script_converter.py b/tests/python/ci/test_script_converter.py
index a792c13581..e249827afe 100644
--- a/tests/python/ci/test_script_converter.py
+++ b/tests/python/ci/test_script_converter.py
@@ -14,14 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""
+Test the conversion of bash to rst
+"""
 
 import sys
 
-import pytest
-
+import tvm
 from tvm.contrib import utils
 
-from test_utils import REPO_ROOT
+# this has to be after the sys.path patching, so ignore pylint
+# pylint: disable=wrong-import-position,wrong-import-order
+from .test_utils import REPO_ROOT
 
 sys.path.insert(0, str(REPO_ROOT / "docs"))
 from script_convert import (
@@ -32,8 +36,11 @@ from script_convert import (
     BASH_MULTILINE_COMMENT_END,
 )
 
+# pylint: enable=wrong-import-position,wrong-import-order
+
 
 def test_bash_cmd():
+    """Test that a bash command gets turned into a rst code block"""
     temp = utils.tempdir()
     src_path = temp / "src.sh"
     dest_path = temp / "dest.py"
@@ -56,6 +63,7 @@ def test_bash_cmd():
 
 
 def test_bash_ignore_cmd():
+    """Test that ignored bash commands are not turned into code blocks"""
     temp = utils.tempdir()
     src_path = temp / "src.sh"
     dest_path = temp / "dest.py"
@@ -79,6 +87,7 @@ def test_bash_ignore_cmd():
 
 
 def test_no_command():
+    """Test a file with no code blocks"""
     temp = utils.tempdir()
     src_path = temp / "src.sh"
     dest_path = temp / "dest.py"
@@ -98,6 +107,7 @@ def test_no_command():
 
 
 def test_text_and_bash_command():
+    """Test a file with a bash code block"""
     temp = utils.tempdir()
     src_path = temp / "src.sh"
     dest_path = temp / "dest.py"
@@ -122,6 +132,7 @@ def test_text_and_bash_command():
 
 
 def test_last_line_break():
+    """Test that line endings are correct"""
     temp = utils.tempdir()
     src_path = temp / "src.sh"
     dest_path = temp / "dest.py"
@@ -141,6 +152,7 @@ def test_last_line_break():
 
 
 def test_multiline_comment():
+    """Test that bash comments are inserted correctly"""
     temp = utils.tempdir()
     src_path = temp / "src.sh"
     dest_path = temp / "dest.py"
@@ -160,3 +172,7 @@ def test_multiline_comment():
     expected_cmd = '"""\n' "comment\n" '"""\n'
 
     assert generated_cmd == expected_cmd
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/ci/test_utils.py b/tests/python/ci/test_utils.py
index 0ad88f19f4..513601aa1b 100644
--- a/tests/python/ci/test_utils.py
+++ b/tests/python/ci/test_utils.py
@@ -14,7 +14,28 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+"""
+Constants used in various CI tests
+"""
+import subprocess
 import pathlib
 
 REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
+
+
+class TempGit:
+    """
+    A wrapper to run commands in a directory
+    """
+
+    def __init__(self, cwd):
+        self.cwd = cwd
+
+    def run(self, *args, **kwargs):
+        proc = subprocess.run(
+            ["git"] + list(args), encoding="utf-8", cwd=self.cwd, check=False, **kwargs
+        )
+        if proc.returncode != 0:
+            raise RuntimeError(f"git command failed: '{args}'")
+
+        return proc