You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/09/27 15:30:39 UTC

[airflow] branch main updated: Refactor dedent nested loops (#34409)

This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 07fe1d2a69 Refactor dedent nested loops (#34409)
07fe1d2a69 is described below

commit 07fe1d2a69cbe4f684a1989c047737c0686c4417
Author: Miroslav Šedivý <67...@users.noreply.github.com>
AuthorDate: Wed Sep 27 15:30:30 2023 +0000

    Refactor dedent nested loops (#34409)
---
 airflow/models/taskinstance.py                     | 48 +++++++++++-----------
 .../providers/google/cloud/transfers/sql_to_gcs.py |  4 +-
 airflow/www/security_manager.py                    |  6 +--
 airflow/www/views.py                               |  6 +--
 .../airflow_breeze/utils/exclude_from_matrix.py    |  7 ++--
 .../src/airflow_breeze/utils/selective_checks.py   |  6 +--
 dev/retag_docker_images.py                         | 26 ++++++------
 docs/exts/airflow_intersphinx.py                   |  7 ++--
 tests/models/test_dag.py                           | 28 +++++--------
 tests/sensors/test_external_task_sensor.py         |  8 ++--
 10 files changed, 66 insertions(+), 80 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 039fe68f98..3ac88d69c9 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import collections.abc
 import contextlib
 import hashlib
+import itertools
 import logging
 import math
 import operator
@@ -3073,32 +3074,31 @@ class TaskInstance(Base, LoggingMixin):
 
         # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
         # if its not, this is still  a significant optimization over querying for every single tuple key
-        for cur_dag_id in dag_ids:
-            for cur_run_id in run_ids:
-                # we compare the group size between task_id and map_index and use the smaller group
-                dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
-                dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
-
-                if len(dag_task_id_groups) <= len(dag_map_index_groups):
-                    for cur_task_id, cur_map_indices in dag_task_id_groups.items():
-                        filter_condition.append(
-                            and_(
-                                TaskInstance.dag_id == cur_dag_id,
-                                TaskInstance.run_id == cur_run_id,
-                                TaskInstance.task_id == cur_task_id,
-                                TaskInstance.map_index.in_(cur_map_indices),
-                            )
+        for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids):
+            # we compare the group size between task_id and map_index and use the smaller group
+            dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
+            dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
+
+            if len(dag_task_id_groups) <= len(dag_map_index_groups):
+                for cur_task_id, cur_map_indices in dag_task_id_groups.items():
+                    filter_condition.append(
+                        and_(
+                            TaskInstance.dag_id == cur_dag_id,
+                            TaskInstance.run_id == cur_run_id,
+                            TaskInstance.task_id == cur_task_id,
+                            TaskInstance.map_index.in_(cur_map_indices),
                         )
-                else:
-                    for cur_map_index, cur_task_ids in dag_map_index_groups.items():
-                        filter_condition.append(
-                            and_(
-                                TaskInstance.dag_id == cur_dag_id,
-                                TaskInstance.run_id == cur_run_id,
-                                TaskInstance.task_id.in_(cur_task_ids),
-                                TaskInstance.map_index == cur_map_index,
-                            )
+                    )
+            else:
+                for cur_map_index, cur_task_ids in dag_map_index_groups.items():
+                    filter_condition.append(
+                        and_(
+                            TaskInstance.dag_id == cur_dag_id,
+                            TaskInstance.run_id == cur_run_id,
+                            TaskInstance.task_id.in_(cur_task_ids),
+                            TaskInstance.map_index == cur_map_index,
                         )
+                    )
 
         return or_(*filter_condition)
 
diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index 05ce03f16d..89c1880321 100644
--- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -222,8 +222,8 @@ class BaseSQLToGCSOperator(BaseOperator):
     def _write_rows_to_parquet(parquet_writer: pq.ParquetWriter, rows):
         rows_pydic: dict[str, list[Any]] = {col: [] for col in parquet_writer.schema.names}
         for row in rows:
-            for ind, col in enumerate(parquet_writer.schema.names):
-                rows_pydic[col].append(row[ind])
+            for cell, col in zip(row, parquet_writer.schema.names):
+                rows_pydic[col].append(cell)
         tbl = pa.Table.from_pydict(rows_pydic, parquet_writer.schema)
         parquet_writer.write_table(tbl)
 
diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py
index 87f845bb4f..f1a0a92919 100644
--- a/airflow/www/security_manager.py
+++ b/airflow/www/security_manager.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import warnings
 from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence
 
@@ -731,9 +732,8 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
     def create_perm_vm_for_all_dag(self) -> None:
         """Create perm-vm if not exist and insert into FAB security model for all-dags."""
         # create perm for global logical dag
-        for resource_name in self.DAG_RESOURCES:
-            for action_name in self.DAG_ACTIONS:
-                self._merge_perm(action_name, resource_name)
+        for resource_name, action_name in itertools.product(self.DAG_RESOURCES, self.DAG_ACTIONS):
+            self._merge_perm(action_name, resource_name)
 
     def check_authorization(
         self,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 6fe8a32ea9..bd641e33f0 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1223,10 +1223,8 @@ class Airflow(AirflowBaseView):
         )
         data = get_task_stats_from_query(qry)
         payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list)
-        for dag_id in filter_dag_ids:
-            for state in State.task_states:
-                count = data.get(dag_id, {}).get(state, 0)
-                payload[dag_id].append({"state": state, "count": count})
+        for dag_id, state in itertools.product(filter_dag_ids, State.task_states):
+            payload[dag_id].append({"state": state, "count": data.get(dag_id, {}).get(state, 0)})
         return flask.json.jsonify(payload)
 
     @expose("/last_dagruns", methods=["POST"])
diff --git a/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py b/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
index 62793ebadb..8b98a0b773 100644
--- a/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
+++ b/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
@@ -16,6 +16,8 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
+
 
 def representative_combos(list_1: list[str], list_2: list[str]) -> list[tuple[str, str]]:
     """
@@ -40,8 +42,5 @@ def excluded_combos(list_1: list[str], list_2: list[str]) -> list[tuple[str, str
     :param list_2: second list
     :return: list of exclusions = list 1 x list 2 - representative_combos
     """
-    all_combos: list[tuple[str, str]] = []
-    for item_1 in list_1:
-        for item_2 in list_2:
-            all_combos.append((item_1, item_2))
+    all_combos: list[tuple[str, str]] = list(itertools.product(list_1, list_2))
     return [item for item in all_combos if item not in set(representative_combos(list_1, list_2))]
diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py
index 0c498bae7e..8251f3f4de 100644
--- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py
+++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py
@@ -470,10 +470,8 @@ class SelectiveChecks:
 
     def _match_files_with_regexps(self, matched_files, regexps):
         for file in self._files:
-            for regexp in regexps:
-                if re.match(regexp, file):
-                    matched_files.append(file)
-                    break
+            if any(re.match(regexp, file) for regexp in regexps):
+                matched_files.append(file)
 
     @lru_cache(maxsize=None)
     def _matching_files(self, match_group: T, match_dict: dict[T, list[str]]) -> list[str]:
diff --git a/dev/retag_docker_images.py b/dev/retag_docker_images.py
index ba71d3ac28..bbf63fab13 100755
--- a/dev/retag_docker_images.py
+++ b/dev/retag_docker_images.py
@@ -27,6 +27,7 @@ from __future__ import annotations
 # * when starting new release branch (for example `v2-1-test`)
 # * when renaming a branch
 #
+import itertools
 import subprocess
 
 import rich_click as click
@@ -52,19 +53,18 @@ def pull_push_all_images(
     target_branch: str,
     target_repo: str,
 ):
-    for python in PYTHON_VERSIONS:
-        for image in images:
-            source_image = image.format(
-                prefix=source_prefix, branch=source_branch, repo=source_repo, python=python
-            )
-            target_image = image.format(
-                prefix=target_prefix, branch=target_branch, repo=target_repo, python=python
-            )
-            print(f"Copying image: {source_image} -> {target_image}")
-            subprocess.run(
-                ["regctl", "image", "copy", "--force-recursive", "--digest-tags", source_image, target_image],
-                check=True,
-            )
+    for python, image in itertools.product(PYTHON_VERSIONS, images):
+        source_image = image.format(
+            prefix=source_prefix, branch=source_branch, repo=source_repo, python=python
+        )
+        target_image = image.format(
+            prefix=target_prefix, branch=target_branch, repo=target_repo, python=python
+        )
+        print(f"Copying image: {source_image} -> {target_image}")
+        subprocess.run(
+            ["regctl", "image", "copy", "--force-recursive", "--digest-tags", source_image, target_image],
+            check=True,
+        )
 
 
 @click.group(invoke_without_command=True)
diff --git a/docs/exts/airflow_intersphinx.py b/docs/exts/airflow_intersphinx.py
index 790e601a4f..b0fecdec9b 100644
--- a/docs/exts/airflow_intersphinx.py
+++ b/docs/exts/airflow_intersphinx.py
@@ -151,10 +151,9 @@ if __name__ == "__main__":
         def inspect_main(inv_data, name) -> None:
             try:
                 for key in sorted(inv_data or {}):
-                    for entry, _ in sorted(inv_data[key].items()):
-                        domain, object_type = key.split(":")
-                        role_name = domain_and_object_type_to_role(domain, object_type)
-
+                    domain, object_type = key.split(":")
+                    role_name = domain_and_object_type_to_role(domain, object_type)
+                    for entry in sorted(inv_data[key].keys()):
                         print(f":{role_name}:`{name}:{entry}`")
             except ValueError as exc:
                 print(exc.args[0] % exc.args[1:])
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 786e07aa45..f6ca5fd2e5 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import datetime
+import itertools
 import logging
 import os
 import pickle
@@ -342,12 +343,9 @@ class TestDag:
                 [EmptyOperator(task_id=f"stage{i}.{j}", priority_weight=weight) for j in range(width)]
                 for i in range(depth)
             ]
-            for i, stage in enumerate(pipeline):
-                if i == 0:
-                    continue
-                for current_task in stage:
-                    for prev_task in pipeline[i - 1]:
-                        current_task.set_upstream(prev_task)
+            for upstream, downstream in zip(pipeline, pipeline[1:]):
+                for up_task, down_task in itertools.product(upstream, downstream):
+                    down_task.set_upstream(up_task)
 
             for task in dag.task_dict.values():
                 match = pattern.match(task.task_id)
@@ -376,12 +374,9 @@ class TestDag:
                 ]
                 for i in range(depth)
             ]
