You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/08/02 20:47:24 UTC
[airflow] branch main updated: Improve AWS SQS Sensor (#16880)
(#16904)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d28efbf Improve AWS SQS Sensor (#16880) (#16904)
d28efbf is described below
commit d28efbfb7780afd1ff13a258dc5dc3e3381ddabd
Author: Bjorn Olsen <bj...@gmail.com>
AuthorDate: Mon Aug 2 22:47:10 2021 +0200
Improve AWS SQS Sensor (#16880) (#16904)
---
airflow/providers/amazon/aws/sensors/sqs.py | 135 ++++++++++++++++---
setup.py | 1 +
tests/providers/amazon/aws/sensors/test_sqs.py | 178 +++++++++++++++++++++++++
3 files changed, 292 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index dc6217b..e2d04e1 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -16,7 +16,11 @@
# specific language governing permissions and limitations
# under the License.
"""Reads and then deletes the message from SQS queue"""
-from typing import Optional
+import json
+from typing import Any, Optional
+
+from jsonpath_ng import parse
+from typing_extensions import Literal
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sqs import SQSHook
@@ -37,9 +41,26 @@ class SQSSensor(BaseSensorOperator):
:type max_messages: int
:param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second)
:type wait_time_seconds: int
+ :param visibility_timeout: Visibility timeout, a period of time during which
+ Amazon SQS prevents other consumers from receiving and processing the message.
+ :type visibility_timeout: Optional[Int]
+ :param message_filtering: Specified how received messages should be filtered. Supported options are:
+ `None` (no filtering, default), `'literal'` (message Body literal match) or `'jsonpath'`
+ (message Body filtered using a JSONPath expression).
+ You may add further methods by overriding the relevant class methods.
+ :type message_filtering: Optional[Literal["literal", "jsonpath"]]
+ :param message_filtering_match_values: Optional value/s for the message filter to match on.
+ For example, with literal matching, if a message body matches any of the specified values
+ then it is included. For JSONPath matching, the result of the JSONPath expression is used
+ and may match any of the specified values.
+ :type message_filtering_match_values: Any
+ :param message_filtering_config: Additional configuration to pass to the message filter.
+ For example with JSONPath filtering you can pass a JSONPath expression string here,
+ such as `'foo[*].baz'`. Messages with a Body which does not match are ignored.
+ :type message_filtering_config: Any
"""
- template_fields = ('sqs_queue', 'max_messages')
+ template_fields = ('sqs_queue', 'max_messages', 'message_filtering_config')
def __init__(
self,
@@ -48,6 +69,10 @@ class SQSSensor(BaseSensorOperator):
aws_conn_id: str = 'aws_default',
max_messages: int = 5,
wait_time_seconds: int = 1,
+ visibility_timeout: Optional[int] = None,
+ message_filtering: Optional[Literal["literal", "jsonpath"]] = None,
+ message_filtering_match_values: Any = None,
+ message_filtering_config: Any = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -55,6 +80,21 @@ class SQSSensor(BaseSensorOperator):
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
self.wait_time_seconds = wait_time_seconds
+ self.visibility_timeout = visibility_timeout
+
+ self.message_filtering = message_filtering
+
+ if message_filtering_match_values is not None:
+ if not isinstance(message_filtering_match_values, set):
+ message_filtering_match_values = set(message_filtering_match_values)
+ self.message_filtering_match_values = message_filtering_match_values
+
+ if self.message_filtering == 'literal':
+ if self.message_filtering_match_values is None:
+ raise TypeError('message_filtering_match_values must be specified for literal matching')
+
+ self.message_filtering_config = message_filtering_config
+
self.hook: Optional[SQSHook] = None
def poke(self, context):
@@ -69,31 +109,48 @@ class SQSSensor(BaseSensorOperator):
self.log.info('SQSSensor checking for message on queue: %s', self.sqs_queue)
- messages = sqs_conn.receive_message(
- QueueUrl=self.sqs_queue,
- MaxNumberOfMessages=self.max_messages,
- WaitTimeSeconds=self.wait_time_seconds,
- )
+ receive_message_kwargs = {
+ 'QueueUrl': self.sqs_queue,
+ 'MaxNumberOfMessages': self.max_messages,
+ 'WaitTimeSeconds': self.wait_time_seconds,
+ }
+ if self.visibility_timeout is not None:
+ receive_message_kwargs['VisibilityTimeout'] = self.visibility_timeout
- self.log.info("received message %s", str(messages))
+ response = sqs_conn.receive_message(**receive_message_kwargs)
- if 'Messages' in messages and messages['Messages']:
- entries = [
- {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']}
- for message in messages['Messages']
- ]
+ if "Messages" not in response:
+ return False
- result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+ messages = response['Messages']
+ num_messages = len(messages)
+ self.log.info("Received %d messages", num_messages)
- if 'Successful' in result:
- context['ti'].xcom_push(key='messages', value=messages)
- return True
- else:
- raise AirflowException(
- 'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages)
- )
+ if not num_messages:
+ return False
- return False
+ if self.message_filtering:
+ messages = self.filter_messages(messages)
+ num_messages = len(messages)
+ self.log.info("There are %d messages left after filtering", num_messages)
+
+ if not num_messages:
+ return False
+
+ self.log.info("Deleting %d messages", num_messages)
+
+ entries = [
+ {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in messages
+ ]
+ response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+
+ if 'Successful' in response:
+ context['ti'].xcom_push(key='messages', value=messages)
+ return True
+ else:
+ raise AirflowException(
+ 'Delete SQS Messages failed ' + str(response) + ' for messages ' + str(messages)
+ )
def get_hook(self) -> SQSHook:
"""Create and return an SQSHook"""
@@ -102,3 +159,37 @@ class SQSSensor(BaseSensorOperator):
self.hook = SQSHook(aws_conn_id=self.aws_conn_id)
return self.hook
+
+ def filter_messages(self, messages):
+ if self.message_filtering == 'literal':
+ return self.filter_messages_literal(messages)
+ if self.message_filtering == 'jsonpath':
+ return self.filter_messages_jsonpath(messages)
+ else:
+ raise NotImplementedError('Override this method to define custom filters')
+
+ def filter_messages_literal(self, messages):
+ filtered_messages = []
+ for message in messages:
+ if message['Body'] in self.message_filtering_match_values:
+ filtered_messages.append(message)
+ return filtered_messages
+
+ def filter_messages_jsonpath(self, messages):
+ jsonpath_expr = parse(self.message_filtering_config)
+ filtered_messages = []
+ for message in messages:
+ body = message['Body']
+ # Body is a string, deserialise to an object and then parse
+ body = json.loads(body)
+ results = jsonpath_expr.find(body)
+ if not results:
+ continue
+ if self.message_filtering_match_values is None:
+ filtered_messages.append(message)
+ continue
+ for result in results:
+ if result.value in self.message_filtering_match_values:
+ filtered_messages.append(message)
+ break
+ return filtered_messages
diff --git a/setup.py b/setup.py
index 09df05c..3b2389c 100644
--- a/setup.py
+++ b/setup.py
@@ -182,6 +182,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
amazon = [
'boto3>=1.15.0,<1.18.0',
'watchtower~=1.0.6',
+ 'jsonpath_ng>=1.5.3',
]
apache_beam = [
'apache-beam>=2.20.0',
diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py
index 90349a3..82a1aac 100644
--- a/tests/providers/amazon/aws/sensors/test_sqs.py
+++ b/tests/providers/amazon/aws/sensors/test_sqs.py
@@ -17,6 +17,7 @@
# under the License.
+import json
import unittest
from unittest import mock
@@ -107,3 +108,180 @@ class TestSQSSensor(unittest.TestCase):
self.sensor.poke(self.mock_context)
assert 'test exception' in ctx.value.args[0]
+
+ @mock.patch.object(SQSHook, 'get_conn')
+ def test_poke_visibility_timeout(self, mock_conn):
+ # Check without visibility_timeout parameter
+ self.sqs_hook.create_queue('test')
+ self.sqs_hook.send_message(queue_url='test', message_body='hello')
+
+ self.sensor.poke(self.mock_context)
+
+ calls_receive_message = [
+ mock.call().receive_message(QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1)
+ ]
+ mock_conn.assert_has_calls(calls_receive_message)
+ # Check with visibility_timeout parameter
+ self.sensor = SQSSensor(
+ task_id='test_task2',
+ dag=self.dag,
+ sqs_queue='test',
+ aws_conn_id='aws_default',
+ visibility_timeout=42,
+ )
+ self.sensor.poke(self.mock_context)
+
+ calls_receive_message = [
+ mock.call().receive_message(
+ QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
+ )
+ ]
+ mock_conn.assert_has_calls(calls_receive_message)
+
+ @mock_sqs
+ def test_poke_message_invalid_filtering(self):
+ self.sqs_hook.create_queue('test')
+ self.sqs_hook.send_message(queue_url='test', message_body='hello')
+ sensor = SQSSensor(
+ task_id='test_task2',
+ dag=self.dag,
+ sqs_queue='test',
+ aws_conn_id='aws_default',
+ message_filtering='invalid_option',
+ )
+ with pytest.raises(NotImplementedError) as ctx:
+ sensor.poke(self.mock_context)
+ assert 'Override this method to define custom filters' in ctx.value.args[0]
+
+ @mock.patch.object(SQSHook, "get_conn")
+ def test_poke_message_filtering_literal_values(self, mock_conn):
+ self.sqs_hook.create_queue('test')
+ matching = [{"id": 11, "body": "a matching message"}]
+ non_matching = [{"id": 12, "body": "a non-matching message"}]
+ all = matching + non_matching
+
+ def mock_receive_message(**kwargs):
+ messages = []
+ for message in all:
+ messages.append(
+ {
+ 'MessageId': message['id'],
+ 'ReceiptHandle': 100 + message['id'],
+ 'Body': message['body'],
+ }
+ )
+ return {'Messages': messages}
+
+ mock_conn.return_value.receive_message.side_effect = mock_receive_message
+
+ def mock_delete_message_batch(**kwargs):
+ return {'Successful'}
+
+ mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch
+
+ # Test that messages are filtered
+ self.sensor.message_filtering = 'literal'
+ self.sensor.message_filtering_match_values = ["a matching message"]
+ result = self.sensor.poke(self.mock_context)
+ assert result
+
+ # Test that only filtered messages are deleted
+ delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
+ calls_delete_message_batch = [
+ mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
+ ]
+ mock_conn.assert_has_calls(calls_delete_message_batch)
+
+ @mock.patch.object(SQSHook, "get_conn")
+ def test_poke_message_filtering_jsonpath(self, mock_conn):
+ self.sqs_hook.create_queue('test')
+ matching = [
+ {"id": 11, "key": {"matches": [1, 2]}},
+ {"id": 12, "key": {"matches": [3, 4, 5]}},
+ {"id": 13, "key": {"matches": [10]}},
+ ]
+ non_matching = [
+ {"id": 14, "key": {"nope": [5, 6]}},
+ {"id": 15, "key": {"nope": [7, 8]}},
+ ]
+ all = matching + non_matching
+
+ def mock_receive_message(**kwargs):
+ messages = []
+ for message in all:
+ messages.append(
+ {
+ 'MessageId': message['id'],
+ 'ReceiptHandle': 100 + message['id'],
+ 'Body': json.dumps(message),
+ }
+ )
+ return {'Messages': messages}
+
+ mock_conn.return_value.receive_message.side_effect = mock_receive_message
+
+ def mock_delete_message_batch(**kwargs):
+ return {'Successful'}
+
+ mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch
+
+ # Test that messages are filtered
+ self.sensor.message_filtering = 'jsonpath'
+ self.sensor.message_filtering_config = 'key.matches[*]'
+ result = self.sensor.poke(self.mock_context)
+ assert result
+
+ # Test that only filtered messages are deleted
+ delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
+ calls_delete_message_batch = [
+ mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
+ ]
+ mock_conn.assert_has_calls(calls_delete_message_batch)
+
+ @mock.patch.object(SQSHook, "get_conn")
+ def test_poke_message_filtering_jsonpath_values(self, mock_conn):
+ self.sqs_hook.create_queue('test')
+ matching = [
+ {"id": 11, "key": {"matches": [1, 2]}},
+ {"id": 12, "key": {"matches": [1, 4, 5]}},
+ {"id": 13, "key": {"matches": [4, 5]}},
+ ]
+ non_matching = [
+ {"id": 21, "key": {"matches": [10]}},
+ {"id": 22, "key": {"nope": [5, 6]}},
+ {"id": 23, "key": {"nope": [7, 8]}},
+ ]
+ all = matching + non_matching
+
+ def mock_receive_message(**kwargs):
+ messages = []
+ for message in all:
+ messages.append(
+ {
+ 'MessageId': message['id'],
+ 'ReceiptHandle': 100 + message['id'],
+ 'Body': json.dumps(message),
+ }
+ )
+ return {'Messages': messages}
+
+ mock_conn.return_value.receive_message.side_effect = mock_receive_message
+
+ def mock_delete_message_batch(**kwargs):
+ return {'Successful'}
+
+ mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch
+
+ # Test that messages are filtered
+ self.sensor.message_filtering = 'jsonpath'
+ self.sensor.message_filtering_config = 'key.matches[*]'
+ self.sensor.message_filtering_match_values = [1, 4]
+ result = self.sensor.poke(self.mock_context)
+ assert result
+
+ # Test that only filtered messages are deleted
+ delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
+ calls_delete_message_batch = [
+ mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
+ ]
+ mock_conn.assert_has_calls(calls_delete_message_batch)