You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/15 18:45:18 UTC

[airflow] 42/45: Fix reducing mapped length of a mapped task at runtime after a clear (#25531)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 31a4b8cf1f803343fdc9681bfa2784132c05411b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Aug 5 10:30:51 2022 +0100

    Fix reducing mapped length of a mapped task at runtime after a clear (#25531)
    
    The previous fix on task immutability after a run did not fix a case where the task was removed at runtime when the literal is dynamic. This PR addreses it
    
    (cherry picked from commit d3028ada36a43a0d549d22c280fb16d868b90b6d)
---
 airflow/models/dagrun.py    | 12 +++++++--
 tests/models/test_dagrun.py | 64 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 74 insertions(+), 2 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 1b746e8a06..8b7f3a1c39 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -655,7 +655,7 @@ class DagRun(Base, LoggingMixin):
                     yield ti
 
         tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
-        missing_indexes = self._find_missing_task_indexes(tis, session=session)
+        missing_indexes = self._revise_mapped_task_indexes(tis, session=session)
         if missing_indexes:
             self.verify_integrity(missing_indexes=missing_indexes, session=session)
 
@@ -1081,7 +1081,7 @@ class DagRun(Base, LoggingMixin):
             # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
             session.rollback()
 
-    def _find_missing_task_indexes(
+    def _revise_mapped_task_indexes(
         self,
         tis: Iterable[TI],
         *,
@@ -1106,6 +1106,14 @@ class DagRun(Base, LoggingMixin):
             existing_indexes[task].append(ti.map_index)
             task.run_time_mapped_ti_count.cache_clear()  # type: ignore[attr-defined]
             new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0
+
+            if ti.map_index >= new_length:
+                self.log.debug(
+                    "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
+                    ti,
+                    new_length,
+                )
+                ti.state = State.REMOVED
             new_indexes[task] = range(new_length)
         missing_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
         for k, v in existing_indexes.items():
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 6c3cc1c91c..e28b203640 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1207,6 +1207,70 @@ def test_mapped_literal_length_increase_at_runtime_adds_additional_tis(dag_maker
     ]
 
 
+def test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker, session):
+    """
+    Test that when the length of mapped literal reduces at runtime, the missing task instances
+    are marked as removed
+    """
+    from airflow.models import Variable
+
+    Variable.set(key='arg1', value=[1, 2, 3])
+
+    @task
+    def task_1():
+        return Variable.get('arg1', deserialize_json=True)
+
+    with dag_maker(session=session) as dag:
+
+        @task
+        def task_2(arg2):
+            ...
+
+        task_2.expand(arg2=task_1())
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id='task_1')
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, State.NONE),
+    ]
+
+    # Now "clear" and "reduce" the length of literal
+    dag.clear()
+    Variable.set(key='arg1', value=[1, 2])
+
+    with dag:
+        task_2.expand(arg2=task_1()).operator
+
+    # At this point, we need to test that the change works on the serialized
+    # DAG (which is what the scheduler operates on)
+    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    dr.dag = serialized_dag
+
+    # Run the first task again to get the new lengths
+    ti = dr.get_task_instance(task_id='task_1')
+    task1 = dag.get_task('task_1')
+    ti.refresh_from_task(task1)
+    ti.run()
+
+    # this would be called by the localtask job
+    dr.task_instance_scheduling_decisions()
+    tis = dr.get_task_instances()
+
+    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
+    assert sorted(indices) == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, TaskInstanceState.REMOVED),
+    ]
+
+
 @pytest.mark.need_serialized_dag
 def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
     literal = [1, 2, 3, 4]