You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/28 21:25:30 UTC

[airflow] 17/17: Handle stuck queued tasks in Celery for db backend(#19769)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ec60dd799d65bdb80b83893db7df215d98342dde
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Jan 14 09:55:15 2022 +0100

    Handle stuck queued tasks in Celery for db backend(#19769)
    
    Move the state of stuck queued tasks in Celery to Scheduled so that
    the Scheduler can queue them again. Only applies to DatabaseBackend
    
    (cherry picked from commit 14ee831c7ad767e31a3aeccf3edbc519b3b8c923)
---
 airflow/config_templates/config.yml          |   7 +
 airflow/config_templates/default_airflow.cfg |   3 +
 airflow/executors/celery_executor.py         |  52 +++++++
 tests/executors/test_celery_executor.py      | 197 ++++++++++++++++++++++++---
 4 files changed, 242 insertions(+), 17 deletions(-)

diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 6941f03..e061568 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1663,6 +1663,13 @@
       type: string
       example: ~
       default: "False"
+    - name: stuck_queued_task_check_interval
+      description: |
+        How often to check for stuck queued task (in seconds)
+      version_added: 2.3.0
+      type: integer
+      example: ~
+      default: "300"
 - name: celery_broker_transport_options
   description: |
     This section is for specifying options which can be passed to the
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 6a5449b..4024922 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -830,6 +830,9 @@ task_publish_max_retries = 3
 # Worker initialisation check to validate Metadata Database connection
 worker_precheck = False
 
+# How often to check for stuck queued task (in seconds)
+stuck_queued_task_check_interval = 300
+
 [celery_broker_transport_options]
 
 # This section is for specifying options which can be passed to the
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index f257b0c..8daced6 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -40,6 +40,7 @@ from celery.backends.database import DatabaseBackend, Task as TaskDb, session_cl
 from celery.result import AsyncResult
 from celery.signals import import_modules as celery_import_modules
 from setproctitle import setproctitle
+from sqlalchemy.orm.session import Session
 
 import airflow.settings as settings
 from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
@@ -50,6 +51,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.timezone import utcnow
@@ -231,6 +233,10 @@ class CeleryExecutor(BaseExecutor):
         self.task_adoption_timeout = datetime.timedelta(
             seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)
         )
+        self.stuck_tasks_last_check_time: int = time.time()
+        self.stuck_queued_task_check_interval = conf.getint(
+            'celery', 'stuck_queued_task_check_interval', fallback=300
+        )
         self.task_publish_retries: Dict[TaskInstanceKey, int] = OrderedDict()
         self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3)
 
@@ -335,6 +341,8 @@ class CeleryExecutor(BaseExecutor):
 
         if self.adopted_task_timeouts:
             self._check_for_stalled_adopted_tasks()
