You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/08/02 20:47:24 UTC

[airflow] branch main updated: Improve AWS SQS Sensor (#16880) (#16904)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d28efbf  Improve AWS SQS Sensor (#16880) (#16904)
d28efbf is described below

commit d28efbfb7780afd1ff13a258dc5dc3e3381ddabd
Author: Bjorn Olsen <bj...@gmail.com>
AuthorDate: Mon Aug 2 22:47:10 2021 +0200

    Improve AWS SQS Sensor (#16880) (#16904)
---
 airflow/providers/amazon/aws/sensors/sqs.py    | 135 ++++++++++++++++---
 setup.py                                       |   1 +
 tests/providers/amazon/aws/sensors/test_sqs.py | 178 +++++++++++++++++++++++++
 3 files changed, 292 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index dc6217b..e2d04e1 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -16,7 +16,11 @@
 # specific language governing permissions and limitations
 # under the License.
 """Reads and then deletes the message from SQS queue"""
-from typing import Optional
+import json
+from typing import Any, Optional
+
+from jsonpath_ng import parse
+from typing_extensions import Literal
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.sqs import SQSHook
@@ -37,9 +41,26 @@ class SQSSensor(BaseSensorOperator):
     :type max_messages: int
     :param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second)
     :type wait_time_seconds: int
+    :param visibility_timeout: Visibility timeout, a period of time during which
+        Amazon SQS prevents other consumers from receiving and processing the message.
+    :type visibility_timeout: Optional[Int]
+    :param message_filtering: Specified how received messages should be filtered. Supported options are:
+        `None` (no filtering, default), `'literal'` (message Body literal match) or `'jsonpath'`
+        (message Body filtered using a JSONPath expression).
+        You may add further methods by overriding the relevant class methods.
+    :type message_filtering: Optional[Literal["literal", "jsonpath"]]
+    :param message_filtering_match_values: Optional value/s for the message filter to match on.
+        For example, with literal matching, if a message body matches any of the specified values
+        then it is included. For JSONPath matching, the result of the JSONPath expression is used
+        and may match any of the specified values.
+    :type message_filtering_match_values: Any
+    :param message_filtering_config: Additional configuration to pass to the message filter.
+        For example with JSONPath filtering you can pass a JSONPath expression string here,
+        such as `'foo[*].baz'`. Messages with a Body which does not match are ignored.
+    :type message_filtering_config: Any
     """
 
-    template_fields = ('sqs_queue', 'max_messages')
+    template_fields = ('sqs_queue', 'max_messages', 'message_filtering_config')
 
     def __init__(
         self,
@@ -48,6 +69,10 @@ class SQSSensor(BaseSensorOperator):
         aws_conn_id: str = 'aws_default',
         max_messages: int = 5,
         wait_time_seconds: int = 1,
+        visibility_timeout: Optional[int] = None,
+        message_filtering: Optional[Literal["literal", "jsonpath"]] = None,
+        message_filtering_match_values: Any = None,
+        message_filtering_config: Any = None,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -55,6 +80,21 @@ class SQSSensor(BaseSensorOperator):
         self.aws_conn_id = aws_conn_id
         self.max_messages = max_messages
         self.wait_time_seconds = wait_time_seconds
+        self.visibility_timeout = visibility_timeout
+
+        self.message_filtering = message_filtering
+
+        if message_filtering_match_values is not None:
+            if not isinstance(message_filtering_match_values, set):
+                message_filtering_match_values = set(message_filtering_match_values)
+        self.message_filtering_match_values = message_filtering_match_values
+
+        if self.message_filtering == 'literal':
+            if self.message_filtering_match_values is None:
+                raise TypeError('message_filtering_match_values must be specified for literal matching')
+
+        self.message_filtering_config = message_filtering_config
+
         self.hook: Optional[SQSHook] = None
 
     def poke(self, context):
@@ -69,31 +109,48 @@ class SQSSensor(BaseSensorOperator):
 
         self.log.info('SQSSensor checking for message on queue: %s', self.sqs_queue)
 
-        messages = sqs_conn.receive_message(
-            QueueUrl=self.sqs_queue,
-            MaxNumberOfMessages=self.max_messages,
-            WaitTimeSeconds=self.wait_time_seconds,
-        )
+        receive_message_kwargs = {
+            'QueueUrl': self.sqs_queue,
+            'MaxNumberOfMessages': self.max_messages,
+            'WaitTimeSeconds': self.wait_time_seconds,
+        }
+        if self.visibility_timeout is not None:
+            receive_message_kwargs['VisibilityTimeout'] = self.visibility_timeout
 
-        self.log.info("received message %s", str(messages))
+        response = sqs_conn.receive_message(**receive_message_kwargs)
 
-        if 'Messages' in messages and messages['Messages']:
-            entries = [
-                {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']}
-                for message in messages['Messages']
-            ]
+        if "Messages" not in response:
+            return False
 
-            result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+        messages = response['Messages']
+        num_messages = len(messages)
+        self.log.info("Received %d messages", num_messages)
 
-            if 'Successful' in result:
-                context['ti'].xcom_push(key='messages', value=messages)
-                return True
-            else:
-                raise AirflowException(
-                    'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages)
-                )
+        if not num_messages:
+            return False
 
-        return False
+        if self.message_filtering:
+            messages = self.filter_messages(messages)
+            num_messages = len(messages)
+            self.log.info("There are %d messages left after filtering", num_messages)
+
+        if not num_messages:
+            return False
+
+        self.log.info("Deleting %d messages", num_messages)
+
+        entries = [
+            {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in messages
+        ]
+        response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+
+        if 'Successful' in response:
+            context['ti'].xcom_push(key='messages', value=messages)
+            return True
+        else:
+            raise AirflowException(
+                'Delete SQS Messages failed ' + str(response) + ' for messages ' + str(messages)
+            )
 
     def get_hook(self) -> SQSHook:
         """Create and return an SQSHook"""
@@ -102,3 +159,37 @@ class SQSSensor(BaseSensorOperator):
 
         self.hook = SQSHook(aws_conn_id=self.aws_conn_id)
         return self.hook
+
+    def filter_messages(self, messages):
+        if self.message_filtering == 'literal':
+            return self.filter_messages_literal(messages)
+        if self.message_filtering == 'jsonpath':
+            return self.filter_messages_jsonpath(messages)
+        else:
+            raise NotImplementedError('Override this method to define custom filters')
+
+    def filter_messages_literal(self, messages):
+        filtered_messages = []
+        for message in messages:
+            if message['Body'] in self.message_filtering_match_values:
+                filtered_messages.append(message)
+        return filtered_messages
+
+    def filter_messages_jsonpath(self, messages):
+        jsonpath_expr = parse(self.message_filtering_config)
+        filtered_messages = []
+        for message in messages:
+            body = message['Body']
+            # Body is a string, deserialise to an object and then parse
+            body = json.loads(body)
+            results = jsonpath_expr.find(body)
+            if not results:
+                continue
+            if self.message_filtering_match_values is None:
+                filtered_messages.append(message)
+                continue
+            for result in results:
+                if result.value in self.message_filtering_match_values:
+                    filtered_messages.append(message)
+                    break
+        return filtered_messages
diff --git a/setup.py b/setup.py
index 09df05c..3b2389c 100644
--- a/setup.py
+++ b/setup.py
@@ -182,6 +182,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
 amazon = [
     'boto3>=1.15.0,<1.18.0',
     'watchtower~=1.0.6',
+    'jsonpath_ng>=1.5.3',
 ]
 apache_beam = [
     'apache-beam>=2.20.0',
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index 90349a3..82a1aac 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -17,6 +17,7 @@
 # under the License.
 
 
+import json
 import unittest
 from unittest import mock
 
@@ -107,3 +108,180 @@ class TestSQSSensor(unittest.TestCase):
             self.sensor.poke(self.mock_context)
 
         assert 'test exception' in ctx.value.args[0]
+
+    @mock.patch.object(SQSHook, 'get_conn')
+    def test_poke_visibility_timeout(self, mock_conn):
+        # Check without visibility_timeout parameter
+        self.sqs_hook.create_queue('test')
+        self.sqs_hook.send_message(queue_url='test', message_body='hello')
+
+        self.sensor.poke(self.mock_context)
+
+        calls_receive_message = [
+            mock.call().receive_message(QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1)
+        ]
+        mock_conn.assert_has_calls(calls_receive_message)
+        # Check with visibility_timeout parameter
+        self.sensor = SQSSensor(
+            task_id='test_task2',
+            dag=self.dag,
+            sqs_queue='test',
+            aws_conn_id='aws_default',
+            visibility_timeout=42,
+        )
+        self.sensor.poke(self.mock_context)
+
+        calls_receive_message = [
+            mock.call().receive_message(
+                QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
+            )
+        ]
+        mock_conn.assert_has_calls(calls_receive_message)
+
+    @mock_sqs
+    def test_poke_message_invalid_filtering(self):
+        self.sqs_hook.create_queue('test')
+        self.sqs_hook.send_message(queue_url='test', message_body='hello')
+        sensor = SQSSensor(
+            task_id='test_task2',
+            dag=self.dag,
+            sqs_queue='test',
+            aws_conn_id='aws_default',
+            message_filtering='invalid_option',
+        )
+        with pytest.raises(NotImplementedError) as ctx:
+            sensor.poke(self.mock_context)
+        assert 'Override this method to define custom filters' in ctx.value.args[0]
+
+    @mock.patch.object(SQSHook, "get_conn")
+    def test_poke_message_filtering_literal_values(self, mock_conn):
+        self.sqs_hook.create_queue('test')
+        matching = [{"id": 11, "body": "a matching message"}]
+        non_matching = [{"id": 12, "body": "a non-matching message"}]
+        all = matching + non_matching
+
+        def mock_receive_message(**kwargs):
+            messages = []
+            for message in all:
+                messages.append(
+                    {
+                        'MessageId': message['id'],
+                        'ReceiptHandle': 100 + message['id'],
+                        'Body': message['body'],
+                    }
+                )
+            return {'Messages': messages}
+
+        mock_conn.return_value.receive_message.side_effect = mock_receive_message
+
+        def mock_delete_message_batch(**kwargs):
+            return {'Successful'}
+
+        mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch
+
+        # Test that messages are filtered
+        self.sensor.message_filtering = 'literal'
+        self.sensor.message_filtering_match_values = ["a matching message"]
+        result = self.sensor.poke(self.mock_context)
+        assert result
+
+        # Test that only filtered messages are deleted
+        delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
+        calls_delete_message_batch = [
+            mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
+        ]
+        mock_conn.assert_has_calls(calls_delete_message_batch)
+
+    @mock.patch.object(SQSHook, "get_conn")
+    def test_poke_message_filtering_jsonpath(self, mock_conn):
+        self.sqs_hook.create_queue('test')
+        matching = [
+            {"id": 11, "key": {"matches": [1, 2]}},
+            {"id": 12, "key": {"matches": [3, 4, 5]}},
+            {"id": 13, "key": {"matches": [10]}},
+        ]
+        non_matching = [
+            {"id": 14, "key": {"nope": [5, 6]}},
+            {"id": 15, "key": {"nope": [7, 8]}},
+        ]
+        all = matching + non_matching
+
+        def mock_receive_message(**kwargs):
+            messages = []
+            for message in all:
+                messages.append(
+                    {
+                        'MessageId': message['id'],
+                        'ReceiptHandle': 100 + message['id'],
+                        'Body': json.dumps(message),
+                    }
+                )
+            return {'Messages': messages}
+
+        mock_conn.return_value.receive_message.side_effect = mock_receive_message
+
+        def mock_delete_message_batch(**kwargs):
+            return {'Successful'}
+
+        mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch
+
+        # Test that messages are filtered
+        self.sensor.message_filtering = 'jsonpath'
+        self.sensor.message_filtering_config = 'key.matches[*]'
+        result = self.sensor.poke(self.mock_context)
+        assert result
+
+        # Test that only filtered messages are deleted
+        delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
+        calls_delete_message_batch = [
+            mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
+        ]
+        mock_conn.assert_has_calls(calls_delete_message_batch)
+
+    @mock.patch.object(SQSHook, "get_conn")
+    def test_poke_message_filtering_jsonpath_values(self, mock_conn):
+        self.sqs_hook.create_queue('test')
+        matching = [
+            {"id": 11, "key": {"matches": [1, 2]}},
+            {"id": 12, "key": {"matches": [1, 4, 5]}},
+            {"id": 13, "key": {"matches": [4, 5]}},
+        ]
+        non_matching = [
+            {"id": 21, "key": {"matches": [10]}},
+            {"id": 22, "key": {"nope": [5, 6]}},
+            {"id": 23, "key": {"nope": [7, 8]}},
+        ]
+        all = matching + non_matching
+
+        def mock_receive_message(**kwargs):
+            messages = []
+            for message in all:
+                messages.append(
+                    {
+                        'MessageId': message['id'],
+                        'ReceiptHandle': 100 + message['id'],
+                        'Body': json.dumps(message),
+                    }
+                )
+            return {'Messages': messages}
+
+        mock_conn.return_value.receive_message.side_effect = mock_receive_message
+
+        def mock_delete_message_batch(**kwargs):
+            return {'Successful'}
+
+        mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch
+
+        # Test that messages are filtered
+        self.sensor.message_filtering = 'jsonpath'
+        self.sensor.message_filtering_config = 'key.matches[*]'
+        self.sensor.message_filtering_match_values = [1, 4]
+        result = self.sensor.poke(self.mock_context)
+        assert result
+
+        # Test that only filtered messages are deleted
+        delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
+        calls_delete_message_batch = [
+            mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
+        ]
+        mock_conn.assert_has_calls(calls_delete_message_batch)