You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/25 19:12:00 UTC

[airflow] branch main updated: Properly check the existence of missing mapped TIs (#25788)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new db818ae666 Properly check the existence of missing mapped TIs (#25788)
db818ae666 is described below

commit db818ae6665b37cd032aa6d2b0f97232462d41e1
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Aug 25 20:11:48 2022 +0100

    Properly check the existence of missing mapped TIs (#25788)
    
    The previous implementation of missing indexes was not correct. Missing indexes
    were being checked every time that `task_instance_scheduling_decision` was called.
    The missing tasks should only be revised after expanding of last resort for mapped tasks have been done. If we find that a task is in schedulable state and has already been expanded, we revise its indexes and ensure they are complete. Missing indexes are marked as removed.
    This implementation allows the revision to be done in one place
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/models/dagrun.py    | 131 ++++++---------
 tests/models/test_dagrun.py | 389 +++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 440 insertions(+), 80 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 4dd0493551..701618d5c4 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -656,9 +656,6 @@ class DagRun(Base, LoggingMixin):
                     yield ti
 
         tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
-        missing_indexes = self._revise_mapped_task_indexes(tis, session=session)
-        if missing_indexes:
-            self.verify_integrity(missing_indexes=missing_indexes, session=session)
 
         unfinished_tis = [t for t in tis if t.state in State.unfinished]
         finished_tis = [t for t in tis if t.state in State.finished]
@@ -730,6 +727,11 @@ class DagRun(Base, LoggingMixin):
                     additional_tis.extend(expanded_tis[1:])
                 expansion_happened = True
             if schedulable.state in SCHEDULEABLE_STATES:
+                task = schedulable.task
+                if isinstance(schedulable.task, MappedOperator):
+                    # Ensure the task indexes are complete
+                    created = self._revise_mapped_task_indexes(task, session=session)
+                    ready_tis.extend(created)
                 ready_tis.append(schedulable)
 
         # Check if any ti changed state
@@ -825,7 +827,6 @@ class DagRun(Base, LoggingMixin):
     def verify_integrity(
         self,
         *,
-        missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None,
         session: Session = NEW_SESSION,
     ):
         """
@@ -842,15 +843,10 @@ class DagRun(Base, LoggingMixin):
 
         dag = self.get_dag()
         task_ids: Set[str] = set()
-        if missing_indexes:
-            tis = self.get_task_instances(session=session)
-            for ti in tis:
-                task_instance_mutation_hook(ti)
-                task_ids.add(ti.task_id)
-        else:
-            task_ids, missing_indexes = self._check_for_removed_or_restored_tasks(
-                dag, task_instance_mutation_hook, session=session
-            )
+
+        task_ids = self._check_for_removed_or_restored_tasks(
+            dag, task_instance_mutation_hook, session=session
+        )
 
         def task_filter(task: "Operator") -> bool:
             return task.task_id not in task_ids and (
@@ -865,13 +861,13 @@ class DagRun(Base, LoggingMixin):
         task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
 
         # Create the missing tasks, including mapped tasks
-        tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session)
+        tasks = self._create_tasks(dag, task_creator, task_filter, session=session)
 
         self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)
 
     def _check_for_removed_or_restored_tasks(
         self, dag: "DAG", ti_mutation_hook, *, session: Session
-    ) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]:
+    ) -> Set[str]:
         """
         Check for removed tasks/restored/missing tasks.
 
@@ -879,15 +875,13 @@ class DagRun(Base, LoggingMixin):
         :param ti_mutation_hook: task_instance_mutation_hook function
         :param session: Sqlalchemy ORM Session
 
-        :return: List of task_ids in the dagrun and missing task indexes
+        :return: Task IDs in the DAG run
 
         """
         tis = self.get_task_instances(session=session)
 
         # check for removed or restored tasks
         task_ids = set()
-        existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list)
-        expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
         for ti in tis:
             ti_mutation_hook(ti)
             task_ids.add(ti.task_id)
@@ -925,13 +919,9 @@ class DagRun(Base, LoggingMixin):
                 elif ti.map_index < 0:
                     self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
                     ti.state = State.REMOVED
-                else:
-                    self.log.info("Restoring mapped task '%s'", ti)
-                    Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
-                    existing_indexes[task].append(ti.map_index)
-                    expected_indexes[task] = range(num_mapped_tis)
             else:
                 #  What if it is _now_ dynamically mapped, but wasn't before?
+                task.run_time_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
                 total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
 
                 if total_length is None:
@@ -950,16 +940,8 @@ class DagRun(Base, LoggingMixin):
                         total_length,
                     )
                     ti.state = State.REMOVED
-                else:
-                    self.log.info("Restoring mapped task '%s'", ti)
-                    Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
-                    existing_indexes[task].append(ti.map_index)
-                    expected_indexes[task] = range(total_length)
-        # Check if we have some missing indexes to create ti for
-        missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
-        for k, v in existing_indexes.items():
-            missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))})
-        return task_ids, missing_indexes
+
+        return task_ids
 
     def _get_task_creator(
         self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
@@ -995,12 +977,11 @@ class DagRun(Base, LoggingMixin):
             creator = create_ti
         return creator
 
-    def _create_missing_tasks(
+    def _create_tasks(
         self,
         dag: "DAG",
         task_creator: Callable,
         task_filter: Callable,
-        missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]],
         *,
         session: Session,
     ) -> Iterable["Operator"]:
@@ -1031,12 +1012,7 @@ class DagRun(Base, LoggingMixin):
         tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
 
         tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
-        if missing_indexes:
-            # If there are missing indexes, override the tasks to create
-            new_tasks_and_map_idxs = itertools.starmap(
-                expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0]
-            )
-            tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs))
+
         return tasks
 
     def _create_task_instances(
@@ -1082,44 +1058,45 @@ class DagRun(Base, LoggingMixin):
             # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
             session.rollback()
 
-    def _revise_mapped_task_indexes(
-        self,
-        tis: Iterable[TI],
-        *,
-        session: Session,
-    ) -> Dict["MappedOperator", Sequence[int]]:
-        """Check if the length of the mapped task instances changed at runtime and find the missing indexes.
+    def _revise_mapped_task_indexes(self, task, session: Session):
+        """Check if task increased or reduced in length and handle appropriately"""
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import task_instance_mutation_hook
 
-        :param tis: Task instances to check
-        :param session: The session to use
-        """
-        from airflow.models.mappedoperator import MappedOperator
+        task.run_time_mapped_ti_count.cache_clear()
+        total_length = (
+            task.parse_time_mapped_ti_count
+            or task.run_time_mapped_ti_count(self.run_id, session=session)
+            or 0
+        )
+        query = session.query(TaskInstance.map_index).filter(
+            TaskInstance.dag_id == self.dag_id,
+            TaskInstance.task_id == task.task_id,
+            TaskInstance.run_id == self.run_id,
+        )
+        existing_indexes = {i for (i,) in query}
+        missing_indexes = set(range(total_length)).difference(existing_indexes)
+        removed_indexes = existing_indexes.difference(range(total_length))
+        created_tis = []
 
-        existing_indexes: Dict[MappedOperator, List[int]] = defaultdict(list)
-        new_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
-        for ti in tis:
-            task = ti.task
-            if not isinstance(task, MappedOperator):
-                continue
-            # skip unexpanded tasks and also tasks that expands with literal arguments
-            if ti.map_index < 0 or task.parse_time_mapped_ti_count:
-                continue
-            existing_indexes[task].append(ti.map_index)
-            task.run_time_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
-            new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0
-
-            if ti.map_index >= new_length:
-                self.log.debug(
-                    "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
-                    ti,
-                    new_length,
-                )
-                ti.state = State.REMOVED
-            new_indexes[task] = range(new_length)
-        missing_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
-        for k, v in existing_indexes.items():
-            missing_indexes.update({k: list(set(new_indexes[k]).difference(v))})
-        return missing_indexes
+        if missing_indexes:
+            for index in missing_indexes:
+                ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
+                self.log.debug("Expanding TIs upserted %s", ti)
+                task_instance_mutation_hook(ti)
+                ti = session.merge(ti)
+                ti.refresh_from_task(task)
+                session.flush()
+                created_tis.append(ti)
+        elif removed_indexes:
+            session.query(TaskInstance).filter(
+                TaskInstance.dag_id == self.dag_id,
+                TaskInstance.task_id == task.task_id,
+                TaskInstance.run_id == self.run_id,
+                TaskInstance.map_index.in_(removed_indexes),
+            ).update({TaskInstance.state: TaskInstanceState.REMOVED})
+            session.flush()
+        return created_tis
 
     @staticmethod
     def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 6f18567680..13fb47fdd9 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1092,9 +1092,8 @@ def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session):
     serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
 
     dr.dag = serialized_dag
-    # Since we change the literal on the dag file itself, the dag_hash will
-    # change which will have the scheduler verify the dr integrity
-    dr.verify_integrity()
+    # Every mapped task is revised at task_instance_scheduling_decision
+    dr.task_instance_scheduling_decisions()
 
     tis = dr.get_task_instances()
     indices = [(ti.map_index, ti.state) for ti in tis]
@@ -1279,6 +1278,390 @@ def test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker
     ]
 
 
+def test_mapped_literal_length_with_no_change_at_runtime_doesnt_call_verify_integrity(dag_maker, session):
+    """
+    Test that when there's no change to mapped task indexes at runtime, the dagrun.verify_integrity
+    is not called
+    """
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+
+    # Now "clear" and no change to length
+    dag.clear()
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+
+    # this would be called by the localtask job
+    # Verify that DagRun.verify_integrity is not called
+    with mock.patch('airflow.models.dagrun.DagRun.verify_integrity') as mock_verify_integrity:
+        dr.task_instance_scheduling_decisions()
+        mock_verify_integrity.assert_not_called()
+
+
+def test_calls_to_verify_integrity_with_mapped_task_increase_at_runtime(dag_maker, session):
+    """
+    Test increase in mapped task at runtime with calls to dagrun.verify_integrity
+    """
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+    # Now "clear" and "increase" the length of literal
+    dag.clear()
+    Variable.set(key='arg1', value=[1, 2, 3, 4, 5])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+    task2 = dag.get_task('task_2')
+    for ti in dr.get_task_instances():
+        if ti.map_index < 0:
+            ti.task = task1
+        else:
+            ti.task = task2
+        session.merge(ti)
+    session.flush()
+    # create the additional task
+    dr.task_instance_scheduling_decisions()
+    # Run verify_integrity as a whole and assert new tasks were added
+    dr.verify_integrity()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+        (3, State.NONE),
+        (4, State.NONE),
+    ]
+    ti3 = dr.get_task_instance(task_id='task_2', map_index=3)
+    ti3.task = task2
+    ti3.state = TaskInstanceState.FAILED
+    session.merge(ti3)
+    session.flush()
+    # assert repeated calls did not change the instances
+    dr.verify_integrity()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+        (3, TaskInstanceState.FAILED),
+        (4, State.NONE),
+    ]
+
+
+def test_calls_to_verify_integrity_with_mapped_task_reduction_at_runtime(dag_maker, session):
+    """
+    Test reduction in mapped task at runtime with calls to dagrun.verify_integrity
+    """
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+    # Now "clear" and "reduce" the length of literal
+    dag.clear()
+    Variable.set(key='arg1', value=[1])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+    task2 = dag.get_task('task_2')
+    for ti in dr.get_task_instances():
+        if ti.map_index < 0:
+            ti.task = task1
+        else:
+            ti.task = task2
+            ti.state = TaskInstanceState.SUCCESS
+        session.merge(ti)
+    session.flush()
+
+    # Run verify_integrity as a whole and assert some tasks were removed
+    dr.verify_integrity()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, TaskInstanceState.SUCCESS),
+        (1, TaskInstanceState.REMOVED),
+        (2, TaskInstanceState.REMOVED),
+    ]
+
+    # assert repeated calls did not change the instances
+    dr.verify_integrity()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, TaskInstanceState.SUCCESS),
+        (1, TaskInstanceState.REMOVED),
+        (2, TaskInstanceState.REMOVED),
+    ]
+
+
+def test_calls_to_verify_integrity_with_mapped_task_with_no_changes_at_runtime(dag_maker, session):
+    """
+    Test no change in mapped task at runtime with calls to dagrun.verify_integrity
+    """
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+    # Now "clear" and return the same length
+    dag.clear()
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+    task2 = dag.get_task('task_2')
+    for ti in dr.get_task_instances():
+        if ti.map_index < 0:
+            ti.task = task1
+        else:
+            ti.task = task2
+            ti.state = TaskInstanceState.SUCCESS
+        session.merge(ti)
+    session.flush()
+
+    # Run verify_integrity as a whole and assert no changes
+    dr.verify_integrity()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, TaskInstanceState.SUCCESS),
+        (1, TaskInstanceState.SUCCESS),
+        (2, TaskInstanceState.SUCCESS),
+    ]
+
+    # assert repeated calls did not change the instances
+    dr.verify_integrity()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, TaskInstanceState.SUCCESS),
+        (1, TaskInstanceState.SUCCESS),
+        (2, TaskInstanceState.SUCCESS),
+    ]
+
+
+def test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_maker, session, caplog):
+    """
+    Test zero length reduction in mapped task at runtime with calls to dagrun.verify_integrity
+    """
+    import logging
+
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+    ti1 = [i for i in tis if i.map_index == 0][0]
+    # Now "clear" and "reduce" the length to empty list
+    dag.clear()
+    Variable.set(key='arg1', value=[])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+    task2 = dag.get_task('task_2')
+    for ti in dr.get_task_instances():
+        if ti.map_index < 0:
+            ti.task = task1
+        else:
+            ti.task = task2
+        session.merge(ti)
+    session.flush()
+    with caplog.at_level(logging.DEBUG):
+
+        # Run verify_integrity as a whole and assert the tasks were removed
+        dr.verify_integrity()
+        tis = dr.get_task_instances()
+        indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+        assert sorted(indices) == [
+            (0, TaskInstanceState.REMOVED),
+            (1, TaskInstanceState.REMOVED),
+            (2, TaskInstanceState.REMOVED),
+        ]
+        assert (
+            f"Removing task '{ti1}' as the map_index is longer than the resolved mapping list (0)"
+            in caplog.text
+        )
+
+
 @pytest.mark.need_serialized_dag
 def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
     literal = [1, 2, 3, 4]