You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/09/27 15:30:39 UTC
[airflow] branch main updated: Refactor dedent nested loops (#34409)
This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 07fe1d2a69 Refactor dedent nested loops (#34409)
07fe1d2a69 is described below
commit 07fe1d2a69cbe4f684a1989c047737c0686c4417
Author: Miroslav Šedivý <67...@users.noreply.github.com>
AuthorDate: Wed Sep 27 15:30:30 2023 +0000
Refactor dedent nested loops (#34409)
---
airflow/models/taskinstance.py | 48 +++++++++++-----------
.../providers/google/cloud/transfers/sql_to_gcs.py | 4 +-
airflow/www/security_manager.py | 6 +--
airflow/www/views.py | 6 +--
.../airflow_breeze/utils/exclude_from_matrix.py | 7 ++--
.../src/airflow_breeze/utils/selective_checks.py | 6 +--
dev/retag_docker_images.py | 26 ++++++------
docs/exts/airflow_intersphinx.py | 7 ++--
tests/models/test_dag.py | 28 +++++--------
tests/sensors/test_external_task_sensor.py | 8 ++--
10 files changed, 66 insertions(+), 80 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 039fe68f98..3ac88d69c9 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import collections.abc
import contextlib
import hashlib
+import itertools
import logging
import math
import operator
@@ -3073,32 +3074,31 @@ class TaskInstance(Base, LoggingMixin):
# this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
# if its not, this is still a significant optimization over querying for every single tuple key
- for cur_dag_id in dag_ids:
- for cur_run_id in run_ids:
- # we compare the group size between task_id and map_index and use the smaller group
- dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
- dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
-
- if len(dag_task_id_groups) <= len(dag_map_index_groups):
- for cur_task_id, cur_map_indices in dag_task_id_groups.items():
- filter_condition.append(
- and_(
- TaskInstance.dag_id == cur_dag_id,
- TaskInstance.run_id == cur_run_id,
- TaskInstance.task_id == cur_task_id,
- TaskInstance.map_index.in_(cur_map_indices),
- )
+ for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids):
+ # we compare the group size between task_id and map_index and use the smaller group
+ dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
+ dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
+
+ if len(dag_task_id_groups) <= len(dag_map_index_groups):
+ for cur_task_id, cur_map_indices in dag_task_id_groups.items():
+ filter_condition.append(
+ and_(
+ TaskInstance.dag_id == cur_dag_id,
+ TaskInstance.run_id == cur_run_id,
+ TaskInstance.task_id == cur_task_id,
+ TaskInstance.map_index.in_(cur_map_indices),
)
- else:
- for cur_map_index, cur_task_ids in dag_map_index_groups.items():
- filter_condition.append(
- and_(
- TaskInstance.dag_id == cur_dag_id,
- TaskInstance.run_id == cur_run_id,
- TaskInstance.task_id.in_(cur_task_ids),
- TaskInstance.map_index == cur_map_index,
- )
+ )
+ else:
+ for cur_map_index, cur_task_ids in dag_map_index_groups.items():
+ filter_condition.append(
+ and_(
+ TaskInstance.dag_id == cur_dag_id,
+ TaskInstance.run_id == cur_run_id,
+ TaskInstance.task_id.in_(cur_task_ids),
+ TaskInstance.map_index == cur_map_index,
)
+ )
return or_(*filter_condition)
diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index 05ce03f16d..89c1880321 100644
--- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -222,8 +222,8 @@ class BaseSQLToGCSOperator(BaseOperator):
def _write_rows_to_parquet(parquet_writer: pq.ParquetWriter, rows):
rows_pydic: dict[str, list[Any]] = {col: [] for col in parquet_writer.schema.names}
for row in rows:
- for ind, col in enumerate(parquet_writer.schema.names):
- rows_pydic[col].append(row[ind])
+ for cell, col in zip(row, parquet_writer.schema.names):
+ rows_pydic[col].append(cell)
tbl = pa.Table.from_pydict(rows_pydic, parquet_writer.schema)
parquet_writer.write_table(tbl)
diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py
index 87f845bb4f..f1a0a92919 100644
--- a/airflow/www/security_manager.py
+++ b/airflow/www/security_manager.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import itertools
import warnings
from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence
@@ -731,9 +732,8 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin):
def create_perm_vm_for_all_dag(self) -> None:
"""Create perm-vm if not exist and insert into FAB security model for all-dags."""
# create perm for global logical dag
- for resource_name in self.DAG_RESOURCES:
- for action_name in self.DAG_ACTIONS:
- self._merge_perm(action_name, resource_name)
+ for resource_name, action_name in itertools.product(self.DAG_RESOURCES, self.DAG_ACTIONS):
+ self._merge_perm(action_name, resource_name)
def check_authorization(
self,
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 6fe8a32ea9..bd641e33f0 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1223,10 +1223,8 @@ class Airflow(AirflowBaseView):
)
data = get_task_stats_from_query(qry)
payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list)
- for dag_id in filter_dag_ids:
- for state in State.task_states:
- count = data.get(dag_id, {}).get(state, 0)
- payload[dag_id].append({"state": state, "count": count})
+ for dag_id, state in itertools.product(filter_dag_ids, State.task_states):
+ payload[dag_id].append({"state": state, "count": data.get(dag_id, {}).get(state, 0)})
return flask.json.jsonify(payload)
@expose("/last_dagruns", methods=["POST"])
diff --git a/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py b/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
index 62793ebadb..8b98a0b773 100644
--- a/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
+++ b/dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+import itertools
+
def representative_combos(list_1: list[str], list_2: list[str]) -> list[tuple[str, str]]:
"""
@@ -40,8 +42,5 @@ def excluded_combos(list_1: list[str], list_2: list[str]) -> list[tuple[str, str
:param list_2: second list
:return: list of exclusions = list 1 x list 2 - representative_combos
"""
- all_combos: list[tuple[str, str]] = []
- for item_1 in list_1:
- for item_2 in list_2:
- all_combos.append((item_1, item_2))
+ all_combos: list[tuple[str, str]] = list(itertools.product(list_1, list_2))
return [item for item in all_combos if item not in set(representative_combos(list_1, list_2))]
diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py
index 0c498bae7e..8251f3f4de 100644
--- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py
+++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py
@@ -470,10 +470,8 @@ class SelectiveChecks:
def _match_files_with_regexps(self, matched_files, regexps):
for file in self._files:
- for regexp in regexps:
- if re.match(regexp, file):
- matched_files.append(file)
- break
+ if any(re.match(regexp, file) for regexp in regexps):
+ matched_files.append(file)
@lru_cache(maxsize=None)
def _matching_files(self, match_group: T, match_dict: dict[T, list[str]]) -> list[str]:
diff --git a/dev/retag_docker_images.py b/dev/retag_docker_images.py
index ba71d3ac28..bbf63fab13 100755
--- a/dev/retag_docker_images.py
+++ b/dev/retag_docker_images.py
@@ -27,6 +27,7 @@ from __future__ import annotations
# * when starting new release branch (for example `v2-1-test`)
# * when renaming a branch
#
+import itertools
import subprocess
import rich_click as click
@@ -52,19 +53,18 @@ def pull_push_all_images(
target_branch: str,
target_repo: str,
):
- for python in PYTHON_VERSIONS:
- for image in images:
- source_image = image.format(
- prefix=source_prefix, branch=source_branch, repo=source_repo, python=python
- )
- target_image = image.format(
- prefix=target_prefix, branch=target_branch, repo=target_repo, python=python
- )
- print(f"Copying image: {source_image} -> {target_image}")
- subprocess.run(
- ["regctl", "image", "copy", "--force-recursive", "--digest-tags", source_image, target_image],
- check=True,
- )
+ for python, image in itertools.product(PYTHON_VERSIONS, images):
+ source_image = image.format(
+ prefix=source_prefix, branch=source_branch, repo=source_repo, python=python
+ )
+ target_image = image.format(
+ prefix=target_prefix, branch=target_branch, repo=target_repo, python=python
+ )
+ print(f"Copying image: {source_image} -> {target_image}")
+ subprocess.run(
+ ["regctl", "image", "copy", "--force-recursive", "--digest-tags", source_image, target_image],
+ check=True,
+ )
@click.group(invoke_without_command=True)
diff --git a/docs/exts/airflow_intersphinx.py b/docs/exts/airflow_intersphinx.py
index 790e601a4f..b0fecdec9b 100644
--- a/docs/exts/airflow_intersphinx.py
+++ b/docs/exts/airflow_intersphinx.py
@@ -151,10 +151,9 @@ if __name__ == "__main__":
def inspect_main(inv_data, name) -> None:
try:
for key in sorted(inv_data or {}):
- for entry, _ in sorted(inv_data[key].items()):
- domain, object_type = key.split(":")
- role_name = domain_and_object_type_to_role(domain, object_type)
-
+ domain, object_type = key.split(":")
+ role_name = domain_and_object_type_to_role(domain, object_type)
+ for entry in sorted(inv_data[key].keys()):
print(f":{role_name}:`{name}:{entry}`")
except ValueError as exc:
print(exc.args[0] % exc.args[1:])
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 786e07aa45..f6ca5fd2e5 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import datetime
+import itertools
import logging
import os
import pickle
@@ -342,12 +343,9 @@ class TestDag:
[EmptyOperator(task_id=f"stage{i}.{j}", priority_weight=weight) for j in range(width)]
for i in range(depth)
]
- for i, stage in enumerate(pipeline):
- if i == 0:
- continue
- for current_task in stage:
- for prev_task in pipeline[i - 1]:
- current_task.set_upstream(prev_task)
+ for upstream, downstream in zip(pipeline, pipeline[1:]):
+ for up_task, down_task in itertools.product(upstream, downstream):
+ down_task.set_upstream(up_task)
for task in dag.task_dict.values():
match = pattern.match(task.task_id)
@@ -376,12 +374,9 @@ class TestDag:
]
for i in range(depth)
]
- for i, stage in enumerate(pipeline):
- if i == 0:
- continue
- for current_task in stage:
- for prev_task in pipeline[i - 1]:
- current_task.set_upstream(prev_task)
+ for upstream, downstream in zip(pipeline, pipeline[1:]):
+ for up_task, down_task in itertools.product(upstream, downstream):
+ down_task.set_upstream(up_task)
for task in dag.task_dict.values():
match = pattern.match(task.task_id)
@@ -409,12 +404,9 @@ class TestDag:
]
for i in range(depth)
]
- for i, stage in enumerate(pipeline):
- if i == 0:
- continue
- for current_task in stage:
- for prev_task in pipeline[i - 1]:
- current_task.set_upstream(prev_task)
+ for upstream, downstream in zip(pipeline, pipeline[1:]):
+ for up_task, down_task in itertools.product(upstream, downstream):
+ down_task.set_upstream(up_task)
for task in dag.task_dict.values():
# the sum of each stages after this task + itself
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 4af96d2d40..6d8920f008 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+import itertools
import logging
import os
import tempfile
@@ -1159,10 +1160,9 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session):
run_tasks(dag_bag, execution_date=day_2)
# Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared.
- for dag in dag_bag.dags.values():
- for execution_date in [day_1, day_2]:
- dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
- dagrun.set_state(State.SUCCESS)
+ for dag, execution_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]):
+ dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
+ dagrun.set_state(State.SUCCESS)
session.flush()
dag_0 = dag_bag.get_dag("parent_dag_0")