You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by on...@apache.org on 2023/08/04 05:18:45 UTC

[airflow] branch main updated: Deferrable mode for Sqs Sensor (#32809)

This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 087d14ada2 Deferrable mode for Sqs Sensor (#32809)
087d14ada2 is described below

commit 087d14ada24e87fdf9db58a13acf0f2753191256
Author: Syed Hussain <10...@users.noreply.github.com>
AuthorDate: Thu Aug 3 22:18:38 2023 -0700

    Deferrable mode for Sqs Sensor (#32809)
---
 airflow/providers/amazon/aws/sensors/sqs.py        |  94 +++++------
 .../amazon/aws/{sensors => triggers}/sqs.py        | 183 +++++++--------------
 airflow/providers/amazon/aws/utils/sqs.py          |  90 ++++++++++
 airflow/providers/amazon/provider.yaml             |   3 +
 .../operators/sqs.rst                              |   1 +
 tests/providers/amazon/aws/sensors/test_sqs.py     |  15 +-
 tests/providers/amazon/aws/triggers/test_sqs.py    | 108 ++++++++++++
 7 files changed, 322 insertions(+), 172 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index da7f7e513e..a61652acb2 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -18,20 +18,22 @@
 """Reads and then deletes the message from SQS queue."""
 from __future__ import annotations
 
-import json
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Collection, Literal, Sequence
 
 from deprecated import deprecated
-from jsonpath_ng import parse
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
+from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
+from airflow.providers.amazon.aws.utils.sqs import process_response
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
+from datetime import timedelta
 
 
 class SqsSensor(BaseSensorOperator):
@@ -70,6 +72,9 @@ class SqsSensor(BaseSensorOperator):
     :param delete_message_on_reception: Default to `True`, the messages are deleted from the queue
         as soon as being consumed. Otherwise, the messages remain in the queue after consumption and
         should be deleted manually.
+    :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
+        module to be installed.
+        (default: False, but can be overridden in config file by setting default_deferrable to True)
 
     """
 
@@ -88,6 +93,7 @@ class SqsSensor(BaseSensorOperator):
         message_filtering_match_values: Any = None,
         message_filtering_config: Any = None,
         delete_message_on_reception: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -112,6 +118,34 @@ class SqsSensor(BaseSensorOperator):
                 raise TypeError("message_filtering_match_values must be specified for literal matching")
 
         self.message_filtering_config = message_filtering_config
+        self.deferrable = deferrable
+
+    def execute(self, context: Context) -> Any:
+        if self.deferrable:
+            self.defer(
+                trigger=SqsSensorTrigger(
+                    sqs_queue=self.sqs_queue,
+                    aws_conn_id=self.aws_conn_id,
+                    max_messages=self.max_messages,
+                    num_batches=self.num_batches,
+                    wait_time_seconds=self.wait_time_seconds,
+                    visibility_timeout=self.visibility_timeout,
+                    message_filtering=self.message_filtering,
+                    message_filtering_match_values=self.message_filtering_match_values,
+                    message_filtering_config=self.message_filtering_config,
+                    delete_message_on_reception=self.delete_message_on_reception,
+                    waiter_delay=int(self.poke_interval),
+                ),
+                method_name="execute_complete",
+                timeout=timedelta(seconds=self.timeout),
+            )
+        else:
+            super().execute(context=context)
+
+    def execute_complete(self, context: Context, event: dict | None = None) -> None:
+        if event is None or event["status"] != "success":
+            raise AirflowException(f"Trigger error: event is {event}")
+        context["ti"].xcom_push(key="messages", value=event["message_batch"])
 
     def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
         """
@@ -131,19 +165,7 @@ class SqsSensor(BaseSensorOperator):
             receive_message_kwargs["VisibilityTimeout"] = self.visibility_timeout
 
         response = sqs_conn.receive_message(**receive_message_kwargs)
-
-        if "Messages" not in response:
-            return []
-
-        messages = response["Messages"]
-        num_messages = len(messages)
-        self.log.info("Received %d messages", num_messages)
-
-        if num_messages and self.message_filtering:
-            messages = self.filter_messages(messages)
-            num_messages = len(messages)
-            self.log.info("There are %d messages left after filtering", num_messages)
-        return messages
+        return response
 
     def poke(self, context: Context):
         """
@@ -156,7 +178,13 @@ class SqsSensor(BaseSensorOperator):
 
         # perform multiple SQS call to retrieve messages in series
         for _ in range(self.num_batches):
-            messages = self.poll_sqs(sqs_conn=self.hook.conn)
+            response = self.poll_sqs(sqs_conn=self.hook.conn)
+            messages = process_response(
+                response,
+                self.message_filtering,
+                self.message_filtering_match_values,
+                self.message_filtering_config,
+            )
 
             if not len(messages):
                 continue
@@ -191,37 +219,3 @@ class SqsSensor(BaseSensorOperator):
     @cached_property
     def hook(self) -> SqsHook:
         return SqsHook(aws_conn_id=self.aws_conn_id)
-
-    def filter_messages(self, messages):
-        if self.message_filtering == "literal":
-            return self.filter_messages_literal(messages)
-        if self.message_filtering == "jsonpath":
-            return self.filter_messages_jsonpath(messages)
-        else:
-            raise NotImplementedError("Override this method to define custom filters")
-
-    def filter_messages_literal(self, messages):
-        filtered_messages = []
-        for message in messages:
-            if message["Body"] in self.message_filtering_match_values:
-                filtered_messages.append(message)
-        return filtered_messages
-
-    def filter_messages_jsonpath(self, messages):
-        jsonpath_expr = parse(self.message_filtering_config)
-        filtered_messages = []
-        for message in messages:
-            body = message["Body"]
-            # Body is a string, deserialize to an object and then parse
-            body = json.loads(body)
-            results = jsonpath_expr.find(body)
-            if not results:
-                continue
-            if self.message_filtering_match_values is None:
-                filtered_messages.append(message)
-                continue
-            for result in results:
-                if result.value in self.message_filtering_match_values:
-                    filtered_messages.append(message)
-                    break
-        return filtered_messages
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/triggers/sqs.py
similarity index 51%
copy from airflow/providers/amazon/aws/sensors/sqs.py
copy to airflow/providers/amazon/aws/triggers/sqs.py
index da7f7e513e..7e26b9c28f 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/triggers/sqs.py
@@ -1,4 +1,3 @@
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,42 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Reads and then deletes the message from SQS queue."""
 from __future__ import annotations
 
-import json
-from functools import cached_property
-from typing import TYPE_CHECKING, Any, Collection, Literal, Sequence
-
-from deprecated import deprecated
-from jsonpath_ng import parse
+import asyncio
+from typing import Any, AsyncIterator, Collection, Literal
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.utils.sqs import process_response
+from airflow.triggers.base import BaseTrigger, TriggerEvent
 
-if TYPE_CHECKING:
-    from airflow.utils.context import Context
 
-
-class SqsSensor(BaseSensorOperator):
+class SqsSensorTrigger(BaseTrigger):
     """
-    Get messages from an Amazon SQS queue and then delete the messages from the queue.
-
-    If deletion of messages fails, an AirflowException is thrown. Otherwise, the messages
-    are pushed through XCom with the key ``messages``.
-
-    By default,the sensor performs one and only one SQS call per poke, which limits the result to
-    a maximum of 10 messages. However, the total number of SQS API calls per poke can be controlled
-    by num_batches param.
-
-    .. seealso::
-        For more information on how to use this sensor, take a look at the guide:
-        :ref:`howto/sensor:SqsSensor`
+    Asynchronously get messages from an Amazon SQS queue and then delete the messages from the queue.
 
+    :param sqs_queue: The SQS queue url
     :param aws_conn_id: AWS connection id
-    :param sqs_queue: The SQS queue url (templated)
     :param max_messages: The maximum number of messages to retrieve for each poke (templated)
     :param num_batches: The number of times the sensor will call the SQS API to receive messages (default: 1)
     :param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second)
@@ -70,15 +51,12 @@ class SqsSensor(BaseSensorOperator):
     :param delete_message_on_reception: Default to `True`, the messages are deleted from the queue
         as soon as being consumed. Otherwise, the messages remain in the queue after consumption and
         should be deleted manually.
-
+    :param waiter_delay: The time in seconds to wait between calls to the SQS API to receive messages.
     """
 
-    template_fields: Sequence[str] = ("sqs_queue", "max_messages", "message_filtering_config")
-
     def __init__(
         self,
-        *,
-        sqs_queue,
+        sqs_queue: str,
         aws_conn_id: str = "aws_default",
         max_messages: int = 5,
         num_batches: int = 1,
@@ -88,36 +66,47 @@ class SqsSensor(BaseSensorOperator):
         message_filtering_match_values: Any = None,
         message_filtering_config: Any = None,
         delete_message_on_reception: bool = True,
-        **kwargs,
+        waiter_delay: int = 60,
     ):
-        super().__init__(**kwargs)
         self.sqs_queue = sqs_queue
         self.aws_conn_id = aws_conn_id
         self.max_messages = max_messages
         self.num_batches = num_batches
         self.wait_time_seconds = wait_time_seconds
         self.visibility_timeout = visibility_timeout
-
         self.message_filtering = message_filtering
-
         self.delete_message_on_reception = delete_message_on_reception
-
-        if message_filtering_match_values is not None:
-            if not isinstance(message_filtering_match_values, set):
-                message_filtering_match_values = set(message_filtering_match_values)
         self.message_filtering_match_values = message_filtering_match_values
-
-        if self.message_filtering == "literal":
-            if self.message_filtering_match_values is None:
-                raise TypeError("message_filtering_match_values must be specified for literal matching")
-
         self.message_filtering_config = message_filtering_config
+        self.waiter_delay = waiter_delay
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "sqs_queue": self.sqs_queue,
+                "aws_conn_id": self.aws_conn_id,
+                "max_messages": self.max_messages,
+                "num_batches": self.num_batches,
+                "wait_time_seconds": self.wait_time_seconds,
+                "visibility_timeout": self.visibility_timeout,
+                "message_filtering": self.message_filtering,
+                "delete_message_on_reception": self.delete_message_on_reception,
+                "message_filtering_match_values": self.message_filtering_match_values,
+                "message_filtering_config": self.message_filtering_config,
+                "waiter_delay": self.waiter_delay,
+            },
+        )
+
+    @property
+    def hook(self) -> SqsHook:
+        return SqsHook(aws_conn_id=self.aws_conn_id)
 
-    def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
+    async def poll_sqs(self, client: BaseAwsConnection) -> Collection:
         """
-        Poll SQS queue to retrieve messages.
+        Asynchronously poll SQS queue to retrieve messages.
 
-        :param sqs_conn: SQS connection
+        :param client: SQS connection
         :return: A list of messages retrieved from SQS
         """
         self.log.info("SqsSensor checking for message on queue: %s", self.sqs_queue)
@@ -130,98 +119,50 @@ class SqsSensor(BaseSensorOperator):
         if self.visibility_timeout is not None:
             receive_message_kwargs["VisibilityTimeout"] = self.visibility_timeout
 
-        response = sqs_conn.receive_message(**receive_message_kwargs)
+        response = await client.receive_message(**receive_message_kwargs)
+        return response
 
-        if "Messages" not in response:
-            return []
-
-        messages = response["Messages"]
-        num_messages = len(messages)
-        self.log.info("Received %d messages", num_messages)
-
-        if num_messages and self.message_filtering:
-            messages = self.filter_messages(messages)
-            num_messages = len(messages)
-            self.log.info("There are %d messages left after filtering", num_messages)
-        return messages
-
-    def poke(self, context: Context):
-        """
-        Check subscribed queue for messages and write them to xcom with the ``messages`` key.
-
-        :param context: the context object
-        :return: ``True`` if message is available or ``False``
-        """
+    async def poke(self, client: Any):
         message_batch: list[Any] = []
-
-        # perform multiple SQS call to retrieve messages in series
         for _ in range(self.num_batches):
-            messages = self.poll_sqs(sqs_conn=self.hook.conn)
-
-            if not len(messages):
+            self.log.info("starting call to poll sqs")
+            response = await self.poll_sqs(client=client)
+            messages = process_response(
+                response,
+                self.message_filtering,
+                self.message_filtering_match_values,
+                self.message_filtering_config,
+            )
+
+            if not messages:
                 continue
 
             message_batch.extend(messages)
 
             if self.delete_message_on_reception:
-
                 self.log.info("Deleting %d messages", len(messages))
 
                 entries = [
                     {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}
                     for message in messages
                 ]
-                response = self.hook.conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+                response = await client.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
 
                 if "Successful" not in response:
                     raise AirflowException(
-                        "Delete SQS Messages failed " + str(response) + " for messages " + str(messages)
+                        f"Delete SQS Messages failed {str(response)} for messages {str(messages)}"
                     )
-        if not len(message_batch):
-            return False
-
-        context["ti"].xcom_push(key="messages", value=message_batch)
-        return True
 
-    @deprecated(reason="use `hook` property instead.")
-    def get_hook(self) -> SqsHook:
-        """Create and return an SqsHook."""
-        return self.hook
+        return message_batch
 
-    @cached_property
-    def hook(self) -> SqsHook:
-        return SqsHook(aws_conn_id=self.aws_conn_id)
-
-    def filter_messages(self, messages):
-        if self.message_filtering == "literal":
-            return self.filter_messages_literal(messages)
-        if self.message_filtering == "jsonpath":
-            return self.filter_messages_jsonpath(messages)
-        else:
-            raise NotImplementedError("Override this method to define custom filters")
-
-    def filter_messages_literal(self, messages):
-        filtered_messages = []
-        for message in messages:
-            if message["Body"] in self.message_filtering_match_values:
-                filtered_messages.append(message)
-        return filtered_messages
-
-    def filter_messages_jsonpath(self, messages):
-        jsonpath_expr = parse(self.message_filtering_config)
-        filtered_messages = []
-        for message in messages:
-            body = message["Body"]
-            # Body is a string, deserialize to an object and then parse
-            body = json.loads(body)
-            results = jsonpath_expr.find(body)
-            if not results:
-                continue
-            if self.message_filtering_match_values is None:
-                filtered_messages.append(message)
-                continue
-            for result in results:
-                if result.value in self.message_filtering_match_values:
-                    filtered_messages.append(message)
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        while True:
+            # This loop will run indefinitely until the timeout, which is set in the self.defer
+            # method, is reached.
+            async with self.hook.async_conn as client:
+                result = await self.poke(client=client)
+                if result:
+                    yield TriggerEvent({"status": "success", "message_batch": result})
                     break
-        return filtered_messages
+                else:
+                    await asyncio.sleep(self.waiter_delay)
diff --git a/airflow/providers/amazon/aws/utils/sqs.py b/airflow/providers/amazon/aws/utils/sqs.py
new file mode 100644
index 0000000000..2b081e5259
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/sqs.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+from typing import Any, Literal
+
+from jsonpath_ng import parse
+
+log = logging.getLogger(__name__)
+
+
+def process_response(
+    response: Any,
+    message_filtering: Literal["literal", "jsonpath"] | None = None,
+    message_filtering_match_values: Any = None,
+    message_filtering_config: Any = None,
+) -> Any:
+    """
+    Process the response from SQS.
+
+    :param response: The response from SQS
+    :return: The processed response
+    """
+    if not isinstance(response, dict):
+        return []
+    elif "Messages" not in response:
+        return []
+
+    messages = response["Messages"]
+    num_messages = len(messages)
+    log.info("Received %d messages", num_messages)
+
+    if num_messages and message_filtering:
+        messages = filter_messages(
+            messages, message_filtering, message_filtering_match_values, message_filtering_config
+        )
+        num_messages = len(messages)
+        log.info("There are %d messages left after filtering", num_messages)
+    return messages
+
+
+def filter_messages(
+    messages, message_filtering, message_filtering_match_values, message_filtering_config
+) -> list[Any]:
+    if message_filtering == "literal":
+        return filter_messages_literal(messages, message_filtering_match_values)
+    if message_filtering == "jsonpath":
+        return filter_messages_jsonpath(messages, message_filtering_match_values, message_filtering_config)
+    else:
+        raise NotImplementedError("Override this method to define custom filters")
+
+
+def filter_messages_literal(messages, message_filtering_match_values) -> list[Any]:
+    return [message for message in messages if message["Body"] in message_filtering_match_values]
+
+
+def filter_messages_jsonpath(messages, message_filtering_match_values, message_filtering_config) -> list[Any]:
+    jsonpath_expr = parse(message_filtering_config)
+    filtered_messages = []
+    for message in messages:
+        body = message["Body"]
+        # Body is a string, deserialize to an object and then parse
+        body = json.loads(body)
+        results = jsonpath_expr.find(body)
+        if not results:
+            continue
+        if message_filtering_match_values is None:
+            filtered_messages.append(message)
+            continue
+        for result in results:
+            if result.value in message_filtering_match_values:
+                filtered_messages.append(message)
+                break
+    return filtered_messages
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index a0af095b40..76d1cb376e 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -572,6 +572,9 @@ triggers:
   - integration-name: AWS Step Functions
     python-modules:
       - airflow.providers.amazon.aws.triggers.step_function
+  - integration-name: Amazon Simple Queue Service (SQS)
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.sqs
 
 transfers:
   - source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/sqs.rst b/docs/apache-airflow-providers-amazon/operators/sqs.rst
index 13c806626e..9f72a2eff8 100644
--- a/docs/apache-airflow-providers-amazon/operators/sqs.rst
+++ b/docs/apache-airflow-providers-amazon/operators/sqs.rst
@@ -61,6 +61,7 @@ Read messages from an Amazon SQS queue
 
 To read messages from an Amazon SQS queue until exhausted use the
 :class:`~airflow.providers.amazon.aws.sensors.sqs.SqsSensor`
+This sensor can also be run in deferrable mode by setting ``deferrable`` param to ``True``.
 
 .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sqs.py
     :language: python
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index e6175c61d3..a9b03bd40a 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -24,7 +24,7 @@ from unittest.mock import patch
 import pytest
 from moto import mock_sqs
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.models.dag import DAG
 from airflow.providers.amazon.aws.hooks.sqs import SqsHook
 from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
@@ -333,3 +333,16 @@ class TestSqsSensor:
             assert f"'Body': '{message}'" in str(
                 self.mock_context["ti"].method_calls
             ), "context call should contain message '{message}'"
+
+    def test_sqs_deferrable(self):
+        self.sensor = SqsSensor(
+            task_id="test_task_deferrable",
+            dag=self.dag,
+            sqs_queue=QUEUE_URL,
+            aws_conn_id="aws_default",
+            max_messages=1,
+            num_batches=3,
+            deferrable=True,
+        )
+        with pytest.raises(TaskDeferred):
+            self.sensor.execute(None)
diff --git a/tests/providers/amazon/aws/triggers/test_sqs.py b/tests/providers/amazon/aws/triggers/test_sqs.py
new file mode 100644
index 0000000000..74ffd837a1
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_sqs.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
+
+TEST_SQS_QUEUE = "test-sqs-queue"
+TEST_AWS_CONN_ID = "test-aws-conn-id"
+TEST_MAX_MESSAGES = 1
+TEST_NUM_BATCHES = 1
+TEST_WAIT_TIME_SECONDS = 1
+TEST_VISIBILITY_TIMEOUT = 1
+TEST_MESSAGE_FILTERING = "literal"
+TEST_MESSAGE_FILTERING_MATCH_VALUES = "test"
+TEST_MESSAGE_FILTERING_CONFIG = "test-message-filtering-config"
+TEST_DELETE_MESSAGE_ON_RECEPTION = False
+TEST_WAITER_DELAY = 1
+
+trigger = SqsSensorTrigger(
+    sqs_queue=TEST_SQS_QUEUE,
+    aws_conn_id=TEST_AWS_CONN_ID,
+    max_messages=TEST_MAX_MESSAGES,
+    num_batches=TEST_NUM_BATCHES,
+    wait_time_seconds=TEST_WAIT_TIME_SECONDS,
+    visibility_timeout=TEST_VISIBILITY_TIMEOUT,
+    message_filtering=TEST_MESSAGE_FILTERING,
+    message_filtering_match_values=TEST_MESSAGE_FILTERING_MATCH_VALUES,
+    message_filtering_config=TEST_MESSAGE_FILTERING_CONFIG,
+    delete_message_on_reception=TEST_DELETE_MESSAGE_ON_RECEPTION,
+    waiter_delay=TEST_WAITER_DELAY,
+)
+
+
+class TestSqsTriggers:
+    @pytest.mark.parametrize(
+        "trigger",
+        [
+            trigger,
+        ],
+    )
+    def test_serialize_recreate(self, trigger):
+        class_path, args = trigger.serialize()
+
+        class_name = class_path.split(".")[-1]
+        clazz = globals()[class_name]
+        instance = clazz(**args)
+
+        class_path2, args2 = instance.serialize()
+
+        assert class_path == class_path2
+        assert args == args2
+
+    @pytest.mark.asyncio
+    async def test_poke(self):
+        sqs_trigger = trigger
+        mock_client = AsyncMock()
+        message = {
+            "MessageId": "test_message_id",
+            "Body": "test",
+        }
+        mock_response = {
+            "Messages": [message],
+        }
+        mock_client.receive_message.return_value = mock_response
+        messages = await sqs_trigger.poke(client=mock_client)
+        assert messages[0] == message
+
+    @pytest.mark.asyncio
+    async def test_poke_filtered_message(self):
+        sqs_trigger = trigger
+        mock_client = AsyncMock()
+        message = {
+            "MessageId": "test_message_id",
+            "Body": "This will be filtered out",
+        }
+        mock_response = {
+            "Messages": [message],
+        }
+        mock_client.receive_message.return_value = mock_response
+        messages = await sqs_trigger.poke(client=mock_client)
+        assert len(messages) == 0
+
+    @pytest.mark.asyncio
+    async def test_poke_no_messages(self):
+        sqs_trigger = trigger
+        mock_client = AsyncMock()
+        mock_response = {"Messages": []}
+        mock_client.receive_message.return_value = mock_response
+        messages = await sqs_trigger.poke(client=mock_client)
+        assert len(messages) == 0