You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/08/13 23:34:03 UTC

[airflow] branch main updated: Refactor: Simplify code in scripts (#33295)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 50a6385c7f Refactor: Simplify code in scripts (#33295)
50a6385c7f is described below

commit 50a6385c7fed0fd5457886d0ffdd5040f6a8d511
Author: Miroslav Šedivý <67...@users.noreply.github.com>
AuthorDate: Sun Aug 13 23:33:55 2023 +0000

    Refactor: Simplify code in scripts (#33295)
---
 kubernetes_tests/test_kubernetes_pod_operator.py   |  2 +-
 .../pre_commit_check_pre_commit_hooks.py           |  5 +----
 scripts/ci/pre_commit/pre_commit_json_schema.py    |  2 +-
 .../pre_commit_update_common_sql_api_stubs.py      | 22 ++++------------------
 .../pre_commit_update_example_dags_paths.py        |  5 ++---
 scripts/in_container/run_migration_reference.py    |  4 ++--
 .../in_container/update_quarantined_test_status.py |  5 ++---
 scripts/in_container/verify_providers.py           |  5 ++---
 8 files changed, 15 insertions(+), 35 deletions(-)

diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index 7ba097f18d..18056977fc 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -927,7 +927,7 @@ class TestKubernetesPodOperatorSystem:
                 "  creation_timestamp: null",
                 "  deletion_grace_period_seconds: null",
             ]
-            actual = [x.getMessage() for x in caplog.records if x.msg == "Starting pod:\n%s"][0].splitlines()
+            actual = next(x.getMessage() for x in caplog.records if x.msg == "Starting pod:\n%s").splitlines()
             assert actual[: len(expected_lines)] == expected_lines
 
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
diff --git a/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py b/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py
index 113eeb8cf3..311c660d78 100755
--- a/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py
+++ b/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py
@@ -127,10 +127,7 @@ def render_template(
 
 def update_static_checks_array(hooks: dict[str, list[str]], image_hooks: list[str]):
     rows = []
-    hook_ids = list(hooks.keys())
-    hook_ids.sort()
-    for hook_id in hook_ids:
-        hook_description = hooks[hook_id]
+    for hook_id, hook_description in sorted(hooks.items()):
         formatted_hook_description = (
             hook_description[0] if len(hook_description) == 1 else "* " + "\n* ".join(hook_description)
         )
diff --git a/scripts/ci/pre_commit/pre_commit_json_schema.py b/scripts/ci/pre_commit/pre_commit_json_schema.py
index 5a82183e31..886ff13fe8 100755
--- a/scripts/ci/pre_commit/pre_commit_json_schema.py
+++ b/scripts/ci/pre_commit/pre_commit_json_schema.py
@@ -100,7 +100,7 @@ def load_file(file_path: str):
     if file_path.lower().endswith(".json"):
         with open(file_path) as input_file:
             return json.load(input_file)
-    elif file_path.lower().endswith(".yaml") or file_path.lower().endswith(".yml"):
+    elif file_path.lower().endswith((".yaml", ".yml")):
         with open(file_path) as input_file:
             return yaml.safe_load(input_file)
     raise _ValidatorError("Unknown file format. Supported extension: '.yaml', '.json'")
diff --git a/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py b/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py
index e07a4f6325..1a02ffad6a 100755
--- a/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py
+++ b/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py
@@ -69,25 +69,11 @@ def summarize_changes(results: list[str]) -> tuple[int, int]:
     """
     removals, additions = 0, 0
     for line in results:
-        if (
-            line.startswith("+")
-            or line.startswith("[green]+")
-            and not (
-                # Skip additions of comments in counting removals
-                line.startswith("+#")
-                or line.startswith("[green]+#")
-            )
-        ):
+        if line.startswith(("+", "[green]+")) and not line.startswith(("+#", "[green]+#")):
+            # Skip additions of comments in counting removals
             additions += 1
-        if (
-            line.startswith("-")
-            or line.startswith("[red]-")
-            and not (
-                # Skip removals of comments in counting removals
-                line.startswith("-#")
-                or line.startswith("[red]-#")
-            )
-        ):
+        if line.startswith(("-", "[red]+")) and not line.startswith(("-#", "[red]+#")):
+            # Skip removals of comments in counting removals
             removals += 1
     return removals, additions
 
diff --git a/scripts/ci/pre_commit/pre_commit_update_example_dags_paths.py b/scripts/ci/pre_commit/pre_commit_update_example_dags_paths.py
index 8a0a71ceb9..c7fd4aa834 100755
--- a/scripts/ci/pre_commit/pre_commit_update_example_dags_paths.py
+++ b/scripts/ci/pre_commit/pre_commit_update_example_dags_paths.py
@@ -52,11 +52,10 @@ def get_provider_and_version(url_path: str) -> tuple[str, str]:
                 provider_info = yaml.safe_load(f)
             version = provider_info["versions"][0]
             provider = "-".join(candidate_folders)
-            while provider.endswith("-"):
-                provider = provider[:-1]
+            provider = provider.rstrip("-")
             return provider, version
         except FileNotFoundError:
-            candidate_folders = candidate_folders[:-1]
+            candidate_folders.pop()
     console.print(
         f"[red]Bad example path: {url_path}. Missing "
         f"provider.yaml in any of the 'airflow/providers/{url_path}' folders. [/]"
diff --git a/scripts/in_container/run_migration_reference.py b/scripts/in_container/run_migration_reference.py
index 43692b2c45..47b8dbed25 100755
--- a/scripts/in_container/run_migration_reference.py
+++ b/scripts/in_container/run_migration_reference.py
@@ -187,6 +187,6 @@ if __name__ == "__main__":
     revisions = list(reversed(list(get_revisions())))
     ensure_airflow_version(revisions=revisions)
     revisions = list(reversed(list(get_revisions())))
-    ensure_filenames_are_sorted(revisions)
+    ensure_filenames_are_sorted(revisions=revisions)
     revisions = list(get_revisions())
-    update_docs(revisions)
+    update_docs(revisions=revisions)
diff --git a/scripts/in_container/update_quarantined_test_status.py b/scripts/in_container/update_quarantined_test_status.py
index 7f03de5c32..a9c155106c 100755
--- a/scripts/in_container/update_quarantined_test_status.py
+++ b/scripts/in_container/update_quarantined_test_status.py
@@ -63,7 +63,7 @@ status_map: dict[str, bool] = {
     ":x:": False,
 }
 
-reverse_status_map: dict[bool, str] = {status_map[key]: key for key in status_map.keys()}
+reverse_status_map: dict[bool, str] = {val: key for key, val in status_map.items()}
 
 
 def get_url(result: TestResult) -> str:
@@ -160,8 +160,7 @@ def get_history_status(history: TestHistory):
 def get_table(history_map: dict[str, TestHistory]) -> str:
     headers = ["Test", "Last run", f"Last {num_runs} runs", "Status", "Comment"]
     the_table: list[list[str]] = []
-    for ordered_key in sorted(history_map.keys()):
-        history = history_map[ordered_key]
+    for _, history in sorted(history_map.items()):
         the_table.append(
             [
                 history.url,
diff --git a/scripts/in_container/verify_providers.py b/scripts/in_container/verify_providers.py
index ea824e3e1c..6df9f9900f 100755
--- a/scripts/in_container/verify_providers.py
+++ b/scripts/in_container/verify_providers.py
@@ -187,7 +187,7 @@ def import_all_classes(
 
     for path, prefix in walkable_paths_and_prefixes.items():
         for modinfo in pkgutil.walk_packages(path=[path], prefix=prefix, onerror=onerror):
-            if not any(modinfo.name.startswith(provider_prefix) for provider_prefix in provider_prefixes):
+            if not modinfo.name.startswith(tuple(provider_prefixes)):
                 if print_skips:
                     console.print(f"Skipping module: {modinfo.name}")
                 continue
@@ -332,8 +332,7 @@ def get_details_about_classes(
     :param wrong_entities: wrong entities found for that type
     :param full_package_name: full package name
     """
-    all_entities = list(entities)
-    all_entities.sort()
+    all_entities = sorted(entities)
     TOTALS[entity_type] += len(all_entities)
     return EntityTypeSummary(
         entities=all_entities,