+        if time.time() - self.stuck_tasks_last_check_time > self.stuck_queued_task_check_interval:
+            self._clear_stuck_queued_tasks()
 
     def _check_for_stalled_adopted_tasks(self):
         """
@@ -375,6 +383,50 @@ class CeleryExecutor(BaseExecutor):
             for key in timedout_keys:
                 self.change_state(key, State.FAILED)
 
+    @provide_session
+    def _clear_stuck_queued_tasks(self, session: Session = NEW_SESSION) -> None:
+        """
+        Tasks can get stuck in queued state in DB while still not in
+        worker. This happens when the worker is autoscaled down and
+        the task is queued but has not been picked up by any worker prior to the scaling.
+
+        In such situation, we update the task instance state to scheduled so that
+        it can be queued again. We chose to use task_adoption_timeout to decide when
+        a queued task is considered stuck and should be reschelduled.
+        """
+        if not isinstance(app.backend, DatabaseBackend):
+            # We only want to do this for database backends where
+            # this case has been spotted
+            return
+        # We use this instead of using bulk_state_fetcher because we
+        # may not have the stuck task in self.tasks and we don't want
+        # to clear task in self.tasks too
+        session_ = app.backend.ResultSession()
+        task_cls = getattr(app.backend, "task_cls", TaskDb)
+        with session_cleanup(session_):
+            celery_task_ids = [t.task_id for t in session_.query(task_cls.task_id).all()]
+        self.log.debug("Checking for stuck queued tasks")
+
+        max_allowed_time = utcnow() - self.task_adoption_timeout
+
+        for task in session.query(TaskInstance).filter(
+            TaskInstance.state == State.QUEUED, TaskInstance.queued_dttm < max_allowed_time
+        ):
+            if task.key in self.queued_tasks or task.key in self.running:
+                continue
+
+            if task.external_executor_id in celery_task_ids:
+                # The task is still running in the worker
+                continue
+
+            self.log.info(
+                'TaskInstance: %s found in queued state for more than %s seconds, rescheduling',
+                task,
+                self.task_adoption_timeout.total_seconds(),
+            )
+            task.state = State.SCHEDULED
+            session.merge(task)
+
     def debug_dump(self) -> None:
         """Called in response to SIGUSR2 by the scheduler"""
         super().debug_dump()
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index db63b18..5632f7d 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -17,10 +17,13 @@
 # under the License.
 import contextlib
 import json
+import logging
 import os
 import signal
 import sys
+import time
 import unittest
+from collections import namedtuple
 from datetime import datetime, timedelta
 from unittest import mock
 
@@ -32,6 +35,7 @@ from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend
 from celery.backends.database import DatabaseBackend
 from celery.contrib.testing.worker import start_worker
 from celery.result import AsyncResult
+from freezegun import freeze_time
 from kombu.asynchronous import set_event_loop
 from parameterized import parameterized
 
@@ -94,12 +98,12 @@ def _prepare_app(broker_url=None, execute=None):
             set_event_loop(None)
 
 
-class TestCeleryExecutor(unittest.TestCase):
-    def setUp(self) -> None:
+class TestCeleryExecutor:
+    def setup_method(self) -> None:
         db.clear_db_runs()
         db.clear_db_jobs()
 
-    def tearDown(self) -> None:
+    def teardown_method(self) -> None:
         db.clear_db_runs()
         db.clear_db_jobs()
 
@@ -196,10 +200,10 @@ class TestCeleryExecutor(unittest.TestCase):
     @pytest.mark.integration("redis")
     @pytest.mark.integration("rabbitmq")
     @pytest.mark.backend("mysql", "postgres")
-    def test_retry_on_error_sending_task(self):
+    def test_retry_on_error_sending_task(self, caplog):
         """Test that Airflow retries publishing tasks to Celery Broker at least 3 times"""
 
-        with _prepare_app(), self.assertLogs(celery_executor.log) as cm, mock.patch.object(
+        with _prepare_app(), caplog.at_level(logging.INFO), mock.patch.object(
             # Mock `with timeout()` to _instantly_ fail.
             celery_executor.timeout,
             "__enter__",
@@ -227,28 +231,19 @@ class TestCeleryExecutor(unittest.TestCase):
             assert dict(executor.task_publish_retries) == {key: 2}
             assert 1 == len(executor.queued_tasks), "Task should remain in queue"
             assert executor.event_buffer == {}
-            assert (
-                "INFO:airflow.executors.celery_executor.CeleryExecutor:"
-                f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in cm.output
-            )
+            assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in caplog.text
 
             executor.heartbeat()
             assert dict(executor.task_publish_retries) == {key: 3}
             assert 1 == len(executor.queued_tasks), "Task should remain in queue"
             assert executor.event_buffer == {}
-            assert (
-                "INFO:airflow.executors.celery_executor.CeleryExecutor:"
-                f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in cm.output
-            )
+            assert f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in caplog.text
 
             executor.heartbeat()
             assert dict(executor.task_publish_retries) == {key: 4}
             assert 1 == len(executor.queued_tasks), "Task should remain in queue"
             assert executor.event_buffer == {}
-            assert (
-                "INFO:airflow.executors.celery_executor.CeleryExecutor:"
-                f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in cm.output
-            )
+            assert f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in caplog.text
 
             executor.heartbeat()
             assert dict(executor.task_publish_retries) == {}
@@ -411,6 +406,174 @@ class TestCeleryExecutor(unittest.TestCase):
         assert executor.running == {key_2}
         assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout}
 
+    @pytest.mark.backend("mysql", "postgres")
+    @pytest.mark.parametrize(
+        "state, queued_dttm, executor_id",
+        [
+            (State.SCHEDULED, timezone.utcnow() - timedelta(days=2), '231'),
+            (State.QUEUED, timezone.utcnow(), '231'),
+            (State.QUEUED, timezone.utcnow(), None),
+        ],
+    )
+    def test_stuck_queued_tasks_are_cleared(
+        self, state, queued_dttm, executor_id, session, dag_maker, create_dummy_dag, create_task_instance
+    ):
+        """Test that clear_stuck_queued_tasks works"""
+        ti = create_task_instance(state=State.QUEUED)
+        ti.queued_dttm = queued_dttm
+        ti.external_executor_id = executor_id
+        session.merge(ti)
+        session.flush()
+        executor = celery_executor.CeleryExecutor()
+        executor._clear_stuck_queued_tasks()
+        session.flush()
+        ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+        assert ti.state == state
+
+    @pytest.mark.backend("mysql", "postgres")
+    def test_task_in_queued_tasks_dict_are_not_cleared(
+        self, session, dag_maker, create_dummy_dag, create_task_instance
+    ):
+        """Test that clear_stuck_queued_tasks doesn't clear tasks in executor.queued_tasks"""
+        ti = create_task_instance(state=State.QUEUED)
+        ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+        ti.external_executor_id = '231'
+        session.merge(ti)
+        session.flush()
+        executor = celery_executor.CeleryExecutor()
+        executor.queued_tasks = {ti.key: AsyncResult("231")}
+        executor._clear_stuck_queued_tasks()
+        session.flush()
+        ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+        assert executor.queued_tasks == {ti.key: AsyncResult("231")}
+        assert ti.state == State.QUEUED
+
+    @pytest.mark.backend("mysql", "postgres")
+    def test_task_in_running_dict_are_not_cleared(
+        self, session, dag_maker, create_dummy_dag, create_task_instance
+    ):
+        """Test that clear_stuck_queued_tasks doesn't clear tasks in executor.running"""
+        ti = create_task_instance(state=State.QUEUED)
+        ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+        ti.external_executor_id = '231'
+        session.merge(ti)
+        session.flush()
+        executor = celery_executor.CeleryExecutor()
+        executor.running = {ti.key: AsyncResult("231")}
+        executor._clear_stuck_queued_tasks()
+        session.flush()
+        ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+        assert executor.running == {ti.key: AsyncResult("231")}
+        assert ti.state == State.QUEUED
+
+    @pytest.mark.backend("mysql", "postgres")
+    def test_only_database_result_backend_supports_clearing_queued_task(
+        self, session, dag_maker, create_dummy_dag, create_task_instance
+    ):
+        with _prepare_app():
+            mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
+            with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend):
+                ti = create_task_instance(state=State.QUEUED)
+                ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+                ti.external_executor_id = '231'
+                session.merge(ti)
+                session.flush()
+                executor = celery_executor.CeleryExecutor()
+                executor.tasks = {ti.key: AsyncResult("231")}
+                executor._clear_stuck_queued_tasks()
+                session.flush()
+                ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+                # Not cleared
+                assert ti.state == State.QUEUED
+                assert executor.tasks == {ti.key: AsyncResult("231")}
+
+    @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+    @pytest.mark.backend("mysql", "postgres")
+    @freeze_time("2020-01-01")
+    @pytest.mark.parametrize(
+        "state",
+        [
+            (State.SCHEDULED),
+            (State.QUEUED),
+        ],
+    )
+    def test_the_check_interval_to_clear_stuck_queued_task_is_correct(
+        self,
+        mock_result_session,
+        state,
+        session,
+        dag_maker,
+        create_dummy_dag,
+        create_task_instance,
+    ):
+        with _prepare_app():
+            mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
+            with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend):
+                mock_session = mock_backend.ResultSession.return_value
+                mock_session.query.return_value.all.return_value = [
+                    mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
+                ]
+                if state == State.SCHEDULED:
+                    last_check_time = time.time() - 302  # should clear ti state
+                else:
+                    last_check_time = time.time() - 298  # should not clear ti state
+
+                ti = create_task_instance(state=State.QUEUED)
+                ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+                ti.external_executor_id = '231'
+                session.merge(ti)
+                session.flush()
+                executor = celery_executor.CeleryExecutor()
+                executor.tasks = {ti.key: AsyncResult("231")}
+                executor.stuck_tasks_last_check_time = last_check_time
+                executor.sync()
+                session.flush()
+                ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+                assert ti.state == state
+
+    @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+    @pytest.mark.backend("mysql", "postgres")
+    @freeze_time("2020-01-01")
+    @pytest.mark.parametrize(
+        "task_id, state",
+        [
+            ('231', State.QUEUED),
+            ('111', State.SCHEDULED),
+        ],
+    )
+    def test_the_check_interval_to_clear_stuck_queued_task_is_correct_for_db_query(
+        self,
+        mock_result_session,
+        task_id,
+        state,
+        session,
+        dag_maker,
+        create_dummy_dag,
+        create_task_instance,
+    ):
+        """Here we test that task are not cleared if found in celery database"""
+        result_obj = namedtuple('Result', ['status', 'task_id'])
+        with _prepare_app():
+            mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
+            with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend):
+                mock_session = mock_backend.ResultSession.return_value
+                mock_session.query.return_value.all.return_value = [result_obj("SUCCESS", task_id)]
+
+                last_check_time = time.time() - 302  # should clear ti state
+
+                ti = create_task_instance(state=State.QUEUED)
+                ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+                ti.external_executor_id = '231'
+                session.merge(ti)
+                session.flush()
+                executor = celery_executor.CeleryExecutor()
+                executor.tasks = {ti.key: AsyncResult("231")}
+                executor.stuck_tasks_last_check_time = last_check_time
+                executor.sync()
+                session.flush()
+                ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+                assert ti.state == state
+
 
 def test_operation_timeout_config():
     assert celery_executor.OPERATION_TIMEOUT == 1