You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/06/29 15:19:51 UTC

[airflow] 06/45: Refactor `DagRun.verify_integrity` (#24114)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5e174a12b73b9737bf48d3097c1fd5ca45a9b0e2
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Jun 10 14:44:19 2022 +0100

    Refactor `DagRun.verify_integrity` (#24114)
    
    This refactoring became necessary as there's a necessity to add additional code
    to the already exisiting code to handle mapped task immutability during run. The additional
    code would make this method difficult to read. Refactoring the code will aid understanding and
    help in debugging.
    
    (cherry picked from commit 12638d2310d962986b43af8f1584a405e280badf)
---
 airflow/models/dagrun.py | 102 ++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 88 insertions(+), 14 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index fdb566e467..b71cd03eec 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -23,6 +23,7 @@ from datetime import datetime
 from typing import (
     TYPE_CHECKING,
     Any,
+    Callable,
     Dict,
     Generator,
     Iterable,
@@ -30,6 +31,7 @@ from typing import (
     NamedTuple,
     Optional,
     Sequence,
+    Set,
     Tuple,
     Union,
     cast,
@@ -818,13 +820,50 @@ class DagRun(Base, LoggingMixin):
         """
         from airflow.settings import task_instance_mutation_hook
 
+        # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
+        hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
+
         dag = self.get_dag()
+        task_ids = self._check_for_removed_or_restored_tasks(
+            dag, task_instance_mutation_hook, session=session
+        )
+
+        def task_filter(task: "Operator") -> bool:
+            return task.task_id not in task_ids and (
+                self.is_backfill
+                or task.start_date <= self.execution_date
+                and (task.end_date is None or self.execution_date <= task.end_date)
+            )
+
+        created_counts: Dict[str, int] = defaultdict(int)
+
+        # Get task creator function
+        task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
+
+        # Create the missing tasks, including mapped tasks
+        tasks = self._create_missing_tasks(dag, task_creator, task_filter, session=session)
+
+        self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)
+
+    def _check_for_removed_or_restored_tasks(
+        self, dag: "DAG", ti_mutation_hook, *, session: Session
+    ) -> Set[str]:
+        """
+        Check for removed tasks/restored tasks.
+
+        :param dag: DAG object corresponding to the dagrun
+        :param ti_mutation_hook: task_instance_mutation_hook function
+        :param session: Sqlalchemy ORM Session
+
+        :return: List of task_ids in the dagrun
+
+        """
         tis = self.get_task_instances(session=session)
 
         # check for removed or restored tasks
         task_ids = set()
         for ti in tis:
-            task_instance_mutation_hook(ti)
+            ti_mutation_hook(ti)
             task_ids.add(ti.task_id)
             task = None
             try:
@@ -885,19 +924,21 @@ class DagRun(Base, LoggingMixin):
                     )
                     ti.state = State.REMOVED
                     ...
+        return task_ids
 
-        def task_filter(task: "Operator") -> bool:
-            return task.task_id not in task_ids and (
-                self.is_backfill
-                or task.start_date <= self.execution_date
-                and (task.end_date is None or self.execution_date <= task.end_date)
-            )
+    def _get_task_creator(
+        self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
+    ) -> Callable:
+        """
+        Get the task creator function.
 
-        created_counts: Dict[str, int] = defaultdict(int)
+        This function also updates the created_counts dictionary with the number of tasks created.
 
-        # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
-        hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
+        :param created_counts: Dictionary of task_type -> count of created TIs
+        :param ti_mutation_hook: task_instance_mutation_hook function
+        :param hook_is_noop: Whether the task_instance_mutation_hook is a noop
 
+        """
         if hook_is_noop:
 
             def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
@@ -912,13 +953,25 @@ class DagRun(Base, LoggingMixin):
             def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
                 for map_index in indexes:
                     ti = TI(task, run_id=self.run_id, map_index=map_index)
-                    task_instance_mutation_hook(ti)
+                    ti_mutation_hook(ti)
                     created_counts[ti.operator] += 1
                     yield ti
 
             creator = create_ti
+        return creator
+
+    def _create_missing_tasks(
+        self, dag: "DAG", task_creator: Callable, task_filter: Callable, *, session: Session
+    ) -> Iterable["Operator"]:
+        """
+        Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
+
+        :param dag: DAG object corresponding to the dagrun
+        :param task_creator: a function that creates tasks
+        :param task_filter: a function that filters tasks to create
+        :param session: the session to use
+        """
 
-        # Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
         def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
             if not task.is_mapped:
                 return (task, (-1,))
@@ -931,8 +984,29 @@ class DagRun(Base, LoggingMixin):
             return (task, range(count))
 
         tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
-        tasks = itertools.chain.from_iterable(itertools.starmap(creator, tasks_and_map_idxs))
 
+        tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
+        return tasks
+
+    def _create_task_instances(
+        self,
+        dag_id: str,
+        tasks: Iterable["Operator"],
+        created_counts: Dict[str, int],
+        hook_is_noop: bool,
+        *,
+        session: Session,
+    ) -> None:
+        """
+        Create the necessary task instances from the given tasks.
+
+        :param dag_id: DAG ID associated with the dagrun
+        :param tasks: the tasks to create the task instances from
+        :param created_counts: a dictionary of number of tasks -> total ti created by the task creator
+        :param hook_is_noop: whether the task_instance_mutation_hook is noop
+        :param session: the session to use
+
+        """
         try:
             if hook_is_noop:
                 session.bulk_insert_mappings(TI, tasks)
@@ -945,7 +1019,7 @@ class DagRun(Base, LoggingMixin):
         except IntegrityError:
             self.log.info(
                 'Hit IntegrityError while creating the TIs for %s- %s',
-                dag.dag_id,
+                dag_id,
                 self.run_id,
                 exc_info=True,
             )