You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/09/25 14:03:45 UTC

[airflow] branch main updated: Respect `soft_fail` argument when running `SqsSensor` (#34569)

This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b5c7676b5 Respect `soft_fail` argument when running `SqsSensor` (#34569)
2b5c7676b5 is described below

commit 2b5c7676b535bc5910c726c851181f9b87362994
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Mon Sep 25 19:33:37 2023 +0530

    Respect `soft_fail` argument when running `SqsSensor` (#34569)
---
 airflow/providers/amazon/aws/sensors/sqs.py    | 14 ++++++--
 tests/providers/amazon/aws/sensors/test_sqs.py | 48 +++++++++++++++++++++++++-
 2 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index 33f989d03c..fce4363517 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -25,7 +25,7 @@ from deprecated import deprecated
 from typing_extensions import Literal
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
 from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
 from airflow.providers.amazon.aws.utils.sqs import process_response
@@ -145,7 +145,11 @@ class SqsSensor(BaseSensorOperator):
 
     def execute_complete(self, context: Context, event: dict | None = None) -> None:
         if event is None or event["status"] != "success":
-            raise AirflowException(f"Trigger error: event is {event}")
+            # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+            message = f"Trigger error: event is {event}"
+            if self.soft_fail:
+                raise AirflowSkipException(message)
+            raise AirflowException(message)
         context["ti"].xcom_push(key="messages", value=event["message_batch"])
 
     def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
@@ -203,7 +207,11 @@ class SqsSensor(BaseSensorOperator):
                 response = self.hook.conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
 
                 if "Successful" not in response:
-                    raise AirflowException(f"Delete SQS Messages failed {response} for messages {messages}")
+                    # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+                    error_message = f"Delete SQS Messages failed {response} for messages {messages}"
+                    if self.soft_fail:
+                        raise AirflowSkipException(error_message)
+                    raise AirflowException(error_message)
         if message_batch:
             context["ti"].xcom_push(key="messages", value=message_batch)
             return True
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index a9b03bd40a..fe3aadb430 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -24,7 +24,7 @@ from unittest.mock import patch
 import pytest
 from moto import mock_sqs
 
-from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
 from airflow.models.dag import DAG
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
 from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
@@ -346,3 +346,49 @@ class TestSqsSensor:
         )
         with pytest.raises(TaskDeferred):
             self.sensor.execute(None)
+
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+    )
+    def test_fail_execute_complete(self, soft_fail, expected_exception):
+        self.sensor = SqsSensor(
+            task_id="test_task_deferrable",
+            dag=self.dag,
+            sqs_queue=QUEUE_URL,
+            aws_conn_id="aws_default",
+            max_messages=1,
+            num_batches=3,
+            deferrable=True,
+            soft_fail=soft_fail,
+        )
+        event = {"status": "failed"}
+        message = f"Trigger error: event is {event}"
+        with pytest.raises(expected_exception, match=message):
+            self.sensor.execute_complete(context={}, event=event)
+
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+    )
+    @mock.patch("airflow.providers.amazon.aws.sensors.sqs.SqsSensor.poll_sqs")
+    @mock.patch("airflow.providers.amazon.aws.sensors.sqs.process_response")
+    @mock.patch("airflow.providers.amazon.aws.hooks.sqs.SqsHook.conn")
+    def test_fail_poke(self, conn, process_response, poll_sqs, soft_fail, expected_exception):
+        self.sensor = SqsSensor(
+            task_id="test_task_deferrable",
+            dag=self.dag,
+            sqs_queue=QUEUE_URL,
+            aws_conn_id="aws_default",
+            max_messages=1,
+            num_batches=3,
+            deferrable=True,
+            soft_fail=soft_fail,
+        )
+        response = "error message"
+        messages = [{"MessageId": "1", "ReceiptHandle": "test"}]
+        poll_sqs.return_value = response
+        process_response.return_value = messages
+        conn.delete_message_batch.return_value = response
+        error_message = f"Delete SQS Messages failed {response} for messages"
+        self.sensor.delete_message_on_reception = True
+        with pytest.raises(expected_exception, match=error_message):
+            self.sensor.poke(context={})