You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 19:00:52 UTC

[airflow] 10/19: Refactor to straighten up types

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3b12a914db3962b40141498f9eb43af91ecd80b2
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Mon Apr 18 22:13:10 2022 +0800

    Refactor to straighten up types
---
 airflow/models/dag.py | 122 +++++++++++++++++++++-----------------------------
 1 file changed, 52 insertions(+), 70 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 0694f37550..755505b5d0 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,6 +39,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Type,
@@ -1340,40 +1341,29 @@ class DAG(LoggingMixin):
             start_date = (timezone.utcnow() - timedelta(30)).replace(
                 hour=0, minute=0, second=0, microsecond=0
             )
-
-        if state is None:
-            state = []
-
-        return (
-            cast(
-                Query,
-                self._get_task_instances(
-                    task_ids=None,
-                    start_date=start_date,
-                    end_date=end_date,
-                    run_id=None,
-                    state=state,
-                    include_subdags=False,
-                    include_parentdag=False,
-                    include_dependent_dags=False,
-                    exclude_task_ids=cast(List[str], []),
-                    session=session,
-                ),
-            )
-            .order_by(DagRun.execution_date)
-            .all()
+        query = self._get_task_instances(
+            task_ids=None,
+            start_date=start_date,
+            end_date=end_date,
+            run_id=None,
+            state=state or (),
+            include_subdags=False,
+            include_parentdag=False,
+            include_dependent_dags=False,
+            exclude_task_ids=(),
+            session=session,
         )
+        return cast(Query, query).order_by(DagRun.execution_date).all()
 
     @overload
     def _get_task_instances(
         self,
         *,
         task_ids: Union[Collection[str], Collection[Tuple[str, int]], None],
-        task_ids_and_map_indexes,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, List[TaskInstanceState]],
+        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
@@ -1392,7 +1382,7 @@ class DAG(LoggingMixin):
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, List[TaskInstanceState]],
+        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
@@ -1413,7 +1403,7 @@ class DAG(LoggingMixin):
         start_date: Optional[datetime],
         end_date: Optional[datetime],
         run_id: Optional[str],
-        state: Union[TaskInstanceState, List[TaskInstanceState]],
+        state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
         include_subdags: bool,
         include_parentdag: bool,
         include_dependent_dags: bool,
@@ -1441,18 +1431,6 @@ class DAG(LoggingMixin):
             tis = session.query(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
-        task_ids_and_map_indexes = None
-        if task_ids is not None:
-            task_ids_and_map_indexes = [item for item in task_ids if isinstance(item, tuple)]
-        if task_ids_and_map_indexes:
-            task_ids = None  # nullify since we have indexes
-
-        exclude_task_ids_and_map_indexes = None
-        if exclude_task_ids is not None:
-            exclude_task_ids_and_map_indexes = [item for item in exclude_task_ids if isinstance(item, tuple)]
-        if exclude_task_ids_and_map_indexes:
-            exclude_task_ids = None
-
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
             conditions = []
@@ -1467,12 +1445,13 @@ class DAG(LoggingMixin):
             tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
             tis = tis.filter(DagRun.execution_date >= start_date)
-        if task_ids:
-            tis = tis.filter(TaskInstance.task_id.in_(task_ids))
-        if task_ids_and_map_indexes:
-            tis = tis.filter(
-                tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids_and_map_indexes)
-            )
+
+        if task_ids is None:
+            pass  # Disable filter if not set.
+        elif isinstance(next(iter(task_ids), None), str):
+            tis = tis.filter(TI.task_id.in_(task_ids))
+        else:
+            tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1610,33 +1589,29 @@ class DAG(LoggingMixin):
             if as_pk_tuple:
                 result.update(TaskInstanceKey(*cols) for cols in tis.all())
             else:
-                result.update(ti.key for ti in tis.all())
+                result.update(ti.key for ti in tis)
 
             if exclude_task_ids is not None:
-                result = set(
-                    filter(
-                        lambda key: key.task_id not in exclude_task_ids,
-                        result,
-                    )
-                )
-
-            if exclude_task_ids_and_map_indexes is not None:
-                result = set(
-                    filter(
-                        lambda key: (key.task_id, key.map_index) not in exclude_task_ids_and_map_indexes,
-                        result,
-                    )
-                )
+                result = {
+                    task
+                    for task in result
+                    if task.task_id not in exclude_task_ids
+                    and (task.task_id, task.map_index) not in exclude_task_ids
+                }
 
         if as_pk_tuple:
             return result
-        elif result:
+        if result:
             # We've been asked for objects, lets combine it all back in to a result set
-            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
-
-            tis = session.query(TI).filter(TI.filter_for_tis(result))
-        elif exclude_task_ids_and_map_indexes:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids_and_map_indexes))
+            ti_filters = TI.filter_for_tis(result)
+            if ti_filters is not None:
+                tis = session.query(TI).filter(ti_filters)
+        elif exclude_task_ids is None:
+            pass  # Disable filter if not set.
+        elif isinstance(next(iter(exclude_task_ids), None), str):
+            tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
+        else:
+            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids))
 
         return tis
 
@@ -1687,11 +1662,18 @@ class DAG(LoggingMixin):
 
         task = self.get_task(task_id)
         task.dag = self
-        task_map_indexes = [(task, map_index)] if map_index else [task]
-        task_id_map_indexes = {(task_id, map_index)} if map_index else {task_id}
+
+        tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]]
+        task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]]
+        if map_index is None:
+            tasks_to_set_state = [task]
+            task_ids_to_exclude_from_clear = {task_id}
+        else:
+            tasks_to_set_state = [(task, map_index)]
+            task_ids_to_exclude_from_clear = {(task_id, map_index)}
 
         altered = set_state(
-            tasks=task_map_indexes,
+            tasks=tasks_to_set_state,
             execution_date=execution_date,
             run_id=run_id,
             upstream=upstream,
@@ -1726,7 +1708,7 @@ class DAG(LoggingMixin):
             only_failed=True,
             session=session,
             # Exclude the task itself from being cleared
-            exclude_task_ids=task_id_map_indexes,
+            exclude_task_ids=task_ids_to_exclude_from_clear,
         )
 
         return altered
@@ -1784,7 +1766,7 @@ class DAG(LoggingMixin):
     @provide_session
     def clear(
         self,
-        task_ids: Union[Iterable[str], Iterable[Tuple[str, int]], None] = None,
+        task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         only_failed: bool = False,