You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/06/29 15:19:52 UTC

[airflow] 07/45: Fix mapped task immutability after clear (#23667)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 8892587cce270aa504fc1a9e25d8d2279f0c71b8
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Jun 18 08:32:38 2022 +0100

    Fix mapped task immutability after clear (#23667)
    
    We should be able to detect if the structure of mapped task has changed
    and verify the integrity.
    
    This PR ensures this
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
    
    (cherry picked from commit b692517ce3aafb276e9d23570e9734c30a5f3d1f)
---
 airflow/models/dagrun.py    | 114 +++++++++++++++++++++++++------
 tests/models/test_dagrun.py | 161 +++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 251 insertions(+), 24 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index b71cd03eec..3be82b9b6d 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -642,15 +642,9 @@ class DagRun(Base, LoggingMixin):
         tis = list(self.get_task_instances(session=session, state=State.task_states))
         self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
         dag = self.get_dag()
-        for ti in tis:
-            try:
-                ti.task = dag.get_task(ti.task_id)
-            except TaskNotFound:
-                self.log.warning(
-                    "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id
-                )
-                ti.state = State.REMOVED
-                session.flush()
+        missing_indexes = self._find_missing_task_indexes(dag, tis, session=session)
+        if missing_indexes:
+            self.verify_integrity(missing_indexes=missing_indexes, session=session)
 
         unfinished_tis = [t for t in tis if t.state in State.unfinished]
         finished_tis = [t for t in tis if t.state in State.finished]
@@ -811,11 +805,17 @@ class DagRun(Base, LoggingMixin):
             Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)
 
     @provide_session
-    def verify_integrity(self, session: Session = NEW_SESSION):
+    def verify_integrity(
+        self,
+        *,
+        missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None,
+        session: Session = NEW_SESSION,
+    ):
         """
         Verifies the DagRun by checking for removed tasks or tasks that are not in the
         database yet. It will set state to removed or add the task if required.
 
+        :missing_indexes: A dictionary of task vs indexes that are missing.
         :param session: Sqlalchemy ORM Session
         """
         from airflow.settings import task_instance_mutation_hook
@@ -824,9 +824,16 @@ class DagRun(Base, LoggingMixin):
         hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
 
         dag = self.get_dag()
-        task_ids = self._check_for_removed_or_restored_tasks(
-            dag, task_instance_mutation_hook, session=session
-        )
+        task_ids: Set[str] = set()
+        if missing_indexes:
+            tis = self.get_task_instances(session=session)
+            for ti in tis:
+                task_instance_mutation_hook(ti)
+                task_ids.add(ti.task_id)
+        else:
+            task_ids, missing_indexes = self._check_for_removed_or_restored_tasks(
+                dag, task_instance_mutation_hook, session=session
+            )
 
         def task_filter(task: "Operator") -> bool:
             return task.task_id not in task_ids and (
@@ -841,27 +848,29 @@ class DagRun(Base, LoggingMixin):
         task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
 
         # Create the missing tasks, including mapped tasks
-        tasks = self._create_missing_tasks(dag, task_creator, task_filter, session=session)
+        tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session)
 
         self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)
 
     def _check_for_removed_or_restored_tasks(
         self, dag: "DAG", ti_mutation_hook, *, session: Session
-    ) -> Set[str]:
+    ) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]:
         """
-        Check for removed tasks/restored tasks.
+        Check for removed tasks/restored/missing tasks.
 
         :param dag: DAG object corresponding to the dagrun
         :param ti_mutation_hook: task_instance_mutation_hook function
         :param session: Sqlalchemy ORM Session
 
-        :return: List of task_ids in the dagrun
+        :return: List of task_ids in the dagrun and missing task indexes
 
         """
         tis = self.get_task_instances(session=session)
 
         # check for removed or restored tasks
         task_ids = set()
+        existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list)
+        expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
         for ti in tis:
             ti_mutation_hook(ti)
             task_ids.add(ti.task_id)
@@ -902,7 +911,8 @@ class DagRun(Base, LoggingMixin):
                 else:
                     self.log.info("Restoring mapped task '%s'", ti)
                     Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
-                    ti.state = State.NONE
+                    existing_indexes[task].append(ti.map_index)
+                    expected_indexes[task] = range(num_mapped_tis)
             else:
                 #  What if it is _now_ dynamically mapped, but wasn't before?
                 total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
@@ -923,8 +933,16 @@ class DagRun(Base, LoggingMixin):
                         total_length,
                     )
                     ti.state = State.REMOVED
-                    ...
-        return task_ids
+                else:
+                    self.log.info("Restoring mapped task '%s'", ti)
+                    Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
+                    existing_indexes[task].append(ti.map_index)
+                    expected_indexes[task] = range(total_length)
+        # Check if we have some missing indexes to create ti for
+        missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
+        for k, v in existing_indexes.items():
+            missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))})
+        return task_ids, missing_indexes
 
     def _get_task_creator(
         self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
@@ -961,7 +979,13 @@ class DagRun(Base, LoggingMixin):
         return creator
 
     def _create_missing_tasks(
-        self, dag: "DAG", task_creator: Callable, task_filter: Callable, *, session: Session
+        self,
+        dag: "DAG",
+        task_creator: Callable,
+        task_filter: Callable,
+        missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]],
+        *,
+        session: Session,
     ) -> Iterable["Operator"]:
         """
         Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
@@ -972,7 +996,9 @@ class DagRun(Base, LoggingMixin):
         :param session: the session to use
         """
 
