You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/09/25 14:03:45 UTC
[airflow] branch main updated: Respect `soft_fail` argument when running `SqsSensor` (#34569)
This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2b5c7676b5 Respect `soft_fail` argument when running `SqsSensor` (#34569)
2b5c7676b5 is described below
commit 2b5c7676b535bc5910c726c851181f9b87362994
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Mon Sep 25 19:33:37 2023 +0530
Respect `soft_fail` argument when running `SqsSensor` (#34569)
---
airflow/providers/amazon/aws/sensors/sqs.py | 14 ++++++--
tests/providers/amazon/aws/sensors/test_sqs.py | 48 +++++++++++++++++++++++++-
2 files changed, 58 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index 33f989d03c..fce4363517 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -25,7 +25,7 @@ from deprecated import deprecated
from typing_extensions import Literal
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils.sqs import process_response
@@ -145,7 +145,11 @@ class SqsSensor(BaseSensorOperator):
def execute_complete(self, context: Context, event: dict | None = None) -> None:
if event is None or event["status"] != "success":
- raise AirflowException(f"Trigger error: event is {event}")
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ message = f"Trigger error: event is {event}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
context["ti"].xcom_push(key="messages", value=event["message_batch"])
def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
@@ -203,7 +207,11 @@ class SqsSensor(BaseSensorOperator):
response = self.hook.conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
if "Successful" not in response:
- raise AirflowException(f"Delete SQS Messages failed {response} for messages {messages}")
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ error_message = f"Delete SQS Messages failed {response} for messages {messages}"
+ if self.soft_fail:
+ raise AirflowSkipException(error_message)
+ raise AirflowException(error_message)
if message_batch:
context["ti"].xcom_push(key="messages", value=message_batch)
return True
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index a9b03bd40a..fe3aadb430 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -24,7 +24,7 @@ from unittest.mock import patch
import pytest
from moto import mock_sqs
-from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
@@ -346,3 +346,49 @@ class TestSqsSensor:
)
with pytest.raises(TaskDeferred):
self.sensor.execute(None)
+
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+ )
+ def test_fail_execute_complete(self, soft_fail, expected_exception):
+ self.sensor = SqsSensor(
+ task_id="test_task_deferrable",
+ dag=self.dag,
+ sqs_queue=QUEUE_URL,
+ aws_conn_id="aws_default",
+ max_messages=1,
+ num_batches=3,
+ deferrable=True,
+ soft_fail=soft_fail,
+ )
+ event = {"status": "failed"}
+ message = f"Trigger error: event is {event}"
+ with pytest.raises(expected_exception, match=message):
+ self.sensor.execute_complete(context={}, event=event)
+
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+ )
+ @mock.patch("airflow.providers.amazon.aws.sensors.sqs.SqsSensor.poll_sqs")
+ @mock.patch("airflow.providers.amazon.aws.sensors.sqs.process_response")
+ @mock.patch("airflow.providers.amazon.aws.hooks.sqs.SqsHook.conn")
+ def test_fail_poke(self, conn, process_response, poll_sqs, soft_fail, expected_exception):
+ self.sensor = SqsSensor(
+ task_id="test_task_deferrable",
+ dag=self.dag,
+ sqs_queue=QUEUE_URL,
+ aws_conn_id="aws_default",
+ max_messages=1,
+ num_batches=3,
+ deferrable=True,
+ soft_fail=soft_fail,
+ )
+ response = "error message"
+ messages = [{"MessageId": "1", "ReceiptHandle": "test"}]
+ poll_sqs.return_value = response
+ process_response.return_value = messages
+ conn.delete_message_batch.return_value = response
+ error_message = f"Delete SQS Messages failed {response} for messages"
+ self.sensor.delete_message_on_reception = True
+ with pytest.raises(expected_exception, match=error_message):
+ self.sensor.poke(context={})