You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2023/09/25 08:49:59 UTC
[airflow] branch main updated: respect soft_fail argument when exception is raised for celery sensors (#34474)
This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f19e055789 respect soft_fail argument when exception is raised for celery sensors (#34474)
f19e055789 is described below
commit f19e0557890a86f7a622bada99f7a054edd3cfe0
Author: Wei Lee <we...@gmail.com>
AuthorDate: Mon Sep 25 16:49:50 2023 +0800
respect soft_fail argument when exception is raised for celery sensors (#34474)
---
airflow/providers/celery/sensors/celery_queue.py | 14 +++++++++++---
tests/providers/celery/sensors/test_celery_queue.py | 17 +++++++++++++++++
2 files changed, 28 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py
index 4533217bff..9800ccdb5b 100644
--- a/airflow/providers/celery/sensors/celery_queue.py
+++ b/airflow/providers/celery/sensors/celery_queue.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
from celery.app import control
+from airflow.exceptions import AirflowSkipException
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -39,7 +40,6 @@ class CeleryQueueSensor(BaseSensorOperator):
"""
def __init__(self, *, celery_queue: str, target_task_id: str | None = None, **kwargs) -> None:
-
super().__init__(**kwargs)
self.celery_queue = celery_queue
self.target_task_id = target_task_id
@@ -56,7 +56,6 @@ class CeleryQueueSensor(BaseSensorOperator):
return celery_result.ready()
def poke(self, context: Context) -> bool:
-
if self.target_task_id:
return self._check_task_id(context)
@@ -74,4 +73,13 @@ class CeleryQueueSensor(BaseSensorOperator):
return reserved == 0 and scheduled == 0 and active == 0
except KeyError:
- raise KeyError(f"Could not locate Celery queue {self.celery_queue}")
+ # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
+ message = f"Could not locate Celery queue {self.celery_queue}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise KeyError(message)
+ except Exception as err:
+ # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException from err
+ raise
diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py
index 8d09085352..f2d619e50c 100644
--- a/tests/providers/celery/sensors/test_celery_queue.py
+++ b/tests/providers/celery/sensors/test_celery_queue.py
@@ -19,6 +19,9 @@ from __future__ import annotations
from unittest.mock import patch
+import pytest
+
+from airflow.exceptions import AirflowSkipException
from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor
@@ -54,6 +57,20 @@ class TestCeleryQueueSensor:
test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task")
assert not test_sensor.poke(None)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, KeyError), (True, AirflowSkipException))
+ )
+ @patch("celery.app.control.Inspect")
+ def test_poke_fail_with_exception(self, mock_inspect, soft_fail, expected_exception):
+ mock_inspect_result = mock_inspect.return_value
+ mock_inspect_result.reserved.return_value = {}
+ mock_inspect_result.scheduled.return_value = {}
+ mock_inspect_result.active.return_value = {}
+
+ with pytest.raises(expected_exception):
+ test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task", soft_fail=soft_fail)
+ test_sensor.poke(None)
+
@patch("celery.app.control.Inspect")
def test_poke_success_with_taskid(self, mock_inspect):
test_sensor = self.sensor(