-        def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
+        def expand_mapped_literals(
+            task: "Operator", sequence: Union[Sequence[int], None] = None
+        ) -> Tuple["Operator", Sequence[int]]:
             if not task.is_mapped:
                 return (task, (-1,))
             task = cast("MappedOperator", task)
@@ -981,11 +1007,19 @@ class DagRun(Base, LoggingMixin):
             )
             if not count:
                 return (task, (-1,))
+            if sequence:
+                return (task, sequence)
             return (task, range(count))
 
         tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
 
         tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
+        if missing_indexes:
+            # If there are missing indexes, override the tasks to create
+            new_tasks_and_map_idxs = itertools.starmap(
+                expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0]
+            )
+            tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs))
         return tasks
 
     def _create_task_instances(
@@ -1027,6 +1061,42 @@ class DagRun(Base, LoggingMixin):
             # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
             session.rollback()
 
+    def _find_missing_task_indexes(self, dag, tis, *, session) -> Dict["MappedOperator", Sequence[int]]:
+        """
+        Here we check if the length of the mapped task instances changed
+        at runtime. If so, we find the missing indexes.
+
+        This function also marks task instances with missing tasks as REMOVED.
+
+        :param dag: DAG object corresponding to the dagrun
+        :param tis: task instances to check
+        :param session: the session to use
+        """
+        existing_indexes: Dict["MappedOperator", list] = defaultdict(list)
+        new_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
+        for ti in tis:
+            try:
+                task = ti.task = dag.get_task(ti.task_id)
+            except TaskNotFound:
+                self.log.error("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id)
+
+                ti.state = State.REMOVED
+                session.flush()
+                continue
+            if not task.is_mapped:
+                continue
+            # skip unexpanded tasks and also tasks that expands with literal arguments
+            if ti.map_index < 0 or task.parse_time_mapped_ti_count:
+                continue
+            existing_indexes[task].append(ti.map_index)
+            task.run_time_mapped_ti_count.cache_clear()
+            new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0
+            new_indexes[task] = range(new_length)
+        missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
+        for k, v in existing_indexes.items():
+            missing_indexes.update({k: list(set(new_indexes[k]).difference(v))})
+        return missing_indexes
+
     @staticmethod
     def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
         """
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index f73f5d1c45..d45fd41370 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -41,7 +41,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE as _DEFAULT_DATE
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, clear_db_variables
 from tests.test_utils.mock_operators import MockOperator
 
 DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
@@ -54,11 +54,13 @@ class TestDagRun:
         clear_db_runs()
         clear_db_pools()
         clear_db_dags()
+        clear_db_variables()
 
     def teardown_method(self) -> None:
         clear_db_runs()
         clear_db_pools()
         clear_db_dags()
+        clear_db_variables()
 
     def create_dag_run(
         self,
@@ -899,7 +901,7 @@ def test_verify_integrity_task_start_and_end_date(Stats_incr, session, run_type,
 
     session.add(dag_run)
     session.flush()
-    dag_run.verify_integrity(session)
+    dag_run.verify_integrity(session=session)
 
     tis = dag_run.task_instances
     assert len(tis) == expected_tis
@@ -1027,6 +1029,161 @@ def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session):
     ]
 
 
+def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session):
+    """Test that when the length of mapped literal increases, additional ti is added"""
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=[1, 2, 3, 4])
+
+    dr = dag_maker.create_dagrun()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+        (3, State.NONE),
+    ]
+
+    # Now "increase" the length of literal
+    dag._remove_task('task_2')
+
+    with dag:
+        task_2.expand(arg2=[1, 2, 3, 4, 5]).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+    # Since we change the literal on the dag file itself, the dag_hash will
+    # change which will have the scheduler verify the dr integrity
+    dr.verify_integrity()
+
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+        (3, State.NONE),
+        (4, State.NONE),
+    ]
+
+
+def test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session):
+    """Test that when the length of mapped literal reduces, removed state is added"""
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=[1, 2, 3, 4])
+
+    dr = dag_maker.create_dagrun()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+        (3, State.NONE),
+    ]
+
+    # Now "reduce" the length of literal
+    dag._remove_task('task_2')
+
+    with dag:
+        task_2.expand(arg2=[1, 2]).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+    # Since we change the literal on the dag file itself, the dag_hash will
+    # change which will have the scheduler verify the dr integrity
+    dr.verify_integrity()
+
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.REMOVED),
+        (3, State.REMOVED),
+    ]
+
+
+def test_mapped_literal_length_increase_at_runtime_adds_additional_tis(dag_maker, session):
+    """Test that when the length of mapped literal increases at runtime, additional ti is added"""
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+
+    # Now "clear" and "increase" the length of literal
+    dag.clear()
+    Variable.set(key='arg1', value=[1, 2, 3, 4])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+
+    # this would be called by the localtask job
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+        (3, State.NONE),
+    ]
+
+
 @pytest.mark.need_serialized_dag
 def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
     literal = [1, 2, 3, 4]