You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/03/07 16:16:12 UTC

[airflow] 05/23: Fix Scheduler crash when clear a previous run of a normal task that is now a mapped task (#29645)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3aa0c008fcbd445f9b3f277f3bf79f00da01425e
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Mon Feb 20 20:45:25 2023 +0100

    Fix Scheduler crash when clear a previous run of a normal task that is now a mapped task (#29645)
    
    The fix was to clear the db references of the taskinstances in XCom, RenderedTaskInstanceFields
    and TaskFail. That way, we are able to run the mapped tasks
    
    (cherry picked from commit a770edfac493f3972c10a43e45bcd0e7cfaea65f)
---
 airflow/models/dagrun.py          |  8 +++++
 airflow/models/taskinstance.py    | 19 +++++++++++
 tests/models/test_dagrun.py       | 67 +++++++++++++++++++++++++++++----------
 tests/models/test_taskinstance.py | 16 ++++++++++
 4 files changed, 93 insertions(+), 17 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index e06d911c27..6f1505d489 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -778,6 +778,14 @@ class DagRun(Base, LoggingMixin):
             """
             if ti.map_index >= 0:  # Already expanded, we're good.
                 return None
+
+            from airflow.models.mappedoperator import MappedOperator
+
+            if isinstance(ti.task, MappedOperator):
+                # If we get here, it could be that we are moving from non-mapped to mapped
+                # after task instance clearing or this ti is not yet expanded. Safe to clear
+                # the db references.
+                ti.clear_db_references(session=session)
             try:
                 expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session)
             except NotMapped:  # Not a mapped task, nothing needed.
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 615a98f6cb..1eb778df55 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2658,6 +2658,25 @@ class TaskInstance(Base, LoggingMixin):
         map_index_start = ancestor_map_index * further_count
         return range(map_index_start, map_index_start + further_count)
 
+    def clear_db_references(self, session):
+        """
+        Clear DB references to XCom, TaskFail and RenderedTaskInstanceFields.
+
+        :param session: ORM Session
+
+        :meta private:
+        """
+        from airflow.models.renderedtifields import RenderedTaskInstanceFields
+
+        tables = [TaskFail, XCom, RenderedTaskInstanceFields]
+        for table in tables:
+            session.query(table).filter(
+                table.dag_id == self.dag_id,
+                table.task_id == self.task_id,
+                table.run_id == self.run_id,
+                table.map_index == self.map_index,
+            ).delete()
+
 
 def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
     """Given two operators, find their innermost common mapped task group."""
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 3d368ae962..34295473c9 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -49,13 +49,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE as _DEFAULT_DATE
-from tests.test_utils.db import (
-    clear_db_dags,
-    clear_db_datasets,
-    clear_db_pools,
-    clear_db_runs,
-    clear_db_variables,
-)
+from tests.test_utils import db
 from tests.test_utils.mock_operators import MockOperator
 
 DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
@@ -64,19 +58,21 @@ DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
 class TestDagRun:
     dagbag = DagBag(include_examples=True)
 
+    @staticmethod
+    def clean_db():
+        db.clear_db_runs()
+        db.clear_db_pools()
+        db.clear_db_dags()
+        db.clear_db_variables()
+        db.clear_db_datasets()
+        db.clear_db_xcom()
+        db.clear_db_task_fail()
+
     def setup_class(self) -> None:
-        clear_db_runs()
-        clear_db_pools()
-        clear_db_dags()
-        clear_db_variables()
-        clear_db_datasets()
+        self.clean_db()
 
     def teardown_method(self) -> None:
-        clear_db_runs()
-        clear_db_pools()
-        clear_db_dags()
-        clear_db_variables()
-        clear_db_datasets()
+        self.clean_db()
 
     def create_dag_run(
         self,
@@ -2222,3 +2218,40 @@ def test_mapped_task_depends_on_past(dag_maker, session):
     assert len(decision.unfinished_tis) == 0
     decision = dr2.task_instance_scheduling_decisions(session=session)
     assert len(decision.unfinished_tis) == 0
+
+
+def test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
+    """
+    Test that clearing a task and moving from non-mapped to mapped clears existing
+    references in XCom, TaskFail, and RenderedTaskInstanceFields
+    To be able to test this, RenderedTaskInstanceFields was not used in the test
+    since it would require that the task is expanded first.
+    """
+
+    from airflow.models.taskfail import TaskFail
+    from airflow.models.xcom import XCom
+
+    @task
+    def printx(x):
+        print(x)
+
+    with dag_maker() as dag:
+        printx.expand(x=[1])
+
+    dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+    ti = dr1.get_task_instances()[0]
+    # mimicking a case where task moved from non-mapped to mapped
+    # in that case, it would have map_index of -1 even though mapped
+    ti.map_index = -1
+    session.merge(ti)
+    session.flush()
+    # Purposely omitted RenderedTaskInstanceFields because the ti need
+    # to be expanded but here we are mimicking and made it map_index -1
+    session.add(TaskFail(ti))
+    XCom.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id)
+    session.commit()
+    for table in [TaskFail, XCom]:
+        assert session.query(table).count() == 1
+    dr1.task_instance_scheduling_decisions(session)
+    for table in [TaskFail, XCom]:
+        assert session.query(table).count() == 0
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 849358cc69..bb2b676a59 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -129,6 +129,7 @@ class TestTaskInstance:
         db.clear_rendered_ti_fields()
         db.clear_db_task_reschedule()
         db.clear_db_datasets()
+        db.clear_db_xcom()
 
     def setup_method(self):
         self.clean_db()
@@ -2846,6 +2847,21 @@ class TestTaskInstance:
         assert ser_ti.operator == "EmptyOperator"
         assert ser_ti.task.operator_name == "EmptyOperator"
 
+    def test_clear_db_references(self, session, create_task_instance):
+        tables = [TaskFail, RenderedTaskInstanceFields, XCom]
+        ti = create_task_instance()
+        session.merge(ti)
+        session.commit()
+        for table in [TaskFail, RenderedTaskInstanceFields]:
+            session.add(table(ti))
+        XCom.set(key="key", value="value", task_id=ti.task_id, dag_id=ti.dag_id, run_id=ti.run_id)
+        session.commit()
+        for table in tables:
+            assert session.query(table).count() == 1
+        ti.clear_db_references(session)
+        for table in tables:
+            assert session.query(table).count() == 0
+
 
 @pytest.mark.parametrize("pool_override", [None, "test_pool2"])
 def test_refresh_from_task(pool_override):