You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2023/09/25 08:49:59 UTC

[airflow] branch main updated: respect soft_fail argument when exception is raised for celery sensors (#34474)

This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f19e055789 respect soft_fail argument when exception is raised for celery sensors (#34474)
f19e055789 is described below

commit f19e0557890a86f7a622bada99f7a054edd3cfe0
Author: Wei Lee <we...@gmail.com>
AuthorDate: Mon Sep 25 16:49:50 2023 +0800

    respect soft_fail argument when exception is raised for celery sensors (#34474)
---
 airflow/providers/celery/sensors/celery_queue.py    | 14 +++++++++++---
 tests/providers/celery/sensors/test_celery_queue.py | 17 +++++++++++++++++
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py
index 4533217bff..9800ccdb5b 100644
--- a/airflow/providers/celery/sensors/celery_queue.py
+++ b/airflow/providers/celery/sensors/celery_queue.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
 
 from celery.app import control
 
+from airflow.exceptions import AirflowSkipException
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -39,7 +40,6 @@ class CeleryQueueSensor(BaseSensorOperator):
     """
 
     def __init__(self, *, celery_queue: str, target_task_id: str | None = None, **kwargs) -> None:
-
         super().__init__(**kwargs)
         self.celery_queue = celery_queue
         self.target_task_id = target_task_id
@@ -56,7 +56,6 @@ class CeleryQueueSensor(BaseSensorOperator):
         return celery_result.ready()
 
     def poke(self, context: Context) -> bool:
-
         if self.target_task_id:
             return self._check_task_id(context)
 
@@ -74,4 +73,13 @@ class CeleryQueueSensor(BaseSensorOperator):
 
             return reserved == 0 and scheduled == 0 and active == 0
         except KeyError:
-            raise KeyError(f"Could not locate Celery queue {self.celery_queue}")
+            # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
+            message = f"Could not locate Celery queue {self.celery_queue}"
+            if self.soft_fail:
+                raise AirflowSkipException(message)
+            raise KeyError(message)
+        except Exception as err:
+            # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
+            if self.soft_fail:
+                raise AirflowSkipException from err
+            raise
diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py
index 8d09085352..f2d619e50c 100644
--- a/tests/providers/celery/sensors/test_celery_queue.py
+++ b/tests/providers/celery/sensors/test_celery_queue.py
@@ -19,6 +19,9 @@ from __future__ import annotations
 
 from unittest.mock import patch
 
+import pytest
+
+from airflow.exceptions import AirflowSkipException
 from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor
 
 
@@ -54,6 +57,20 @@ class TestCeleryQueueSensor:
         test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task")
         assert not test_sensor.poke(None)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, KeyError), (True, AirflowSkipException))
+    )
+    @patch("celery.app.control.Inspect")
+    def test_poke_fail_with_exception(self, mock_inspect, soft_fail, expected_exception):
+        mock_inspect_result = mock_inspect.return_value
+        mock_inspect_result.reserved.return_value = {}
+        mock_inspect_result.scheduled.return_value = {}
+        mock_inspect_result.active.return_value = {}
+
+        with pytest.raises(expected_exception):
+            test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task", soft_fail=soft_fail)
+            test_sensor.poke(None)
+
     @patch("celery.app.control.Inspect")
     def test_poke_success_with_taskid(self, mock_inspect):
         test_sensor = self.sensor(