You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/03/03 09:32:45 UTC

[airflow] 21/41: Support google-cloud-pubsub>=2.0.0 (#13127)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 8faa1bdb02422bd62eb730da7d653164050a7dd9
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Dec 22 13:02:59 2020 +0100

    Support google-cloud-pubsub>=2.0.0 (#13127)
    
    (cherry picked from commit 8c00ec89b97aa6e725379d08c8ff29a01be47e73)
---
 airflow/providers/google/cloud/hooks/pubsub.py     |  81 ++++----
 airflow/providers/google/cloud/operators/pubsub.py |   3 +-
 airflow/providers/google/cloud/sensors/pubsub.py   |   3 +-
 setup.py                                           |   2 +-
 tests/providers/google/cloud/hooks/test_pubsub.py  | 221 +++++++++++----------
 .../google/cloud/operators/test_pubsub.py          |  16 +-
 .../providers/google/cloud/sensors/test_pubsub.py  |  16 +-
 7 files changed, 177 insertions(+), 165 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/pubsub.py b/airflow/providers/google/cloud/hooks/pubsub.py
index f2ae190..37240a2 100644
--- a/airflow/providers/google/cloud/hooks/pubsub.py
+++ b/airflow/providers/google/cloud/hooks/pubsub.py
@@ -111,7 +111,7 @@ class PubSubHook(GoogleBaseHook):
         self._validate_messages(messages)
 
         publisher = self.get_conn()
-        topic_path = PublisherClient.topic_path(project_id, topic)  # pylint: disable=no-member
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         self.log.info("Publish %d messages to topic (path) %s", len(messages), topic_path)
         try:
@@ -206,7 +206,7 @@ class PubSubHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]]
         """
         publisher = self.get_conn()
-        topic_path = PublisherClient.topic_path(project_id, topic)  # pylint: disable=no-member
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         # Add airflow-version label to the topic
         labels = labels or {}
@@ -216,13 +216,15 @@ class PubSubHook(GoogleBaseHook):
         try:
             # pylint: disable=no-member
             publisher.create_topic(
-                name=topic_path,
-                labels=labels,
-                message_storage_policy=message_storage_policy,
-                kms_key_name=kms_key_name,
+                request={
+                    "name": topic_path,
+                    "labels": labels,
+                    "message_storage_policy": message_storage_policy,
+                    "kms_key_name": kms_key_name,
+                },
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
         except AlreadyExists:
             self.log.warning('Topic already exists: %s', topic)
@@ -266,16 +268,13 @@ class PubSubHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]]
         """
         publisher = self.get_conn()
-        topic_path = PublisherClient.topic_path(project_id, topic)  # pylint: disable=no-member
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         self.log.info("Deleting topic (path) %s", topic_path)
         try:
             # pylint: disable=no-member
             publisher.delete_topic(
-                topic=topic_path,
-                retry=retry,
-                timeout=timeout,
-                metadata=metadata,
+                request={"topic": topic_path}, retry=retry, timeout=timeout, metadata=metadata or ()
             )
         except NotFound:
             self.log.warning('Topic does not exist: %s', topic_path)
@@ -401,27 +400,29 @@ class PubSubHook(GoogleBaseHook):
         labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-')
 
         # pylint: disable=no-member
