You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 19:00:52 UTC
[airflow] 10/19: Refactor to straighten up types
This is an automated email from the ASF dual-hosted git repository.
bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3b12a914db3962b40141498f9eb43af91ecd80b2
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Mon Apr 18 22:13:10 2022 +0800
Refactor to straighten up types
---
airflow/models/dag.py | 122 +++++++++++++++++++++-----------------------------
1 file changed, 52 insertions(+), 70 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 0694f37550..755505b5d0 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,6 +39,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
Type,
@@ -1340,40 +1341,29 @@ class DAG(LoggingMixin):
start_date = (timezone.utcnow() - timedelta(30)).replace(
hour=0, minute=0, second=0, microsecond=0
)
-
- if state is None:
- state = []
-
- return (
- cast(
- Query,
- self._get_task_instances(
- task_ids=None,
- start_date=start_date,
- end_date=end_date,
- run_id=None,
- state=state,
- include_subdags=False,
- include_parentdag=False,
- include_dependent_dags=False,
- exclude_task_ids=cast(List[str], []),
- session=session,
- ),
- )
- .order_by(DagRun.execution_date)
- .all()
+ query = self._get_task_instances(
+ task_ids=None,
+ start_date=start_date,
+ end_date=end_date,
+ run_id=None,
+ state=state or (),
+ include_subdags=False,
+ include_parentdag=False,
+ include_dependent_dags=False,
+ exclude_task_ids=(),
+ session=session,
)
+ return cast(Query, query).order_by(DagRun.execution_date).all()
@overload
def _get_task_instances(
self,
*,
task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
- task_ids_and_map_indexes,
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[TaskInstanceState, List[TaskInstanceState]],
+ state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
@@ -1392,7 +1382,7 @@ class DAG(LoggingMixin):
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[TaskInstanceState, List[TaskInstanceState]],
+ state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
@@ -1413,7 +1403,7 @@ class DAG(LoggingMixin):
start_date: Optional[datetime],
end_date: Optional[datetime],
run_id: Optional[str],
- state: Union[TaskInstanceState, List[TaskInstanceState]],
+ state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
@@ -1441,18 +1431,6 @@ class DAG(LoggingMixin):
tis = session.query(TaskInstance)
tis = tis.join(TaskInstance.dag_run)
- task_ids_and_map_indexes = None
- if task_ids is not None:
- task_ids_and_map_indexes = [item for item in task_ids if isinstance(item, tuple)]
- if task_ids_and_map_indexes:
- task_ids = None # nullify since we have indexes
-
- exclude_task_ids_and_map_indexes = None
- if exclude_task_ids is not None:
- exclude_task_ids_and_map_indexes = [item for item in exclude_task_ids if isinstance(item, tuple)]
- if exclude_task_ids_and_map_indexes:
- exclude_task_ids = None
-
if include_subdags:
# Crafting the right filter for dag_id and task_ids combo
conditions = []
@@ -1467,12 +1445,13 @@ class DAG(LoggingMixin):
tis = tis.filter(TaskInstance.run_id == run_id)
if start_date:
tis = tis.filter(DagRun.execution_date >= start_date)
- if task_ids:
- tis = tis.filter(TaskInstance.task_id.in_(task_ids))
- if task_ids_and_map_indexes:
- tis = tis.filter(
- tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
- )
+
+ if task_ids is None:
+ pass # Disable filter if not set.
+ elif isinstance(next(iter(task_ids), None), str):
+ tis = tis.filter(TI.task_id.in_(task_ids))
+ else:
+ tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids))
# This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
if end_date or not self.allow_future_exec_dates:
@@ -1610,33 +1589,29 @@ class DAG(LoggingMixin):
if as_pk_tuple:
result.update(TaskInstanceKey(*cols) for cols in tis.all())
else:
- result.update(ti.key for ti in tis.all())
+ result.update(ti.key for ti in tis)
if exclude_task_ids is not None:
- result = set(
- filter(
- lambda key: key.task_id not in exclude_task_ids,
- result,
- )
- )
-
- if exclude_task_ids_and_map_indexes is not None:
- result = set(
- filter(
- lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
- result,
- )
- )
+ result = {
+ task
+ for task in result
+ if task.task_id not in exclude_task_ids
+ and (task.task_id, task.map_index) not in exclude_task_ids
+ }
if as_pk_tuple:
return result
- elif result:
+ if result:
# We've been asked for objects, lets combine it all back in to a result set
- tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
-
- tis = session.query(TI).filter(TI.filter_for_tis(result))
- elif exclude_task_ids_and_map_indexes:
- tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes))
+ ti_filters = TI.filter_for_tis(result)
+ if ti_filters is not None:
+ tis = session.query(TI).filter(ti_filters)
+ elif exclude_task_ids is None:
+ pass # Disable filter if not set.
+ elif isinstance(next(iter(exclude_task_ids), None), str):
+ tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
+ else:
+ tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids))
return tis
@@ -1687,11 +1662,18 @@ class DAG(LoggingMixin):
task = self.get_task(task_id)
task.dag = self
- task_map_indexes = [(task, map_index)] if map_index else [task]
- task_id_map_indexes = {(task_id, map_index)} if map_index else {task_id}
+
+ tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
+ task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
+ if map_index is None:
+ tasks_to_set_state = [task]
+ task_ids_to_exclude_from_clear = {task_id}
+ else:
+ tasks_to_set_state = [(task, map_index)]
+ task_ids_to_exclude_from_clear = {(task_id, map_index)}
altered = set_state(
- tasks=task_map_indexes,
+ tasks=tasks_to_set_state,
execution_date=execution_date,
run_id=run_id,
upstream=upstream,
@@ -1726,7 +1708,7 @@ class DAG(LoggingMixin):
only_failed=True,
session=session,
# Exclude the task itself from being cleared
- exclude_task_ids=task_id_map_indexes,
+ exclude_task_ids=task_ids_to_exclude_from_clear,
)
return altered
@@ -1784,7 +1766,7 @@ class DAG(LoggingMixin):
@provide_session
def clear(
self,
- task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None,
+ task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
only_failed: bool = False,