You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by on...@apache.org on 2023/08/04 05:18:45 UTC
[airflow] branch main updated: Deferrable mode for Sqs Sensor (#32809)
This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 087d14ada2 Deferrable mode for Sqs Sensor (#32809)
087d14ada2 is described below
commit 087d14ada24e87fdf9db58a13acf0f2753191256
Author: Syed Hussain <10...@users.noreply.github.com>
AuthorDate: Thu Aug 3 22:18:38 2023 -0700
Deferrable mode for Sqs Sensor (#32809)
---
airflow/providers/amazon/aws/sensors/sqs.py | 94 +++++------
.../amazon/aws/{sensors => triggers}/sqs.py | 183 +++++++--------------
airflow/providers/amazon/aws/utils/sqs.py | 90 ++++++++++
airflow/providers/amazon/provider.yaml | 3 +
.../operators/sqs.rst | 1 +
tests/providers/amazon/aws/sensors/test_sqs.py | 15 +-
tests/providers/amazon/aws/triggers/test_sqs.py | 108 ++++++++++++
7 files changed, 322 insertions(+), 172 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index da7f7e513e..a61652acb2 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -18,20 +18,22 @@
"""Reads and then deletes the message from SQS queue."""
from __future__ import annotations
-import json
from functools import cached_property
from typing import TYPE_CHECKING, Any, Collection, Literal, Sequence
from deprecated import deprecated
-from jsonpath_ng import parse
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
+from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
+from airflow.providers.amazon.aws.utils.sqs import process_response
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
+from datetime import timedelta
class SqsSensor(BaseSensorOperator):
@@ -70,6 +72,9 @@ class SqsSensor(BaseSensorOperator):
:param delete_message_on_reception: Default to `True`, the messages are deleted from the queue
as soon as being consumed. Otherwise, the messages remain in the queue after consumption and
should be deleted manually.
+ :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
+ module to be installed.
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
"""
@@ -88,6 +93,7 @@ class SqsSensor(BaseSensorOperator):
message_filtering_match_values: Any = None,
message_filtering_config: Any = None,
delete_message_on_reception: bool = True,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
@@ -112,6 +118,34 @@ class SqsSensor(BaseSensorOperator):
raise TypeError("message_filtering_match_values must be specified for literal matching")
self.message_filtering_config = message_filtering_config
+ self.deferrable = deferrable
+
+ def execute(self, context: Context) -> Any:
+ if self.deferrable:
+ self.defer(
+ trigger=SqsSensorTrigger(
+ sqs_queue=self.sqs_queue,
+ aws_conn_id=self.aws_conn_id,
+ max_messages=self.max_messages,
+ num_batches=self.num_batches,
+ wait_time_seconds=self.wait_time_seconds,
+ visibility_timeout=self.visibility_timeout,
+ message_filtering=self.message_filtering,
+ message_filtering_match_values=self.message_filtering_match_values,
+ message_filtering_config=self.message_filtering_config,
+ delete_message_on_reception=self.delete_message_on_reception,
+ waiter_delay=int(self.poke_interval),
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.timeout),
+ )
+ else:
+ super().execute(context=context)
+
+ def execute_complete(self, context: Context, event: dict | None = None) -> None:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Trigger error: event is {event}")
+ context["ti"].xcom_push(key="messages", value=event["message_batch"])
def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
"""
@@ -131,19 +165,7 @@ class SqsSensor(BaseSensorOperator):
receive_message_kwargs["VisibilityTimeout"] = self.visibility_timeout
response = sqs_conn.receive_message(**receive_message_kwargs)
-
- if "Messages" not in response:
- return []
-
- messages = response["Messages"]
- num_messages = len(messages)
- self.log.info("Received %d messages", num_messages)
-
- if num_messages and self.message_filtering:
- messages = self.filter_messages(messages)
- num_messages = len(messages)
- self.log.info("There are %d messages left after filtering", num_messages)
- return messages
+ return response
def poke(self, context: Context):
"""
@@ -156,7 +178,13 @@ class SqsSensor(BaseSensorOperator):
# perform multiple SQS call to retrieve messages in series
for _ in range(self.num_batches):
- messages = self.poll_sqs(sqs_conn=self.hook.conn)
+ response = self.poll_sqs(sqs_conn=self.hook.conn)
+ messages = process_response(
+ response,
+ self.message_filtering,
+ self.message_filtering_match_values,
+ self.message_filtering_config,
+ )
if not len(messages):
continue
@@ -191,37 +219,3 @@ class SqsSensor(BaseSensorOperator):
@cached_property
def hook(self) -> SqsHook:
return SqsHook(aws_conn_id=self.aws_conn_id)
-
- def filter_messages(self, messages):
- if self.message_filtering == "literal":
- return self.filter_messages_literal(messages)
- if self.message_filtering == "jsonpath":
- return self.filter_messages_jsonpath(messages)
- else:
- raise NotImplementedError("Override this method to define custom filters")
-
- def filter_messages_literal(self, messages):
- filtered_messages = []
- for message in messages:
- if message["Body"] in self.message_filtering_match_values:
- filtered_messages.append(message)
- return filtered_messages
-
- def filter_messages_jsonpath(self, messages):
- jsonpath_expr = parse(self.message_filtering_config)
- filtered_messages = []
- for message in messages:
- body = message["Body"]
- # Body is a string, deserialize to an object and then parse
- body = json.loads(body)
- results = jsonpath_expr.find(body)
- if not results:
- continue
- if self.message_filtering_match_values is None:
- filtered_messages.append(message)
- continue
- for result in results:
- if result.value in self.message_filtering_match_values:
- filtered_messages.append(message)
- break
- return filtered_messages
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/triggers/sqs.py
similarity index 51%
copy from airflow/providers/amazon/aws/sensors/sqs.py
copy to airflow/providers/amazon/aws/triggers/sqs.py
index da7f7e513e..7e26b9c28f 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/triggers/sqs.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,42 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Reads and then deletes the message from SQS queue."""
from __future__ import annotations
-import json
-from functools import cached_property
-from typing import TYPE_CHECKING, Any, Collection, Literal, Sequence
-
-from deprecated import deprecated
-from jsonpath_ng import parse
+import asyncio
+from typing import Any, AsyncIterator, Collection, Literal
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.utils.sqs import process_response
+from airflow.triggers.base import BaseTrigger, TriggerEvent
-if TYPE_CHECKING:
- from airflow.utils.context import Context
-
-class SqsSensor(BaseSensorOperator):
+class SqsSensorTrigger(BaseTrigger):
"""
- Get messages from an Amazon SQS queue and then delete the messages from the queue.
-
- If deletion of messages fails, an AirflowException is thrown. Otherwise, the messages
- are pushed through XCom with the key ``messages``.
-
- By default,the sensor performs one and only one SQS call per poke, which limits the result to
- a maximum of 10 messages. However, the total number of SQS API calls per poke can be controlled
- by num_batches param.
-
- .. seealso::
- For more information on how to use this sensor, take a look at the guide:
- :ref:`howto/sensor:SqsSensor`
+ Asynchronously get messages from an Amazon SQS queue and then delete the messages from the queue.
+ :param sqs_queue: The SQS queue url
:param aws_conn_id: AWS connection id
- :param sqs_queue: The SQS queue url (templated)
:param max_messages: The maximum number of messages to retrieve for each poke (templated)
:param num_batches: The number of times the sensor will call the SQS API to receive messages (default: 1)
:param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second)
@@ -70,15 +51,12 @@ class SqsSensor(BaseSensorOperator):
:param delete_message_on_reception: Default to `True`, the messages are deleted from the queue
as soon as being consumed. Otherwise, the messages remain in the queue after consumption and
should be deleted manually.
-
+ :param waiter_delay: The time in seconds to wait between calls to the SQS API to receive messages.
"""
- template_fields: Sequence[str] = ("sqs_queue", "max_messages", "message_filtering_config")
-
def __init__(
self,
- *,
- sqs_queue,
+ sqs_queue: str,
aws_conn_id: str = "aws_default",
max_messages: int = 5,
num_batches: int = 1,
@@ -88,36 +66,47 @@ class SqsSensor(BaseSensorOperator):
message_filtering_match_values: Any = None,
message_filtering_config: Any = None,
delete_message_on_reception: bool = True,
- **kwargs,
+ waiter_delay: int = 60,
):
- super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
self.num_batches = num_batches
self.wait_time_seconds = wait_time_seconds
self.visibility_timeout = visibility_timeout
-
self.message_filtering = message_filtering
-
self.delete_message_on_reception = delete_message_on_reception
-
- if message_filtering_match_values is not None:
- if not isinstance(message_filtering_match_values, set):
- message_filtering_match_values = set(message_filtering_match_values)
self.message_filtering_match_values = message_filtering_match_values
-
- if self.message_filtering == "literal":
- if self.message_filtering_match_values is None:
- raise TypeError("message_filtering_match_values must be specified for literal matching")
-
self.message_filtering_config = message_filtering_config
+ self.waiter_delay = waiter_delay
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "sqs_queue": self.sqs_queue,
+ "aws_conn_id": self.aws_conn_id,
+ "max_messages": self.max_messages,
+ "num_batches": self.num_batches,
+ "wait_time_seconds": self.wait_time_seconds,
+ "visibility_timeout": self.visibility_timeout,
+ "message_filtering": self.message_filtering,
+ "delete_message_on_reception": self.delete_message_on_reception,
+ "message_filtering_match_values": self.message_filtering_match_values,
+ "message_filtering_config": self.message_filtering_config,
+ "waiter_delay": self.waiter_delay,
+ },
+ )
+
+ @property
+ def hook(self) -> SqsHook:
+ return SqsHook(aws_conn_id=self.aws_conn_id)
- def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
+ async def poll_sqs(self, client: BaseAwsConnection) -> Collection:
"""
- Poll SQS queue to retrieve messages.
+ Asynchronously poll SQS queue to retrieve messages.
- :param sqs_conn: SQS connection
+ :param client: SQS connection
:return: A list of messages retrieved from SQS
"""
self.log.info("SqsSensor checking for message on queue: %s", self.sqs_queue)
@@ -130,98 +119,50 @@ class SqsSensor(BaseSensorOperator):
if self.visibility_timeout is not None:
receive_message_kwargs["VisibilityTimeout"] = self.visibility_timeout
- response = sqs_conn.receive_message(**receive_message_kwargs)
+ response = await client.receive_message(**receive_message_kwargs)
+ return response
- if "Messages" not in response:
- return []
-
- messages = response["Messages"]
- num_messages = len(messages)
- self.log.info("Received %d messages", num_messages)
-
- if num_messages and self.message_filtering:
- messages = self.filter_messages(messages)
- num_messages = len(messages)
- self.log.info("There are %d messages left after filtering", num_messages)
- return messages
-
- def poke(self, context: Context):
- """
- Check subscribed queue for messages and write them to xcom with the ``messages`` key.
-
- :param context: the context object
- :return: ``True`` if message is available or ``False``
- """
+ async def poke(self, client: Any):
message_batch: list[Any] = []
-
- # perform multiple SQS call to retrieve messages in series
for _ in range(self.num_batches):
- messages = self.poll_sqs(sqs_conn=self.hook.conn)
-
- if not len(messages):
+ self.log.info("starting call to poll sqs")
+ response = await self.poll_sqs(client=client)
+ messages = process_response(
+ response,
+ self.message_filtering,
+ self.message_filtering_match_values,
+ self.message_filtering_config,
+ )
+
+ if not messages:
continue
message_batch.extend(messages)
if self.delete_message_on_reception:
-
self.log.info("Deleting %d messages", len(messages))
entries = [
{"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}
for message in messages
]
- response = self.hook.conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+ response = await client.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
if "Successful" not in response:
raise AirflowException(
- "Delete SQS Messages failed " + str(response) + " for messages " + str(messages)
+ f"Delete SQS Messages failed {str(response)} for messages {str(messages)}"
)
- if not len(message_batch):
- return False
-
- context["ti"].xcom_push(key="messages", value=message_batch)
- return True
- @deprecated(reason="use `hook` property instead.")
- def get_hook(self) -> SqsHook:
- """Create and return an SqsHook."""
- return self.hook
+ return message_batch
- @cached_property
- def hook(self) -> SqsHook:
- return SqsHook(aws_conn_id=self.aws_conn_id)
-
- def filter_messages(self, messages):
- if self.message_filtering == "literal":
- return self.filter_messages_literal(messages)
- if self.message_filtering == "jsonpath":
- return self.filter_messages_jsonpath(messages)
- else:
- raise NotImplementedError("Override this method to define custom filters")
-
- def filter_messages_literal(self, messages):
- filtered_messages = []
- for message in messages:
- if message["Body"] in self.message_filtering_match_values:
- filtered_messages.append(message)
- return filtered_messages
-
- def filter_messages_jsonpath(self, messages):
- jsonpath_expr = parse(self.message_filtering_config)
- filtered_messages = []
- for message in messages:
- body = message["Body"]
- # Body is a string, deserialize to an object and then parse
- body = json.loads(body)
- results = jsonpath_expr.find(body)
- if not results:
- continue
- if self.message_filtering_match_values is None:
- filtered_messages.append(message)
- continue
- for result in results:
- if result.value in self.message_filtering_match_values:
- filtered_messages.append(message)
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ while True:
+ # This loop will run indefinitely until the timeout, which is set in the self.defer
+ # method, is reached.
+ async with self.hook.async_conn as client:
+ result = await self.poke(client=client)
+ if result:
+ yield TriggerEvent({"status": "success", "message_batch": result})
break
- return filtered_messages
+ else:
+ await asyncio.sleep(self.waiter_delay)
diff --git a/airflow/providers/amazon/aws/utils/sqs.py b/airflow/providers/amazon/aws/utils/sqs.py
new file mode 100644
index 0000000000..2b081e5259
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/sqs.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+from typing import Any, Literal
+
+from jsonpath_ng import parse
+
+log = logging.getLogger(__name__)
+
+
+def process_response(
+ response: Any,
+ message_filtering: Literal["literal", "jsonpath"] | None = None,
+ message_filtering_match_values: Any = None,
+ message_filtering_config: Any = None,
+) -> Any:
+ """
+ Process the response from SQS.
+
+ :param response: The response from SQS
+ :return: The processed response
+ """
+ if not isinstance(response, dict):
+ return []
+ elif "Messages" not in response:
+ return []
+
+ messages = response["Messages"]
+ num_messages = len(messages)
+ log.info("Received %d messages", num_messages)
+
+ if num_messages and message_filtering:
+ messages = filter_messages(
+ messages, message_filtering, message_filtering_match_values, message_filtering_config
+ )
+ num_messages = len(messages)
+ log.info("There are %d messages left after filtering", num_messages)
+ return messages
+
+
+def filter_messages(
+ messages, message_filtering, message_filtering_match_values, message_filtering_config
+) -> list[Any]:
+ if message_filtering == "literal":
+ return filter_messages_literal(messages, message_filtering_match_values)
+ if message_filtering == "jsonpath":
+ return filter_messages_jsonpath(messages, message_filtering_match_values, message_filtering_config)
+ else:
+ raise NotImplementedError("Override this method to define custom filters")
+
+
+def filter_messages_literal(messages, message_filtering_match_values) -> list[Any]:
+ return [message for message in messages if message["Body"] in message_filtering_match_values]
+
+
+def filter_messages_jsonpath(messages, message_filtering_match_values, message_filtering_config) -> list[Any]:
+ jsonpath_expr = parse(message_filtering_config)
+ filtered_messages = []
+ for message in messages:
+ body = message["Body"]
+ # Body is a string, deserialize to an object and then parse
+ body = json.loads(body)
+ results = jsonpath_expr.find(body)
+ if not results:
+ continue
+ if message_filtering_match_values is None:
+ filtered_messages.append(message)
+ continue
+ for result in results:
+ if result.value in message_filtering_match_values:
+ filtered_messages.append(message)
+ break
+ return filtered_messages
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index a0af095b40..76d1cb376e 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -572,6 +572,9 @@ triggers:
- integration-name: AWS Step Functions
python-modules:
- airflow.providers.amazon.aws.triggers.step_function
+ - integration-name: Amazon Simple Queue Service (SQS)
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.sqs
transfers:
- source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/sqs.rst b/docs/apache-airflow-providers-amazon/operators/sqs.rst
index 13c806626e..9f72a2eff8 100644
--- a/docs/apache-airflow-providers-amazon/operators/sqs.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sqs.rst
@@ -61,6 +61,7 @@ Read messages from an Amazon SQS queue
To read messages from an Amazon SQS queue until exhausted use the
:class:`~airflow.providers.amazon.aws.sensors.sqs.SqsSensor`
+This sensor can also be run in deferrable mode by setting ``deferrable`` param to ``True``.
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sqs.py
:language: python
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index e6175c61d3..a9b03bd40a 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -24,7 +24,7 @@ from unittest.mock import patch
import pytest
from moto import mock_sqs
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
@@ -333,3 +333,16 @@ class TestSqsSensor:
assert f"'Body': '{message}'" in str(
self.mock_context["ti"].method_calls
), "context call should contain message '{message}'"
+
+ def test_sqs_deferrable(self):
+ self.sensor = SqsSensor(
+ task_id="test_task_deferrable",
+ dag=self.dag,
+ sqs_queue=QUEUE_URL,
+ aws_conn_id="aws_default",
+ max_messages=1,
+ num_batches=3,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred):
+ self.sensor.execute(None)
diff --git a/tests/providers/amazon/aws/triggers/test_sqs.py b/tests/providers/amazon/aws/triggers/test_sqs.py
new file mode 100644
index 0000000000..74ffd837a1
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_sqs.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
+
+TEST_SQS_QUEUE = "test-sqs-queue"
+TEST_AWS_CONN_ID = "test-aws-conn-id"
+TEST_MAX_MESSAGES = 1
+TEST_NUM_BATCHES = 1
+TEST_WAIT_TIME_SECONDS = 1
+TEST_VISIBILITY_TIMEOUT = 1
+TEST_MESSAGE_FILTERING = "literal"
+TEST_MESSAGE_FILTERING_MATCH_VALUES = "test"
+TEST_MESSAGE_FILTERING_CONFIG = "test-message-filtering-config"
+TEST_DELETE_MESSAGE_ON_RECEPTION = False
+TEST_WAITER_DELAY = 1
+
+trigger = SqsSensorTrigger(
+ sqs_queue=TEST_SQS_QUEUE,
+ aws_conn_id=TEST_AWS_CONN_ID,
+ max_messages=TEST_MAX_MESSAGES,
+ num_batches=TEST_NUM_BATCHES,
+ wait_time_seconds=TEST_WAIT_TIME_SECONDS,
+ visibility_timeout=TEST_VISIBILITY_TIMEOUT,
+ message_filtering=TEST_MESSAGE_FILTERING,
+ message_filtering_match_values=TEST_MESSAGE_FILTERING_MATCH_VALUES,
+ message_filtering_config=TEST_MESSAGE_FILTERING_CONFIG,
+ delete_message_on_reception=TEST_DELETE_MESSAGE_ON_RECEPTION,
+ waiter_delay=TEST_WAITER_DELAY,
+)
+
+
+class TestSqsTriggers:
+ @pytest.mark.parametrize(
+ "trigger",
+ [
+ trigger,
+ ],
+ )
+ def test_serialize_recreate(self, trigger):
+ class_path, args = trigger.serialize()
+
+ class_name = class_path.split(".")[-1]
+ clazz = globals()[class_name]
+ instance = clazz(**args)
+
+ class_path2, args2 = instance.serialize()
+
+ assert class_path == class_path2
+ assert args == args2
+
+ @pytest.mark.asyncio
+ async def test_poke(self):
+ sqs_trigger = trigger
+ mock_client = AsyncMock()
+ message = {
+ "MessageId": "test_message_id",
+ "Body": "test",
+ }
+ mock_response = {
+ "Messages": [message],
+ }
+ mock_client.receive_message.return_value = mock_response
+ messages = await sqs_trigger.poke(client=mock_client)
+ assert messages[0] == message
+
+ @pytest.mark.asyncio
+ async def test_poke_filtered_message(self):
+ sqs_trigger = trigger
+ mock_client = AsyncMock()
+ message = {
+ "MessageId": "test_message_id",
+ "Body": "This will be filtered out",
+ }
+ mock_response = {
+ "Messages": [message],
+ }
+ mock_client.receive_message.return_value = mock_response
+ messages = await sqs_trigger.poke(client=mock_client)
+ assert len(messages) == 0
+
+ @pytest.mark.asyncio
+ async def test_poke_no_messages(self):
+ sqs_trigger = trigger
+ mock_client = AsyncMock()
+ mock_response = {"Messages": []}
+ mock_client.receive_message.return_value = mock_response
+ messages = await sqs_trigger.poke(client=mock_client)
+ assert len(messages) == 0