You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/08/16 15:35:16 UTC

[airflow] 09/11: Ensure that zombie tasks for dags with errors get cleaned up (#25550)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to tag v2.3.3+astro.2
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit dfb4e81e33e77efd92727b0ef7db592845cc6354
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Fri Aug 5 18:36:24 2022 +0100

    Ensure that zombie tasks for dags with errors get cleaned up (#25550)
    
    If there is a parse error in a DAG the zombie cleanup request never ran,
    which resulted in the TI never leaving running state and just
    continually being detected as a zombie.
    
    (Prior to AIP-45 landing, this bug/behaviour resulted in a DAG with a
    parse error never actually leaving the queued state.)
    
    The fix here is to _always_ make sure we run `ti.handle_failure` when we
    are given a request, even if we can't load the DAG. To _try_ and work as
    well as we can, we try to load the serialized_dag if we can, but in
    cases where we can't for whatever reason we also make sure
    TaskInstance.handle_failure is able to operate even when `self.task` is
    None.
    
    (cherry picked from commit 1d8507af07353e5cf29a860314b5ba5caad5cdf3)
---
 airflow/dag_processing/processor.py    | 93 ++++++++++++++++++++++++++++------
 airflow/models/log.py                  |  3 +-
 airflow/models/taskinstance.py         |  8 ++-
 tests/conftest.py                      |  2 +-
 tests/dag_processing/test_processor.py | 41 +++++++++++++--
 tests/models/test_taskinstance.py      | 33 +++++++++++-
 6 files changed, 157 insertions(+), 23 deletions(-)

diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
index 469b55cfeb..df22371e6f 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -25,10 +25,10 @@ import time
 from contextlib import redirect_stderr, redirect_stdout, suppress
 from datetime import timedelta
 from multiprocessing.connection import Connection as MultiprocessingConnection
-from typing import Iterator, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Iterator, List, Optional, Set, Tuple
 
 from setproctitle import setproctitle
-from sqlalchemy import func, or_
+from sqlalchemy import exc, func, or_
 from sqlalchemy.orm.session import Session
 
 from airflow import models, settings
@@ -51,6 +51,9 @@ from airflow.utils.mixins import MultiprocessingStartMethodMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    from airflow.models.operator import Operator
+
 DR = models.DagRun
 TI = models.TaskInstance
 
@@ -575,7 +578,7 @@ class DagFileProcessor(LoggingMixin):
             self.log.debug("Processing Callback Request: %s", request)
             try:
                 if isinstance(request, TaskCallbackRequest):
-                    self._execute_task_callbacks(dagbag, request)
+                    self._execute_task_callbacks(dagbag, request, session=session)
                 elif isinstance(request, SlaCallbackRequest):
                     self.manage_slas(dagbag.get_dag(request.dag_id), session=session)
                 elif isinstance(request, DagCallbackRequest):
@@ -587,7 +590,27 @@ class DagFileProcessor(LoggingMixin):
                     request.full_filepath,
                 )
 
-        session.commit()
+        session.flush()
+
+    def execute_callbacks_without_dag(
+        self, callback_requests: List[CallbackRequest], session: Session
+    ) -> None:
+        """
+        Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors.
+
+        This is so important so that tasks that failed when there is a parse
+        error don't get stuck in queued state.
+        """
+        for request in callback_requests:
+            self.log.debug("Processing Callback Request: %s", request)
+            if isinstance(request, TaskCallbackRequest):
+                self._execute_task_callbacks(None, request, session)
+            else:
+                self.log.info(
+                    "Not executing %s callback for file %s as there was a dag parse error",
+                    request.__class__.__name__,
+                    request.full_filepath,
+                )
 
     @provide_session
     def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
@@ -597,18 +620,51 @@ class DagFileProcessor(LoggingMixin):
             dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
         )
 
-    def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
+    def _execute_task_callbacks(
+        self, dagbag: Optional[DagBag], request: TaskCallbackRequest, session: Session
+    ):
+        if not request.is_failure_callback:
+            return
+
         simple_ti = request.simple_task_instance
