You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2021/07/30 04:37:09 UTC

[airflow] branch main updated: Fix `airflow celery stop` to accept the pid file. (#17278)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b814a58  Fix `airflow celery stop` to accept the pid file. (#17278)
b814a58 is described below

commit b814a58edbf6cad9c7bcc8375c544f35c40ccc97
Author: Santosh Pingale <pi...@gmail.com>
AuthorDate: Fri Jul 30 06:36:51 2021 +0200

    Fix `airflow celery stop` to accept the pid file. (#17278)
---
 airflow/cli/cli_parser.py                 |  2 +-
 airflow/cli/commands/celery_command.py    |  5 +++-
 tests/cli/commands/test_celery_command.py | 39 +++++++++++++++++++++++++++++++
 3 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index b652838..1cca4c2 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -1383,7 +1383,7 @@ CELERY_COMMANDS = (
         name='stop',
         help="Stop the Celery worker gracefully",
         func=lazy_load_command('airflow.cli.commands.celery_command.stop_worker'),
-        args=(),
+        args=(ARG_PID,),
     ),
 )
 
diff --git a/airflow/cli/commands/celery_command.py b/airflow/cli/commands/celery_command.py
index ba3c45e..e9c3e38 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -183,7 +183,10 @@ def worker(args):
 def stop_worker(args):
     """Sends SIGTERM to Celery worker"""
     # Read PID from file
-    pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME)
+    if args.pid:
+        pid_file_path = args.pid
+    else:
+        pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME)
     pid = read_pid_from_pidfile(pid_file_path)
 
     # Send SIGTERM
diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py
index e2c1668..3fbadf1 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -20,6 +20,7 @@ from argparse import Namespace
 from tempfile import NamedTemporaryFile
 from unittest import mock
 
+import os
 import pytest
 import sqlalchemy
 
@@ -142,6 +143,44 @@ class TestCeleryStopCommand(unittest.TestCase):
         celery_command.stop_worker(stop_args)
         mock_read_pid_from_pidfile.assert_called_once_with(pid_file)
 
+    @mock.patch("airflow.cli.commands.celery_command.read_pid_from_pidfile")
+    @mock.patch("airflow.cli.commands.celery_command.worker_bin.worker")
+    @mock.patch("airflow.cli.commands.celery_command.psutil.Process")
+    @conf_vars({("core", "executor"): "CeleryExecutor"})
+    def test_custom_pid_file_is_used_in_start_and_stop(
+        self, mock_celery_worker, mock_read_pid_from_pidfile, mock_process
+    ):
+        pid_file = "custom_test_pid_file"
+
+        # Call worker
+        worker_args = self.parser.parse_args(['celery', 'worker', '--skip-serve-logs', '--pid', pid_file])
+        celery_command.worker(worker_args)
+        run_mock = mock_celery_worker.return_value.run
+        assert run_mock.call_args
+        args, kwargs = run_mock.call_args
+        assert 'pidfile' in kwargs
+        assert kwargs['pidfile'] == pid_file
+        assert not args
+        assert os.path.exists(pid_file)
+
+        with open(pid_file) as pid_fd:
+            pid = "".join(pid_fd.readlines())
+
+            # Call stop
+            stop_args = self.parser.parse_args(['celery', 'stop', '--pid', pid_file])
+            celery_command.stop_worker(stop_args)
+            run_mock = mock_celery_worker.return_value.run
+            assert run_mock.call_args
+            args, kwargs = run_mock.call_args
+            assert 'pidfile' in kwargs
+            assert kwargs['pidfile'] == pid_file
+            assert not args
+
+            mock_read_pid_from_pidfile.assert_called_once_with(pid_file)
+            mock_process.assert_called_once_with(int(pid))
+            mock_process.return_value.terminate.assert_called_once_with()
+            assert not os.path.exists(pid_file)
+
 
 @pytest.mark.backend("mysql", "postgres")
 class TestWorkerStart(unittest.TestCase):