-            for i, stage in enumerate(pipeline):
-                if i == 0:
-                    continue
-                for current_task in stage:
-                    for prev_task in pipeline[i - 1]:
-                        current_task.set_upstream(prev_task)
+            for upstream, downstream in zip(pipeline, pipeline[1:]):
+                for up_task, down_task in itertools.product(upstream, downstream):
+                    down_task.set_upstream(up_task)
 
             for task in dag.task_dict.values():
                 match = pattern.match(task.task_id)
@@ -409,12 +404,9 @@ class TestDag:
                 ]
                 for i in range(depth)
             ]
-            for i, stage in enumerate(pipeline):
-                if i == 0:
-                    continue
-                for current_task in stage:
-                    for prev_task in pipeline[i - 1]:
-                        current_task.set_upstream(prev_task)
+            for upstream, downstream in zip(pipeline, pipeline[1:]):
+                for up_task, down_task in itertools.product(upstream, downstream):
+                    down_task.set_upstream(up_task)
 
             for task in dag.task_dict.values():
                 # the sum of each stages after this task + itself
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 4af96d2d40..6d8920f008 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import logging
 import os
 import tempfile
@@ -1159,10 +1160,9 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session):
     run_tasks(dag_bag, execution_date=day_2)
 
     # Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared.
-    for dag in dag_bag.dags.values():
-        for execution_date in [day_1, day_2]:
-            dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
-            dagrun.set_state(State.SUCCESS)
+    for dag, execution_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]):
+        dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
+        dagrun.set_state(State.SUCCESS)
     session.flush()
 
     dag_0 = dag_bag.get_dag("parent_dag_0")