-        subscription_path = SubscriberClient.subscription_path(subscription_project_id, subscription)
-        topic_path = SubscriberClient.topic_path(project_id, topic)
+        subscription_path = f"projects/{subscription_project_id}/subscriptions/{subscription}"
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         self.log.info("Creating subscription (path) %s for topic (path) %a", subscription_path, topic_path)
         try:
             subscriber.create_subscription(
-                name=subscription_path,
-                topic=topic_path,
-                push_config=push_config,
-                ack_deadline_seconds=ack_deadline_secs,
-                retain_acked_messages=retain_acked_messages,
-                message_retention_duration=message_retention_duration,
-                labels=labels,
-                enable_message_ordering=enable_message_ordering,
-                expiration_policy=expiration_policy,
-                filter_=filter_,
-                dead_letter_policy=dead_letter_policy,
-                retry_policy=retry_policy,
+                request={
+                    "name": subscription_path,
+                    "topic": topic_path,
+                    "push_config": push_config,
+                    "ack_deadline_seconds": ack_deadline_secs,
+                    "retain_acked_messages": retain_acked_messages,
+                    "message_retention_duration": message_retention_duration,
+                    "labels": labels,
+                    "enable_message_ordering": enable_message_ordering,
+                    "expiration_policy": expiration_policy,
+                    "filter": filter_,
+                    "dead_letter_policy": dead_letter_policy,
+                    "retry_policy": retry_policy,
+                },
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
         except AlreadyExists:
             self.log.warning('Subscription already exists: %s', subscription_path)
@@ -466,13 +467,16 @@ class PubSubHook(GoogleBaseHook):
         """
         subscriber = self.subscriber_client
         # noqa E501 # pylint: disable=no-member
-        subscription_path = SubscriberClient.subscription_path(project_id, subscription)
+        subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
 
         self.log.info("Deleting subscription (path) %s", subscription_path)
         try:
             # pylint: disable=no-member
             subscriber.delete_subscription(
-                subscription=subscription_path, retry=retry, timeout=timeout, metadata=metadata
+                request={"subscription": subscription_path},
+                retry=retry,
+                timeout=timeout,
+                metadata=metadata or (),
             )
 
         except NotFound:
@@ -527,18 +531,20 @@ class PubSubHook(GoogleBaseHook):
         """
         subscriber = self.subscriber_client
         # noqa E501 # pylint: disable=no-member,line-too-long
-        subscription_path = SubscriberClient.subscription_path(project_id, subscription)
+        subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
 
         self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path)
         try:
             # pylint: disable=no-member
             response = subscriber.pull(
-                subscription=subscription_path,
-                max_messages=max_messages,
-                return_immediately=return_immediately,
+                request={
+                    "subscription": subscription_path,
+                    "max_messages": max_messages,
+                    "return_immediately": return_immediately,
+                },
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
             result = getattr(response, 'received_messages', [])
             self.log.info("Pulled %d messages from subscription (path) %s", len(result), subscription_path)
@@ -591,17 +597,16 @@ class PubSubHook(GoogleBaseHook):
 
         subscriber = self.subscriber_client
         # noqa E501 # pylint: disable=no-member
-        subscription_path = SubscriberClient.subscription_path(project_id, subscription)
+        subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
 
         self.log.info("Acknowledging %d ack_ids from subscription (path) %s", len(ack_ids), subscription_path)
         try:
             # pylint: disable=no-member
             subscriber.acknowledge(
-                subscription=subscription_path,
-                ack_ids=ack_ids,
+                request={"subscription": subscription_path, "ack_ids": ack_ids},
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
         except (HttpError, GoogleAPICallError) as e:
             raise PubSubException(
diff --git a/airflow/providers/google/cloud/operators/pubsub.py b/airflow/providers/google/cloud/operators/pubsub.py
index e8cf735..23b545f 100644
--- a/airflow/providers/google/cloud/operators/pubsub.py
+++ b/airflow/providers/google/cloud/operators/pubsub.py
@@ -29,7 +29,6 @@ from google.cloud.pubsub_v1.types import (
     ReceivedMessage,
     RetryPolicy,
 )
-from google.protobuf.json_format import MessageToDict
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
@@ -958,6 +957,6 @@ class PubSubPullOperator(BaseOperator):
         :param context: same as in `execute`
         :return: value to be saved to XCom.
         """
-        messages_json = [MessageToDict(m) for m in pulled_messages]
+        messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
 
         return messages_json
diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py
index d6e0be5..ff1f811 100644
--- a/airflow/providers/google/cloud/sensors/pubsub.py
+++ b/airflow/providers/google/cloud/sensors/pubsub.py
@@ -20,7 +20,6 @@ import warnings
 from typing import Any, Callable, Dict, List, Optional, Sequence, Union
 
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import MessageToDict
 
 from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
 from airflow.sensors.base import BaseSensorOperator
@@ -200,6 +199,6 @@ class PubSubPullSensor(BaseSensorOperator):
         :param context: same as in `execute`
         :return: value to be saved to XCom.
         """
-        messages_json = [MessageToDict(m) for m in pulled_messages]
+        messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
 
         return messages_json
diff --git a/setup.py b/setup.py
index 1ec4f5d..ff9fd71 100644
--- a/setup.py
+++ b/setup.py
@@ -296,7 +296,7 @@ google = [
     'google-cloud-memcache>=0.2.0',
     'google-cloud-monitoring>=0.34.0,<2.0.0',
     'google-cloud-os-login>=2.0.0,<3.0.0',
-    'google-cloud-pubsub>=1.0.0,<2.0.0',
+    'google-cloud-pubsub>=2.0.0,<3.0.0',
     'google-cloud-redis>=0.3.0,<2.0.0',
     'google-cloud-secret-manager>=0.2.0,<2.0.0',
     'google-cloud-spanner>=1.10.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py
index 4086526..628d619 100644
--- a/tests/providers/google/cloud/hooks/test_pubsub.py
+++ b/tests/providers/google/cloud/hooks/test_pubsub.py
@@ -25,7 +25,6 @@ import pytest
 from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
 from google.cloud.exceptions import NotFound
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import ParseDict
 from googleapiclient.errors import HttpError
 from parameterized import parameterized
 
@@ -67,15 +66,12 @@ class TestPubSubHook(unittest.TestCase):
 
     def _generate_messages(self, count) -> List[ReceivedMessage]:
         return [
-            ParseDict(
-                {
-                    "ack_id": str(i),
-                    "message": {
-                        "data": f'Message {i}'.encode('utf8'),
-                        "attributes": {"type": "generated message"},
-                    },
+            ReceivedMessage(
+                ack_id=str(i),
+                message={
+                    "data": f'Message {i}'.encode('utf8'),
+                    "attributes": {"type": "generated message"},
                 },
-                ReceivedMessage(),
             )
             for i in range(1, count + 1)
         ]
@@ -112,20 +108,19 @@ class TestPubSubHook(unittest.TestCase):
         create_method = mock_service.return_value.create_topic
         self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC)
         create_method.assert_called_once_with(
-            name=EXPANDED_TOPIC,
-            labels=LABELS,
-            message_storage_policy=None,
-            kms_key_name=None,
+            request=dict(name=EXPANDED_TOPIC, labels=LABELS, message_storage_policy=None, kms_key_name=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
     def test_delete_topic(self, mock_service):
         delete_method = mock_service.return_value.delete_topic
         self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC)
-        delete_method.assert_called_once_with(topic=EXPANDED_TOPIC, retry=None, timeout=None, metadata=None)
+        delete_method.assert_called_once_with(
+            request=dict(topic=EXPANDED_TOPIC), retry=None, timeout=None, metadata=()
+        )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
     def test_delete_nonexisting_topic_failifnotexists(self, mock_service):
@@ -177,21 +172,23 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION
         )
         create_method.assert_called_once_with(
-            name=EXPANDED_SUBSCRIPTION,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=EXPANDED_SUBSCRIPTION,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert TEST_SUBSCRIPTION == response
 
@@ -208,21 +205,23 @@ class TestPubSubHook(unittest.TestCase):
             'a-different-project', TEST_SUBSCRIPTION
         )
         create_method.assert_called_once_with(
-            name=expected_subscription,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=expected_subscription,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
         assert TEST_SUBSCRIPTION == response
@@ -232,7 +231,7 @@ class TestPubSubHook(unittest.TestCase):
         self.pubsub_hook.delete_subscription(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION)
         delete_method = mock_service.delete_subscription
         delete_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION, retry=None, timeout=None, metadata=None
+            request=dict(subscription=EXPANDED_SUBSCRIPTION), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -266,21 +265,23 @@ class TestPubSubHook(unittest.TestCase):
 
         response = self.pubsub_hook.create_subscription(project_id=TEST_PROJECT, topic=TEST_TOPIC)
         create_method.assert_called_once_with(
-            name=expected_name,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=expected_name,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert f'sub-{TEST_UUID}' == response
 
@@ -292,21 +293,23 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, ack_deadline_secs=30
         )
         create_method.assert_called_once_with(
-            name=EXPANDED_SUBSCRIPTION,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=30,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=EXPANDED_SUBSCRIPTION,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=30,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert TEST_SUBSCRIPTION == response
 
@@ -321,21 +324,23 @@ class TestPubSubHook(unittest.TestCase):
             filter_='attributes.domain="com"',
         )
         create_method.assert_called_once_with(
-            name=EXPANDED_SUBSCRIPTION,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_='attributes.domain="com"',
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=EXPANDED_SUBSCRIPTION,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter='attributes.domain="com"',
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert TEST_SUBSCRIPTION == response
 
@@ -401,12 +406,14 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10
         )
         pull_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            max_messages=10,
-            return_immediately=False,
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                max_messages=10,
+                return_immediately=False,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert pulled_messages == response
 
@@ -419,12 +426,14 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10
         )
         pull_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            max_messages=10,
-            return_immediately=False,
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                max_messages=10,
+                return_immediately=False,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert [] == response
 
@@ -445,12 +454,14 @@ class TestPubSubHook(unittest.TestCase):
         with pytest.raises(PubSubException):
             self.pubsub_hook.pull(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10)
             pull_method.assert_called_once_with(
-                subscription=EXPANDED_SUBSCRIPTION,
-                max_messages=10,
-                return_immediately=False,
+                request=dict(
+                    subscription=EXPANDED_SUBSCRIPTION,
+                    max_messages=10,
+                    return_immediately=False,
+                ),
                 retry=None,
                 timeout=None,
-                metadata=None,
+                metadata=(),
             )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -461,11 +472,13 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3']
         )
         ack_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            ack_ids=['1', '2', '3'],
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                ack_ids=['1', '2', '3'],
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -478,11 +491,13 @@ class TestPubSubHook(unittest.TestCase):
             messages=self._generate_messages(3),
         )
         ack_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            ack_ids=['1', '2', '3'],
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                ack_ids=['1', '2', '3'],
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @parameterized.expand(
@@ -504,11 +519,13 @@ class TestPubSubHook(unittest.TestCase):
                 project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3']
             )
             ack_method.assert_called_once_with(
-                subscription=EXPANDED_SUBSCRIPTION,
-                ack_ids=['1', '2', '3'],
+                request=dict(
+                    subscription=EXPANDED_SUBSCRIPTION,
+                    ack_ids=['1', '2', '3'],
+                ),
                 retry=None,
                 timeout=None,
-                metadata=None,
+                metadata=(),
             )
 
     @parameterized.expand(
diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py
index 9ff71e6..6abfffa 100644
--- a/tests/providers/google/cloud/operators/test_pubsub.py
+++ b/tests/providers/google/cloud/operators/test_pubsub.py
@@ -21,7 +21,6 @@ from typing import Any, Dict, List
 from unittest import mock
 
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import MessageToDict, ParseDict
 
 from airflow.providers.google.cloud.operators.pubsub import (
     PubSubCreateSubscriptionOperator,
@@ -230,21 +229,18 @@ class TestPubSubPublishOperator(unittest.TestCase):
 class TestPubSubPullOperator(unittest.TestCase):
     def _generate_messages(self, count):
         return [
-            ParseDict(
-                {
-                    "ack_id": "%s" % i,
-                    "message": {
-                        "data": f'Message {i}'.encode('utf8'),
-                        "attributes": {"type": "generated message"},
-                    },
+            ReceivedMessage(
+                ack_id="%s" % i,
+                message={
+                    "data": f'Message {i}'.encode('utf8'),
+                    "attributes": {"type": "generated message"},
                 },
-                ReceivedMessage(),
             )
             for i in range(1, count + 1)
         ]
 
     def _generate_dicts(self, count):
-        return [MessageToDict(m) for m in self._generate_messages(count)]
+        return [ReceivedMessage.to_dict(m) for m in self._generate_messages(count)]
 
     @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook')
     def test_execute_no_messages(self, mock_hook):
diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py
index ba1aee9..795860b 100644
--- a/tests/providers/google/cloud/sensors/test_pubsub.py
+++ b/tests/providers/google/cloud/sensors/test_pubsub.py
@@ -22,7 +22,6 @@ from unittest import mock
 
 import pytest
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import MessageToDict, ParseDict
 
 from airflow.exceptions import AirflowSensorTimeout
 from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor
@@ -35,21 +34,18 @@ TEST_SUBSCRIPTION = 'test-subscription'
 class TestPubSubPullSensor(unittest.TestCase):
     def _generate_messages(self, count):
         return [
-            ParseDict(
-                {
-                    "ack_id": "%s" % i,
-                    "message": {
-                        "data": f'Message {i}'.encode('utf8'),
-                        "attributes": {"type": "generated message"},
-                    },
+            ReceivedMessage(
+                ack_id="%s" % i,
+                message={
+                    "data": f'Message {i}'.encode('utf8'),
+                    "attributes": {"type": "generated message"},
                 },
-                ReceivedMessage(),
             )
             for i in range(1, count + 1)
         ]
 
     def _generate_dicts(self, count):
-        return [MessageToDict(m) for m in self._generate_messages(count)]
+        return [ReceivedMessage.to_dict(m) for m in self._generate_messages(count)]
 
     @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook')
     def test_poke_no_messages(self, mock_hook):