You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/28 21:25:30 UTC
[airflow] 17/17: Handle stuck queued tasks in Celery for db backend(#19769)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ec60dd799d65bdb80b83893db7df215d98342dde
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Jan 14 09:55:15 2022 +0100
Handle stuck queued tasks in Celery for db backend(#19769)
Move the state of stuck queued tasks in Celery to Scheduled so that
the Scheduler can queue them again. Only applies to DatabaseBackend
(cherry picked from commit 14ee831c7ad767e31a3aeccf3edbc519b3b8c923)
---
airflow/config_templates/config.yml | 7 +
airflow/config_templates/default_airflow.cfg | 3 +
airflow/executors/celery_executor.py | 52 +++++++
tests/executors/test_celery_executor.py | 197 ++++++++++++++++++++++++---
4 files changed, 242 insertions(+), 17 deletions(-)
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 6941f03..e061568 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1663,6 +1663,13 @@
type: string
example: ~
default: "False"
+ - name: stuck_queued_task_check_interval
+ description: |
+ How often to check for stuck queued task (in seconds)
+ version_added: 2.3.0
+ type: integer
+ example: ~
+ default: "300"
- name: celery_broker_transport_options
description: |
This section is for specifying options which can be passed to the
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 6a5449b..4024922 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -830,6 +830,9 @@ task_publish_max_retries = 3
# Worker initialisation check to validate Metadata Database connection
worker_precheck = False
+# How often to check for stuck queued task (in seconds)
+stuck_queued_task_check_interval = 300
+
[celery_broker_transport_options]
# This section is for specifying options which can be passed to the
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index f257b0c..8daced6 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -40,6 +40,7 @@ from celery.backends.database import DatabaseBackend, Task as TaskDb, session_cl
from celery.result import AsyncResult
from celery.signals import import_modules as celery_import_modules
from setproctitle import setproctitle
+from sqlalchemy.orm.session import Session
import airflow.settings as settings
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
@@ -50,6 +51,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.timezone import utcnow
@@ -231,6 +233,10 @@ class CeleryExecutor(BaseExecutor):
self.task_adoption_timeout = datetime.timedelta(
seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)
)
+ self.stuck_tasks_last_check_time: int = time.time()
+ self.stuck_queued_task_check_interval = conf.getint(
+ 'celery', 'stuck_queued_task_check_interval', fallback=300
+ )
self.task_publish_retries: Dict[TaskInstanceKey, int] = OrderedDict()
self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3)
@@ -335,6 +341,8 @@ class CeleryExecutor(BaseExecutor):
if self.adopted_task_timeouts:
self._check_for_stalled_adopted_tasks()
+ if time.time() - self.stuck_tasks_last_check_time > self.stuck_queued_task_check_interval:
+ self._clear_stuck_queued_tasks()
def _check_for_stalled_adopted_tasks(self):
"""
@@ -375,6 +383,50 @@ class CeleryExecutor(BaseExecutor):
for key in timedout_keys:
self.change_state(key, State.FAILED)
+ @provide_session
+ def _clear_stuck_queued_tasks(self, session: Session = NEW_SESSION) -> None:
+ """
+ Tasks can get stuck in queued state in DB while still not in
+ worker. This happens when the worker is autoscaled down and
+ the task is queued but has not been picked up by any worker prior to the scaling.
+
+ In such situation, we update the task instance state to scheduled so that
+ it can be queued again. We chose to use task_adoption_timeout to decide when
+ a queued task is considered stuck and should be reschelduled.
+ """
+ if not isinstance(app.backend, DatabaseBackend):
+ # We only want to do this for database backends where
+ # this case has been spotted
+ return
+ # We use this instead of using bulk_state_fetcher because we
+ # may not have the stuck task in self.tasks and we don't want
+ # to clear task in self.tasks too
+ session_ = app.backend.ResultSession()
+ task_cls = getattr(app.backend, "task_cls", TaskDb)
+ with session_cleanup(session_):
+ celery_task_ids = [t.task_id for t in session_.query(task_cls.task_id).all()]
+ self.log.debug("Checking for stuck queued tasks")
+
+ max_allowed_time = utcnow() - self.task_adoption_timeout
+
+ for task in session.query(TaskInstance).filter(
+ TaskInstance.state == State.QUEUED, TaskInstance.queued_dttm < max_allowed_time
+ ):
+ if task.key in self.queued_tasks or task.key in self.running:
+ continue
+
+ if task.external_executor_id in celery_task_ids:
+ # The task is still running in the worker
+ continue
+
+ self.log.info(
+ 'TaskInstance: %s found in queued state for more than %s seconds, rescheduling',
+ task,
+ self.task_adoption_timeout.total_seconds(),
+ )
+ task.state = State.SCHEDULED
+ session.merge(task)
+
def debug_dump(self) -> None:
"""Called in response to SIGUSR2 by the scheduler"""
super().debug_dump()
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index db63b18..5632f7d 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -17,10 +17,13 @@
# under the License.
import contextlib
import json
+import logging
import os
import signal
import sys
+import time
import unittest
+from collections import namedtuple
from datetime import datetime, timedelta
from unittest import mock
@@ -32,6 +35,7 @@ from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend
from celery.backends.database import DatabaseBackend
from celery.contrib.testing.worker import start_worker
from celery.result import AsyncResult
+from freezegun import freeze_time
from kombu.asynchronous import set_event_loop
from parameterized import parameterized
@@ -94,12 +98,12 @@ def _prepare_app(broker_url=None, execute=None):
set_event_loop(None)
-class TestCeleryExecutor(unittest.TestCase):
- def setUp(self) -> None:
+class TestCeleryExecutor:
+ def setup_method(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()
- def tearDown(self) -> None:
+ def teardown_method(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()
@@ -196,10 +200,10 @@ class TestCeleryExecutor(unittest.TestCase):
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
- def test_retry_on_error_sending_task(self):
+ def test_retry_on_error_sending_task(self, caplog):
"""Test that Airflow retries publishing tasks to Celery Broker at least 3 times"""
- with _prepare_app(), self.assertLogs(celery_executor.log) as cm, mock.patch.object(
+ with _prepare_app(), caplog.at_level(logging.INFO), mock.patch.object(
# Mock `with timeout()` to _instantly_ fail.
celery_executor.timeout,
"__enter__",
@@ -227,28 +231,19 @@ class TestCeleryExecutor(unittest.TestCase):
assert dict(executor.task_publish_retries) == {key: 2}
assert 1 == len(executor.queued_tasks), "Task should remain in queue"
assert executor.event_buffer == {}
- assert (
- "INFO:airflow.executors.celery_executor.CeleryExecutor:"
- f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in cm.output
- )
+ assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in caplog.text
executor.heartbeat()
assert dict(executor.task_publish_retries) == {key: 3}
assert 1 == len(executor.queued_tasks), "Task should remain in queue"
assert executor.event_buffer == {}
- assert (
- "INFO:airflow.executors.celery_executor.CeleryExecutor:"
- f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in cm.output
- )
+ assert f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in caplog.text
executor.heartbeat()
assert dict(executor.task_publish_retries) == {key: 4}
assert 1 == len(executor.queued_tasks), "Task should remain in queue"
assert executor.event_buffer == {}
- assert (
- "INFO:airflow.executors.celery_executor.CeleryExecutor:"
- f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in cm.output
- )
+ assert f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in caplog.text
executor.heartbeat()
assert dict(executor.task_publish_retries) == {}
@@ -411,6 +406,174 @@ class TestCeleryExecutor(unittest.TestCase):
assert executor.running == {key_2}
assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout}
+ @pytest.mark.backend("mysql", "postgres")
+ @pytest.mark.parametrize(
+ "state, queued_dttm, executor_id",
+ [
+ (State.SCHEDULED, timezone.utcnow() - timedelta(days=2), '231'),
+ (State.QUEUED, timezone.utcnow(), '231'),
+ (State.QUEUED, timezone.utcnow(), None),
+ ],
+ )
+ def test_stuck_queued_tasks_are_cleared(
+ self, state, queued_dttm, executor_id, session, dag_maker, create_dummy_dag, create_task_instance
+ ):
+ """Test that clear_stuck_queued_tasks works"""
+ ti = create_task_instance(state=State.QUEUED)
+ ti.queued_dttm = queued_dttm
+ ti.external_executor_id = executor_id
+ session.merge(ti)
+ session.flush()
+ executor = celery_executor.CeleryExecutor()
+ executor._clear_stuck_queued_tasks()
+ session.flush()
+ ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+ assert ti.state == state
+
+ @pytest.mark.backend("mysql", "postgres")
+ def test_task_in_queued_tasks_dict_are_not_cleared(
+ self, session, dag_maker, create_dummy_dag, create_task_instance
+ ):
+ """Test that clear_stuck_queued_tasks doesn't clear tasks in executor.queued_tasks"""
+ ti = create_task_instance(state=State.QUEUED)
+ ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+ ti.external_executor_id = '231'
+ session.merge(ti)
+ session.flush()
+ executor = celery_executor.CeleryExecutor()
+ executor.queued_tasks = {ti.key: AsyncResult("231")}
+ executor._clear_stuck_queued_tasks()
+ session.flush()
+ ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+ assert executor.queued_tasks == {ti.key: AsyncResult("231")}
+ assert ti.state == State.QUEUED
+
+ @pytest.mark.backend("mysql", "postgres")
+ def test_task_in_running_dict_are_not_cleared(
+ self, session, dag_maker, create_dummy_dag, create_task_instance
+ ):
+ """Test that clear_stuck_queued_tasks doesn't clear tasks in executor.running"""
+ ti = create_task_instance(state=State.QUEUED)
+ ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+ ti.external_executor_id = '231'
+ session.merge(ti)
+ session.flush()
+ executor = celery_executor.CeleryExecutor()
+ executor.running = {ti.key: AsyncResult("231")}
+ executor._clear_stuck_queued_tasks()
+ session.flush()
+ ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+ assert executor.running == {ti.key: AsyncResult("231")}
+ assert ti.state == State.QUEUED
+
+ @pytest.mark.backend("mysql", "postgres")
+ def test_only_database_result_backend_supports_clearing_queued_task(
+ self, session, dag_maker, create_dummy_dag, create_task_instance
+ ):
+ with _prepare_app():
+ mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
+ with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend):
+ ti = create_task_instance(state=State.QUEUED)
+ ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+ ti.external_executor_id = '231'
+ session.merge(ti)
+ session.flush()
+ executor = celery_executor.CeleryExecutor()
+ executor.tasks = {ti.key: AsyncResult("231")}
+ executor._clear_stuck_queued_tasks()
+ session.flush()
+ ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+ # Not cleared
+ assert ti.state == State.QUEUED
+ assert executor.tasks == {ti.key: AsyncResult("231")}
+
+ @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+ @pytest.mark.backend("mysql", "postgres")
+ @freeze_time("2020-01-01")
+ @pytest.mark.parametrize(
+ "state",
+ [
+ (State.SCHEDULED),
+ (State.QUEUED),
+ ],
+ )
+ def test_the_check_interval_to_clear_stuck_queued_task_is_correct(
+ self,
+ mock_result_session,
+ state,
+ session,
+ dag_maker,
+ create_dummy_dag,
+ create_task_instance,
+ ):
+ with _prepare_app():
+ mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
+ with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend):
+ mock_session = mock_backend.ResultSession.return_value
+ mock_session.query.return_value.all.return_value = [
+ mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
+ ]
+ if state == State.SCHEDULED:
+ last_check_time = time.time() - 302 # should clear ti state
+ else:
+ last_check_time = time.time() - 298 # should not clear ti state
+
+ ti = create_task_instance(state=State.QUEUED)
+ ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+ ti.external_executor_id = '231'
+ session.merge(ti)
+ session.flush()
+ executor = celery_executor.CeleryExecutor()
+ executor.tasks = {ti.key: AsyncResult("231")}
+ executor.stuck_tasks_last_check_time = last_check_time
+ executor.sync()
+ session.flush()
+ ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+ assert ti.state == state
+
+ @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+ @pytest.mark.backend("mysql", "postgres")
+ @freeze_time("2020-01-01")
+ @pytest.mark.parametrize(
+ "task_id, state",
+ [
+ ('231', State.QUEUED),
+ ('111', State.SCHEDULED),
+ ],
+ )
+ def test_the_check_interval_to_clear_stuck_queued_task_is_correct_for_db_query(
+ self,
+ mock_result_session,
+ task_id,
+ state,
+ session,
+ dag_maker,
+ create_dummy_dag,
+ create_task_instance,
+ ):
+ """Here we test that task are not cleared if found in celery database"""
+ result_obj = namedtuple('Result', ['status', 'task_id'])
+ with _prepare_app():
+ mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
+ with mock.patch('airflow.executors.celery_executor.Celery.backend', mock_backend):
+ mock_session = mock_backend.ResultSession.return_value
+ mock_session.query.return_value.all.return_value = [result_obj("SUCCESS", task_id)]
+
+ last_check_time = time.time() - 302 # should clear ti state
+
+ ti = create_task_instance(state=State.QUEUED)
+ ti.queued_dttm = timezone.utcnow() - timedelta(days=2)
+ ti.external_executor_id = '231'
+ session.merge(ti)
+ session.flush()
+ executor = celery_executor.CeleryExecutor()
+ executor.tasks = {ti.key: AsyncResult("231")}
+ executor.stuck_tasks_last_check_time = last_check_time
+ executor.sync()
+ session.flush()
+ ti = session.query(TaskInstance).filter(TaskInstance.task_id == ti.task_id).one()
+ assert ti.state == state
+
def test_operation_timeout_config():
assert celery_executor.OPERATION_TIMEOUT == 1