-        if simple_ti.dag_id in dagbag.dags:
+        ti: Optional[TI] = (
+            session.query(TI)
+            .filter_by(
+                dag_id=simple_ti.dag_id,
+                run_id=simple_ti.run_id,
+                task_id=simple_ti.task_id,
+                map_index=simple_ti.map_index,
+            )
+            .one_or_none()
+        )
+        if not ti:
+            return
+
+        task: Optional["Operator"] = None
+
+        if dagbag and simple_ti.dag_id in dagbag.dags:
             dag = dagbag.dags[simple_ti.dag_id]
             if simple_ti.task_id in dag.task_ids:
                 task = dag.get_task(simple_ti.task_id)
-                if request.is_failure_callback:
-                    ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
-                    # TODO: Use simple_ti to improve performance here in the future
-                    ti.refresh_from_db()
-                    ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE)
-                    self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
+        else:
+            # We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run
+            # `handle_failure` so that the state of the TI gets progressed.
+            #
+            # Since handle_failure _really_ wants a task, we do our best effort to give it one
+            from airflow.models.serialized_dag import SerializedDagModel
+
+            try:
+                model = session.query(SerializedDagModel).get(simple_ti.dag_id)
+                if model:
+                    task = model.dag.get_task(simple_ti.task_id)
+            except (exc.NoResultFound, TaskNotFound):
+                pass
+        if task:
+            ti.refresh_from_task(task)
+
+        ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
+        self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
+        session.flush()
 
     @provide_session
     def process_file(
@@ -616,7 +672,7 @@ class DagFileProcessor(LoggingMixin):
         file_path: str,
         callback_requests: List[CallbackRequest],
         pickle_dags: bool = False,
-        session: Session = None,
+        session: Session = NEW_SESSION,
     ) -> Tuple[int, int]:
         """
         Process a Python file containing Airflow DAGs.
@@ -652,12 +708,19 @@ class DagFileProcessor(LoggingMixin):
         else:
             self.log.warning("No viable dags retrieved from %s", file_path)
             self.update_import_errors(session, dagbag)
+            if callback_requests:
+                # If there were callback requests for this file but there was a
+                # parse error we still need to progress the state of TIs,
+                # otherwise they might be stuck in queued/running for ever!
+                self.execute_callbacks_without_dag(callback_requests, session)
             return 0, len(dagbag.import_errors)
 
-        self.execute_callbacks(dagbag, callback_requests)
+        self.execute_callbacks(dagbag, callback_requests, session)
+        session.commit()
 
         # Save individual DAGs in the ORM
-        dagbag.sync_to_db()
+        dagbag.sync_to_db(session)
+        session.commit()
 
         if pickle_dags:
             paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
diff --git a/airflow/models/log.py b/airflow/models/log.py
index b2a5639dcd..4633dd3785 100644
--- a/airflow/models/log.py
+++ b/airflow/models/log.py
@@ -55,7 +55,8 @@ class Log(Base):
             self.task_id = task_instance.task_id
             self.execution_date = task_instance.execution_date
             self.map_index = task_instance.map_index
-            task_owner = task_instance.task.owner
+            if task_instance.task:
+                task_owner = task_instance.task.owner
 
         if 'task_id' in kwargs:
             self.task_id = kwargs['task_id']
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index debd0aa6b0..72efd4c8db 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1919,7 +1919,7 @@ class TaskInstance(Base, LoggingMixin):
 
         self.end_date = timezone.utcnow()
         self.set_duration()
-        Stats.incr(f'operator_failures_{self.task.task_type}')
+        Stats.incr(f'operator_failures_{self.operator}')
         Stats.incr('ti_failures')
         if not test_mode:
             session.add(Log(State.FAILED, self))
@@ -1943,7 +1943,8 @@ class TaskInstance(Base, LoggingMixin):
 
         task = None
         try:
-            task = self.task.unmap()
+            if self.task:
+                task = self.task.unmap()
         except Exception:
             self.log.error("Unable to unmap task, can't determine if we need to send an alert email or not")
 
@@ -1985,6 +1986,9 @@ class TaskInstance(Base, LoggingMixin):
             # If a task is cleared when running, it goes into RESTARTING state and is always
             # eligible for retry
             return True
+        if not self.task:
+            # Couldn't load the task, don't know number of retries, guess:
+            return self.try_number <= self.max_tries
 
         return self.task.retries and self.try_number <= self.max_tries
 
diff --git a/tests/conftest.py b/tests/conftest.py
index b153c213d5..8447ba240f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -611,7 +611,7 @@ def dag_maker(request):
                     if not dag_ids:
                         return
                     # To isolate problems here with problems from elsewhere on the session object
-                    self.session.flush()
+                    self.session.rollback()
 
                     self.session.query(SerializedDagModel).filter(
                         SerializedDagModel.dag_id.in_(dag_ids)
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index 1c0b164be8..2bad13e5f6 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -30,6 +30,7 @@ from airflow.configuration import TEST_DAGS_FOLDER, conf
 from airflow.dag_processing.manager import DagFileProcessorAgent
 from airflow.dag_processing.processor import DagFileProcessor
 from airflow.models import DagBag, DagModel, SlaMiss, TaskInstance, errors
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import SimpleTaskInstance
 from airflow.operators.empty import EmptyOperator
 from airflow.utils import timezone
@@ -386,10 +387,44 @@ class TestDagFileProcessor:
                 full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message"
             )
         ]
-        dag_file_processor.execute_callbacks(dagbag, requests)
+        dag_file_processor.execute_callbacks(dagbag, requests, session)
+        mock_ti_handle_failure.assert_called_once_with(
+            error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session
+        )
+
+    @pytest.mark.parametrize(
+        ["has_serialized_dag"],
+        [pytest.param(True, id="dag_in_db"), pytest.param(False, id="no_dag_found")],
+    )
+    @patch.object(TaskInstance, 'handle_failure')
+    def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag):
+        dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        with create_session() as session:
+            session.query(TaskInstance).delete()
+            dag = dagbag.get_dag('example_branch_operator')
+            dagrun = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+            )
+            task = dag.get_task(task_id='run_this_first')
+            ti = TaskInstance(task, run_id=dagrun.run_id, state=State.QUEUED)
+            session.add(ti)
+
+            if has_serialized_dag:
+                assert SerializedDagModel.write_dag(dag, session=session) is True
+                session.flush()
+
+        requests = [
+            TaskCallbackRequest(
+                full_filepath="A", simple_task_instance=SimpleTaskInstance.from_ti(ti), msg="Message"
+            )
+        ]
+        dag_file_processor.execute_callbacks_without_dag(requests, session)
         mock_ti_handle_failure.assert_called_once_with(
-            error="Message",
-            test_mode=conf.getboolean('core', 'unit_test_mode'),
+            error="Message", test_mode=conf.getboolean('core', 'unit_test_mode'), session=session
         )
 
     def test_failure_callbacks_should_not_drop_hostname(self):
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 3990c3cbf5..b1385c1179 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1915,7 +1915,7 @@ class TestTaskInstance:
         ti = TI(task=task, run_id=dr.run_id)
         ti.state = State.QUEUED
         session.merge(ti)
-        session.commit()
+        session.flush()
         assert ti.state == State.QUEUED
         assert ti.try_number == 1
         ti.handle_failure("test queued ti", test_mode=True)
@@ -1925,6 +1925,37 @@ class TestTaskInstance:
         # Check 'ti.try_number' is bumped to 2. This is try_number for next run
         assert ti.try_number == 2
 
+    @patch.object(Stats, 'incr')
+    def test_handle_failure_no_task(self, Stats_incr, dag_maker):
+        """
+        When a zombie is detected for a DAG with a parse error, we need to be able to run handle_failure
+        _without_ ti.task being set
+        """
+        session = settings.Session()
+        with dag_maker():
+            task = EmptyOperator(task_id="mytask", retries=1)
+        dr = dag_maker.create_dagrun()
+        ti = TI(task=task, run_id=dr.run_id)
+        ti = session.merge(ti)
+        ti.task = None
+        ti.state = State.QUEUED
+        session.flush()
+
+        assert ti.task is None, "Check critical pre-condition"
+
+        assert ti.state == State.QUEUED
+        assert ti.try_number == 1
+
+        ti.handle_failure("test queued ti", test_mode=False)
+        assert ti.state == State.UP_FOR_RETRY
+        # Assert that 'ti._try_number' is bumped from 0 to 1. This is the last/current try
+        assert ti._try_number == 1
+        # Check 'ti.try_number' is bumped to 2. This is try_number for next run
+        assert ti.try_number == 2
+
+        Stats_incr.assert_any_call('ti_failures')
+        Stats_incr.assert_any_call('operator_failures_EmptyOperator')
+
     def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
         def fail():
             raise AirflowFailException("hopeless")