You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/14 12:07:33 UTC

[airflow] branch main updated: Implement Azure Service Bus (Update and Receive) Subscription Operator (#25029)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 292440d54f Implement Azure Service Bus (Update and Receive) Subscription Operator (#25029)
292440d54f is described below

commit 292440d54f4db84aaf0c5a98cf5fcf34303f2fa8
Author: Bharanidharan <94...@users.noreply.github.com>
AuthorDate: Thu Jul 14 17:37:11 2022 +0530

    Implement Azure Service Bus (Update and Receive) Subscription Operator (#25029)
    
    Implement Azure Service Bus (Update and Receive) Subscription Operator
---
 airflow/providers/microsoft/azure/hooks/asb.py     |  36 +++++++
 airflow/providers/microsoft/azure/operators/asb.py | 116 +++++++++++++++++++++
 .../operators/asb.rst                              |  32 ++++++
 tests/providers/microsoft/azure/hooks/test_asb.py  |  81 ++++++++++----
 .../microsoft/azure/operators/test_asb.py          |  87 ++++++++++++++++
 .../microsoft/azure/example_azure_service_bus.py   |  22 ++++
 6 files changed, 355 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/asb.py b/airflow/providers/microsoft/azure/hooks/asb.py
index 72cb65a5f4..b7d3074ba1 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -212,3 +212,39 @@ class MessageHook(BaseAzureServiceBusHook):
                 for msg in received_msgs:
                     self.log.info(msg)
                     receiver.complete_message(msg)
+
+    def receive_subscription_message(
+        self,
+        topic_name: str,
+        subscription_name: str,
+        max_message_count: Optional[int],
+        max_wait_time: Optional[float],
+    ):
+        """
+        Receive a batch of subscription message at once. This approach is optimal if you wish
+        to process multiple messages simultaneously, or perform an ad-hoc receive as a single call.
+
+        :param subscription_name: The subscription name that will own the rule in topic
+        :param topic_name: The topic that will own the subscription rule.
+        :param max_message_count: Maximum number of messages in the batch.
+            Actual number returned will depend on prefetch_count and incoming stream rate.
+            Setting to None will fully depend on the prefetch config. The default value is 1.
+        :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. If no
+            messages arrive, and no timeout is specified, this call will not return until the
+            connection is closed. If specified, an no messages arrive within the timeout period,
+            an empty list will be returned.
+        """
+        if subscription_name is None:
+            raise TypeError("Subscription name cannot be None.")
+        if topic_name is None:
+            raise TypeError("Topic name cannot be None.")
+        with self.get_conn() as service_bus_client, service_bus_client.get_subscription_receiver(
+            topic_name, subscription_name
+        ) as subscription_receiver:
+            with subscription_receiver:
+                received_msgs = subscription_receiver.receive_messages(
+                    max_message_count=max_message_count, max_wait_time=max_wait_time
+                )
+                for msg in received_msgs:
+                    self.log.info(msg)
+                    subscription_receiver.complete_message(msg)
diff --git a/airflow/providers/microsoft/azure/operators/asb.py b/airflow/providers/microsoft/azure/operators/asb.py
index 52897ed4f6..2e7599c469 100644
--- a/airflow/providers/microsoft/azure/operators/asb.py
+++ b/airflow/providers/microsoft/azure/operators/asb.py
@@ -314,6 +314,122 @@ class AzureServiceBusSubscriptionCreateOperator(BaseOperator):
             self.log.info("Created subscription %s", subscription.name)
 
 
+class AzureServiceBusUpdateSubscriptionOperator(BaseOperator):
+    """
+    Update an Azure ServiceBus Topic Subscription under a ServiceBus Namespace
+    by using ServiceBusAdministrationClient
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:AzureServiceBusUpdateSubscriptionOperator`
+
+    :param topic_name: The topic that will own the to-be-created subscription.
+    :param subscription_name: Name of the subscription that need to be created.
+    :param max_delivery_count: The maximum delivery count. A message is automatically dead lettered
+        after this number of deliveries. Default value is 10.
+    :param dead_lettering_on_message_expiration: A value that indicates whether this subscription
+        has dead letter support when a message expires.
+    :param enable_batched_operations: Value that indicates whether server-side batched
+        operations are enabled.
+    :param azure_service_bus_conn_id: Reference to the
+        :ref:`Azure Service Bus connection<howto/connection:azure_service_bus>`.
+    """
+
+    template_fields: Sequence[str] = ("topic_name", "subscription_name")
+    ui_color = "#e4f0e8"
+
+    def __init__(
+        self,
+        *,
+        topic_name: str,
+        subscription_name: str,
+        max_delivery_count: Optional[int] = None,
+        dead_lettering_on_message_expiration: Optional[bool] = None,
+        enable_batched_operations: Optional[bool] = None,
+        azure_service_bus_conn_id: str = 'azure_service_bus_default',
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.topic_name = topic_name
+        self.subscription_name = subscription_name
+        self.max_delivery_count = max_delivery_count
+        self.dl_on_message_expiration = dead_lettering_on_message_expiration
+        self.enable_batched_operations = enable_batched_operations
+        self.azure_service_bus_conn_id = azure_service_bus_conn_id
+
+    def execute(self, context: "Context") -> None:
+        """Updates Subscription properties, by connecting to Service Bus Admin client"""
+        hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
+
+        with hook.get_conn() as service_mgmt_conn:
+            subscription_prop = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name)
+            if self.max_delivery_count:
+                subscription_prop.max_delivery_count = self.max_delivery_count
+            if self.dl_on_message_expiration is not None:
+                subscription_prop.dead_lettering_on_message_expiration = self.dl_on_message_expiration
+            if self.enable_batched_operations is not None:
+                subscription_prop.enable_batched_operations = self.enable_batched_operations
+            # update by updating the properties in the model
+            service_mgmt_conn.update_subscription(self.topic_name, subscription_prop)
+            updated_subscription = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name)
+            self.log.info("Subscription Updated successfully %s", updated_subscription)
+
+
+class ASBReceiveSubscriptionMessageOperator(BaseOperator):
+    """
+    Receive a Batch messages from a Service Bus Subscription under specific Topic.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:ASBReceiveSubscriptionMessageOperator`
+
+    :param subscription_name: The subscription name that will own the rule in topic
+    :param topic_name: The topic that will own the subscription rule.
+    :param max_message_count: Maximum number of messages in the batch.
+        Actual number returned will depend on prefetch_count and incoming stream rate.
+        Setting to None will fully depend on the prefetch config. The default value is 1.
+    :param max_wait_time: Maximum time to wait in seconds for the first message to arrive. If no
+        messages arrive, and no timeout is specified, this call will not return until the
+        connection is closed. If specified, an no messages arrive within the timeout period,
+        an empty list will be returned.
+    :param azure_service_bus_conn_id: Reference to the
+        :ref:`Azure Service Bus connection <howto/connection:azure_service_bus>`.
+    """
+
+    template_fields: Sequence[str] = ("topic_name", "subscription_name")
+    ui_color = "#e4f0e8"
+
+    def __init__(
+        self,
+        *,
+        topic_name: str,
+        subscription_name: str,
+        max_message_count: Optional[int] = 1,
+        max_wait_time: Optional[float] = 5,
+        azure_service_bus_conn_id: str = 'azure_service_bus_default',
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.topic_name = topic_name
+        self.subscription_name = subscription_name
+        self.max_message_count = max_message_count
+        self.max_wait_time = max_wait_time
+        self.azure_service_bus_conn_id = azure_service_bus_conn_id
+
+    def execute(self, context: "Context") -> None:
+        """
+        Receive Message in specific queue in Service Bus namespace,
+        by connecting to Service Bus client
+        """
+        # Create the hook
+        hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)
+
+        # Receive message
+        hook.receive_subscription_message(
+            self.topic_name, self.subscription_name, self.max_message_count, self.max_wait_time
+        )
+
+
 class AzureServiceBusSubscriptionDeleteOperator(BaseOperator):
     """
     Deletes the topic subscription in the Azure ServiceBus namespace
diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst b/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst
index 510a41c5e7..0614d13224 100644
--- a/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/operators/asb.rst
@@ -119,6 +119,38 @@ Below is an example of using this operator to execute an Azure Service Bus Creat
     :start-after: [START howto_operator_create_service_bus_subscription]
     :end-before: [END howto_operator_create_service_bus_subscription]
 
+.. _howto/operator:AzureServiceBusUpdateSubscriptionOperator:
+
+Update Azure Service Bus Subscription
+======================================
+
+To Update the Azure service bus topic Subscription which is already created, with specific Parameter you can use
+:class:`~airflow.providers.microsoft.azure.operators.asb.AzureServiceBusUpdateSubscriptionOperator`.
+
+Below is an example of using this operator to execute an Azure Service Bus Update Subscription.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_azure_service_bus.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_update_service_bus_subscription]
+    :end-before: [END howto_operator_update_service_bus_subscription]
+
+.. _howto/operator:ASBReceiveSubscriptionMessageOperator:
+
+Receive Azure Service Bus Subscription Message
+===============================================
+
+To Receive a Batch messages from a Service Bus Subscription under specific Topic, you can use
+:class:`~airflow.providers.microsoft.azure.operators.asb.ASBReceiveSubscriptionMessageOperator`.
+
+Below is an example of using this operator to execute an Azure Service Bus Receive Subscription Message.
+
+.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_azure_service_bus.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_receive_message_service_bus_subscription]
+    :end-before: [END howto_operator_receive_message_service_bus_subscription]
+
 .. _howto/operator:AzureServiceBusSubscriptionDeleteOperator:
 
 Delete Azure Service Bus Subscription
diff --git a/tests/providers/microsoft/azure/hooks/test_asb.py b/tests/providers/microsoft/azure/hooks/test_asb.py
index e8a7326dfc..770d7f4fb1 100644
--- a/tests/providers/microsoft/azure/hooks/test_asb.py
+++ b/tests/providers/microsoft/azure/hooks/test_asb.py
@@ -88,6 +88,35 @@ class TestAdminClientHook:
         with pytest.raises(TypeError):
             hook.delete_queue(None)
 
+    @mock.patch('airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn')
+    def test_delete_subscription(self, mock_sb_admin_client):
+        """
+        Test Delete subscription functionality by passing subscription name and topic name,
+        assert the function with values, mock the azure service bus function  `delete_subscription`
+        """
+        subscription_name = "test_subscription_name"
+        topic_name = "test_topic_name"
+        hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
+        hook.delete_subscription(subscription_name, topic_name)
+        expected_calls = [mock.call().__enter__().delete_subscription(topic_name, subscription_name)]
+        mock_sb_admin_client.assert_has_calls(expected_calls)
+
+    @pytest.mark.parametrize(
+        "mock_subscription_name, mock_topic_name",
+        [("subscription_1", None), (None, "topic_1")],
+    )
+    @mock.patch('airflow.providers.microsoft.azure.hooks.asb.AdminClientHook')
+    def test_delete_subscription_exception(
+        self, mock_sb_admin_client, mock_subscription_name, mock_topic_name
+    ):
+        """
+        Test `delete_subscription` functionality to raise AirflowException,
+         by passing subscription name and topic name as None and pytest raise Airflow Exception
+        """
+        hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
+        with pytest.raises(TypeError):
+            hook.delete_subscription(mock_subscription_name, mock_topic_name)
+
 
 class TestMessageHook:
     def setup_class(self) -> None:
@@ -202,31 +231,45 @@ class TestMessageHook:
         with pytest.raises(TypeError):
             hook.receive_message(None)
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn')
-    def test_delete_subscription(self, mock_sb_admin_client):
+    @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn')
+    def test_receive_subscription_message(self, mock_sb_client):
         """
-        Test Delete subscription functionality by passing subscription name and topic name,
-        assert the function with values, mock the azure service bus function  `delete_subscription`
+        Test `receive_subscription_message` hook function and assert the function with mock value,
+        mock the azure service bus `receive_message` function of subscription
         """
-        subscription_name = "test_subscription_name"
-        topic_name = "test_topic_name"
-        hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
-        hook.delete_subscription(subscription_name, topic_name)
-        expected_calls = [mock.call().__enter__().delete_subscription(topic_name, subscription_name)]
-        mock_sb_admin_client.assert_has_calls(expected_calls)
+        subscription_name = "subscription_1"
+        topic_name = "topic_name"
+        max_message_count = 10
+        max_wait_time = 5
+        hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
+        hook.receive_subscription_message(topic_name, subscription_name, max_message_count, max_wait_time)
+        expected_calls = [
+            mock.call()
+            .__enter__()
+            .get_subscription_receiver(subscription_name, topic_name)
+            .__enter__()
+            .receive_messages(max_message_count=max_message_count, max_wait_time=max_wait_time)
+            .get_subscription_receiver(subscription_name, topic_name)
+            .__exit__()
+            .mock_call()
+            .__exit__
+        ]
+        mock_sb_client.assert_has_calls(expected_calls)
 
     @pytest.mark.parametrize(
-        "mock_subscription_name, mock_topic_name",
-        [("subscription_1", None), (None, "topic_1")],
+        "mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time",
+        [("subscription_1", None, None, None), (None, "topic_1", None, None)],
     )
-    @mock.patch('airflow.providers.microsoft.azure.hooks.asb.AdminClientHook')
-    def test_delete_subscription_exception(
-        self, mock_sb_admin_client, mock_subscription_name, mock_topic_name
+    @mock.patch('airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn')
+    def test_receive_subscription_message_exception(
+        self, mock_sb_client, mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time
     ):
         """
-        Test `delete_subscription` functionality to raise AirflowException,
-         by passing subscription name and topic name as None and pytest raise Airflow Exception
+        Test `receive_subscription_message` hook function to raise exception
+        by sending the subscription and topic name as none
         """
-        hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)
+        hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
         with pytest.raises(TypeError):
-            hook.delete_subscription(mock_subscription_name, mock_topic_name)
+            hook.receive_subscription_message(
+                mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time
+            )
diff --git a/tests/providers/microsoft/azure/operators/test_asb.py b/tests/providers/microsoft/azure/operators/test_asb.py
index fda3457889..25a79cbf12 100644
--- a/tests/providers/microsoft/azure/operators/test_asb.py
+++ b/tests/providers/microsoft/azure/operators/test_asb.py
@@ -21,12 +21,14 @@ import pytest
 from azure.servicebus import ServiceBusMessage
 
 from airflow.providers.microsoft.azure.operators.asb import (
+    ASBReceiveSubscriptionMessageOperator,
     AzureServiceBusCreateQueueOperator,
     AzureServiceBusDeleteQueueOperator,
     AzureServiceBusReceiveMessageOperator,
     AzureServiceBusSendMessageOperator,
     AzureServiceBusSubscriptionCreateOperator,
     AzureServiceBusSubscriptionDeleteOperator,
+    AzureServiceBusUpdateSubscriptionOperator,
 )
 
 QUEUE_NAME = "test_queue"
@@ -290,3 +292,88 @@ class TestASBDeleteSubscriptionOperator:
         mock_get_conn.return_value.__enter__.return_value.delete_subscription.assert_called_once_with(
             TOPIC_NAME, SUBSCRIPTION_NAME
         )
+
+
+class TestAzureServiceBusUpdateSubscriptionOperator:
+    def test_init(self):
+        """
+        Test init by creating AzureServiceBusUpdateSubscriptionOperator with task id, subscription name,
+        topic name and asserting with values
+        """
+        asb_update_subscription_operator = AzureServiceBusUpdateSubscriptionOperator(
+            task_id="asb_update_subscription",
+            topic_name=TOPIC_NAME,
+            subscription_name=SUBSCRIPTION_NAME,
+            max_delivery_count=10,
+        )
+        assert asb_update_subscription_operator.task_id == "asb_update_subscription"
+        assert asb_update_subscription_operator.topic_name == TOPIC_NAME
+        assert asb_update_subscription_operator.subscription_name == SUBSCRIPTION_NAME
+        assert asb_update_subscription_operator.max_delivery_count == 10
+
+    @mock.patch('azure.servicebus.management.SubscriptionProperties')
+    @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook.get_conn")
+    def test_update_subscription(self, mock_get_conn, mock_subscription_properties):
+        """
+        Test AzureServiceBusUpdateSubscriptionOperator passed with the subscription name, topic name
+        mocking the connection details, hook update_subscription function
+        """
+        mock_subscription_properties.name = SUBSCRIPTION_NAME
+        mock_subscription_properties.max_delivery_count = 20
+        mock_get_conn.return_value.__enter__.return_value.get_subscription.return_value = (
+            mock_subscription_properties
+        )
+        asb_update_subscription = AzureServiceBusUpdateSubscriptionOperator(
+            task_id="asb_update_subscription",
+            topic_name=TOPIC_NAME,
+            subscription_name=SUBSCRIPTION_NAME,
+            max_delivery_count=20,
+        )
+        with mock.patch.object(asb_update_subscription.log, "info") as mock_log_info:
+            asb_update_subscription.execute(None)
+        mock_log_info.assert_called_with("Subscription Updated successfully %s", mock_subscription_properties)
+
+
+class TestASBSubscriptionReceiveMessageOperator:
+    def test_init(self):
+        """
+        Test init by creating ASBReceiveSubscriptionMessageOperator with task id, topic_name,
+        subscription_name, batch and asserting with values
+        """
+
+        asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator(
+            task_id="asb_subscription_receive_message",
+            topic_name=TOPIC_NAME,
+            subscription_name=SUBSCRIPTION_NAME,
+            max_message_count=10,
+        )
+        assert asb_subscription_receive_message.task_id == "asb_subscription_receive_message"
+        assert asb_subscription_receive_message.topic_name == TOPIC_NAME
+        assert asb_subscription_receive_message.subscription_name == SUBSCRIPTION_NAME
+        assert asb_subscription_receive_message.max_message_count == 10
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
+    def test_receive_message_queue(self, mock_get_conn):
+        """
+        Test ASBReceiveSubscriptionMessageOperator by mock connection, values
+        and the service bus receive message
+        """
+        asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator(
+            task_id="asb_subscription_receive_message",
+            topic_name=TOPIC_NAME,
+            subscription_name=SUBSCRIPTION_NAME,
+            max_message_count=10,
+        )
+        asb_subscription_receive_message.execute(None)
+        expected_calls = [
+            mock.call()
+            .__enter__()
+            .get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME)
+            .__enter__()
+            .receive_messages(max_message_count=10, max_wait_time=5)
+            .get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME)
+            .__exit__()
+            .mock_call()
+            .__exit__
+        ]
+        mock_get_conn.assert_has_calls(expected_calls)
diff --git a/tests/system/providers/microsoft/azure/example_azure_service_bus.py b/tests/system/providers/microsoft/azure/example_azure_service_bus.py
index 099907276f..a17a3a0bee 100644
--- a/tests/system/providers/microsoft/azure/example_azure_service_bus.py
+++ b/tests/system/providers/microsoft/azure/example_azure_service_bus.py
@@ -21,12 +21,14 @@ from datetime import datetime, timedelta
 from airflow import DAG
 from airflow.models.baseoperator import chain
 from airflow.providers.microsoft.azure.operators.asb import (
+    ASBReceiveSubscriptionMessageOperator,
     AzureServiceBusCreateQueueOperator,
     AzureServiceBusDeleteQueueOperator,
     AzureServiceBusReceiveMessageOperator,
     AzureServiceBusSendMessageOperator,
     AzureServiceBusSubscriptionCreateOperator,
     AzureServiceBusSubscriptionDeleteOperator,
+    AzureServiceBusUpdateSubscriptionOperator,
 )
 
 EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))
@@ -100,6 +102,24 @@ with DAG(
     )
     # [END howto_operator_create_service_bus_subscription]
 
+    # [START howto_operator_update_service_bus_subscription]
+    update_service_bus_subscription = AzureServiceBusUpdateSubscriptionOperator(
+        task_id="update_service_bus_subscription",
+        topic_name=TOPIC_NAME,
+        subscription_name=SUBSCRIPTION_NAME,
+        max_delivery_count=5,
+    )
+    # [END howto_operator_update_service_bus_subscription]
+
+    # [START howto_operator_receive_message_service_bus_subscription]
+    receive_message_service_bus_subscription = ASBReceiveSubscriptionMessageOperator(
+        task_id="receive_message_service_bus_subscription",
+        topic_name=TOPIC_NAME,
+        subscription_name=SUBSCRIPTION_NAME,
+        max_message_count=10,
+    )
+    # [END howto_operator_receive_message_service_bus_subscription]
+
     # [START howto_operator_delete_service_bus_subscription]
     delete_service_bus_subscription = AzureServiceBusSubscriptionDeleteOperator(
         task_id="delete_service_bus_subscription",
@@ -122,6 +142,8 @@ with DAG(
         send_list_message_to_service_bus_queue,
         send_batch_message_to_service_bus_queue,
         receive_message_service_bus_queue,
+        update_service_bus_subscription,
+        receive_message_service_bus_subscription,
         delete_service_bus_subscription,
         delete_service_bus_queue,
     )