You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by qi...@apache.org on 2021/05/29 15:28:21 UTC

[airflow] 01/02: Fix Celery executor getting stuck randomly because of reset_signals in multiprocessing (#15989)

This is an automated email from the ASF dual-hosted git repository.

qian pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 285791d60317bb3faf0601124256a3b49bc33c46
Author: yuqian90 <yu...@gmail.com>
AuthorDate: Sat May 29 23:00:54 2021 +0800

    Fix Celery executor getting stuck randomly because of reset_signals in multiprocessing (#15989)
    
    Fixes #15938
    
    multiprocessing.Pool is known to often become stuck. It causes celery_executor to hang randomly. This happens at least on Debian, Ubuntu using Python 3.8.7 and Python 3.8.10. The issue is reproducible by running test_send_tasks_to_celery_hang in this PR several times (with db backend set to something other than sqlite because sqlite disables some parallelization)
    
    The issue goes away once switched to concurrent.futures.ProcessPoolExecutor. In python 3.6 and earlier, ProcessPoolExecutor has no initializer argument. Fortunately, it's not needed because reset_signal is no longer needed because the signal handler now checks if the current process is the parent.
    
    (cherry picked from commit f75dd7ae6e755dad328ba6f3fd462ade194dab25)
---
 airflow/executors/celery_executor.py    | 24 +++++----------
 airflow/jobs/scheduler_job.py           | 16 ++++++++++
 scripts/ci/docker-compose/base.yml      |  2 ++
 tests/executors/test_celery_executor.py | 52 +++++++++++++++++++++++++++++++++
 4 files changed, 78 insertions(+), 16 deletions(-)

diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index bc321c6..553639b 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -30,7 +30,8 @@ import subprocess
 import time
 import traceback
 from collections import OrderedDict
-from multiprocessing import Pool, cpu_count
+from concurrent.futures import ProcessPoolExecutor
+from multiprocessing import cpu_count
 from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
 
 from celery import Celery, Task, states as celery_states
@@ -318,18 +319,9 @@ class CeleryExecutor(BaseExecutor):
         chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
         num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
 
-        def reset_signals():
-            # Since we are run from inside the SchedulerJob, we don't to
-            # inherit the signal handlers that we registered there.
-            import signal
-
-            signal.signal(signal.SIGINT, signal.SIG_DFL)
-            signal.signal(signal.SIGTERM, signal.SIG_DFL)
-            signal.signal(signal.SIGUSR2, signal.SIG_DFL)
-
-        with Pool(processes=num_processes, initializer=reset_signals) as send_pool:
-            key_and_async_results = send_pool.map(
-                send_task_to_executor, task_tuples_to_send, chunksize=chunksize
+        with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
+            key_and_async_results = list(
+                send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
             )
         return key_and_async_results
 
@@ -592,11 +584,11 @@ class BulkStateFetcher(LoggingMixin):
     def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, EventBufferValueType]:
         num_process = min(len(async_results), self._sync_parallelism)
 
-        with Pool(processes=num_process) as sync_pool:
+        with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
             chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / self._sync_parallelism)))
 
-            task_id_to_states_and_info = sync_pool.map(
-                fetch_celery_task_state, async_results, chunksize=chunksize
+            task_id_to_states_and_info = list(
+                sync_pool.map(fetch_celery_task_state, async_results, chunksize=chunksize)
             )
 
             states_and_info_by_task_id: MutableMapping[str, EventBufferValueType] = {}
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index e86a6e7..cece87e 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -670,6 +670,14 @@ class DagFileProcessor(LoggingMixin):
         return len(dagbag.dags), len(dagbag.import_errors)
 
 
+def _is_parent_process():
+    """
+    Returns True if the current process is the parent process. False if the current process is a child
+    process started by multiprocessing.
+    """
+    return multiprocessing.current_process().name == 'MainProcess'
+
+
 class SchedulerJob(BaseJob):  # pylint: disable=too-many-instance-attributes
     """
     This SchedulerJob runs for a specific time interval and schedules the jobs
@@ -745,12 +753,20 @@ class SchedulerJob(BaseJob):  # pylint: disable=too-many-instance-attributes
 
     def _exit_gracefully(self, signum, frame) -> None:  # pylint: disable=unused-argument
         """Helper method to clean up processor_agent to avoid leaving orphan processes."""
+        if not _is_parent_process():
+            # Only the parent process should perform the cleanup.
+            return
+
         self.log.info("Exiting gracefully upon receiving signal %s", signum)
         if self.processor_agent:
             self.processor_agent.end()
         sys.exit(os.EX_OK)
 
     def _debug_dump(self, signum, frame):  # pylint: disable=unused-argument
+        if not _is_parent_process():
+            # Only the parent process should perform the debug dump.
+            return
+
         try:
             sig_name = signal.Signals(signum).name  # pylint: disable=no-member
         except Exception:  # pylint: disable=broad-except
diff --git a/scripts/ci/docker-compose/base.yml b/scripts/ci/docker-compose/base.yml
index eab6425..6b1cb4e 100644
--- a/scripts/ci/docker-compose/base.yml
+++ b/scripts/ci/docker-compose/base.yml
@@ -34,6 +34,8 @@ services:
     ports:
       - "${WEBSERVER_HOST_PORT}:8080"
       - "${FLOWER_HOST_PORT}:5555"
+    cap_add:
+      - SYS_PTRACE
 volumes:
   sqlite-db-volume:
   postgres-db-volume:
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index f454c5a..19c8a0d 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -18,6 +18,7 @@
 import contextlib
 import json
 import os
+import signal
 import sys
 import unittest
 from datetime import datetime, timedelta
@@ -484,3 +485,54 @@ class TestBulkStateFetcher(unittest.TestCase):
         assert [
             'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
         ] == cm.output
+
+
+class MockTask:
+    """
+    A picklable object used to mock tasks sent to Celery. Can't use the mock library
+    here because it's not picklable.
+    """
+
+    def apply_async(self, *args, **kwargs):
+        return 1
+
+
+def _exit_gracefully(signum, _):
+    print(f"{os.getpid()} Exiting gracefully upon receiving signal {signum}")
+    sys.exit(signum)
+
+
+@pytest.fixture
+def register_signals():
+    """
+    Register the same signals as scheduler does to test celery_executor to make sure it does not
+    hang.
+    """
+    orig_sigint = orig_sigterm = orig_sigusr2 = signal.SIG_DFL
+
+    orig_sigint = signal.signal(signal.SIGINT, _exit_gracefully)
+    orig_sigterm = signal.signal(signal.SIGTERM, _exit_gracefully)
+    orig_sigusr2 = signal.signal(signal.SIGUSR2, _exit_gracefully)
+
+    yield
+
+    # Restore original signal handlers after test
+    signal.signal(signal.SIGINT, orig_sigint)
+    signal.signal(signal.SIGTERM, orig_sigterm)
+    signal.signal(signal.SIGUSR2, orig_sigusr2)
+
+
+def test_send_tasks_to_celery_hang(register_signals):  # pylint: disable=unused-argument
+    """
+    Test that celery_executor does not hang after many runs.
+    """
+    executor = celery_executor.CeleryExecutor()
+
+    task = MockTask()
+    task_tuples_to_send = [(None, None, None, None, task) for _ in range(26)]
+
+    for _ in range(500):
+        # This loop can hang on Linux if celery_executor does something wrong with
+        # multiprocessing.
+        results = executor._send_tasks_to_celery(task_tuples_to_send)
+        assert results == [(None, None, 1) for _ in task_tuples_to_send]