You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2023/01/10 14:31:27 UTC

[airflow] branch main updated: Add deferrable ``GCSObjectExistenceSensorAsync`` (#28763)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 284cd52989 Add deferrable ``GCSObjectExistenceSensorAsync`` (#28763)
284cd52989 is described below

commit 284cd529898fbadd14308004a0b0cb6f389b4318
Author: Rajath <92...@users.noreply.github.com>
AuthorDate: Tue Jan 10 20:01:16 2023 +0530

    Add deferrable ``GCSObjectExistenceSensorAsync`` (#28763)
    
    This PR donates the following developed GCSObjectExistenceSensorAsync` in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow.
    
    `GCSObjectExistenceSensorAsync`
---
 airflow/providers/google/cloud/hooks/gcs.py        |  16 ++-
 airflow/providers/google/cloud/sensors/gcs.py      |  52 ++++++-
 airflow/providers/google/cloud/triggers/gcs.py     |  99 +++++++++++++
 .../operators/cloud/gcs.rst                        |  16 +++
 tests/providers/google/cloud/sensors/test_gcs.py   |  53 ++++++-
 tests/providers/google/cloud/triggers/test_gcs.py  | 159 +++++++++++++++++++++
 .../google/cloud/gcs/example_gcs_sensor.py         |  11 +-
 7 files changed, 402 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py
index 68e1c74986..6976724873 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -32,6 +32,8 @@ from tempfile import NamedTemporaryFile
 from typing import IO, Callable, Generator, Sequence, TypeVar, cast, overload
 from urllib.parse import urlsplit
 
+from aiohttp import ClientSession
+from gcloud.aio.storage import Storage
 from google.api_core.exceptions import NotFound
 from google.api_core.retry import Retry
 
@@ -39,11 +41,12 @@ from google.api_core.retry import Retry
 from google.cloud import storage  # type: ignore[attr-defined]
 from google.cloud.exceptions import GoogleCloudError
 from google.cloud.storage.retry import DEFAULT_RETRY
+from requests import Session
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
 from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
 from airflow.utils import timezone
 from airflow.version import version
 
@@ -1174,3 +1177,14 @@ def _parse_gcs_url(gsurl: str) -> tuple[str, str]:
     # Remove leading '/' but NOT trailing one
     blob = parsed_url.path.lstrip("/")
     return bucket, blob
+
+
+class GCSAsyncHook(GoogleBaseAsyncHook):
+    """GCSAsyncHook run on the trigger worker, inherits from GoogleBaseHookAsync"""
+
+    sync_hook_class = GCSHook
+
+    async def get_storage_client(self, session: ClientSession) -> Storage:
+        """Returns a Google Cloud Storage service object."""
+        with await self.service_file_as_context() as file:
+            return Storage(service_file=file, session=cast(Session, session))
diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py
index 8264a09a7b..a944488a06 100644
--- a/airflow/providers/google/cloud/sensors/gcs.py
+++ b/airflow/providers/google/cloud/sensors/gcs.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import os
 import textwrap
-from datetime import datetime
+from datetime import datetime, timedelta
 from typing import TYPE_CHECKING, Callable, Sequence
 
 from google.api_core.retry import Retry
@@ -28,6 +28,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
 from airflow.sensors.base import BaseSensorOperator, poke_mode_only
 
 if TYPE_CHECKING:
@@ -94,6 +95,55 @@ class GCSObjectExistenceSensor(BaseSensorOperator):
         return hook.exists(self.bucket, self.object, self.retry)
 
 
+class GCSObjectExistenceAsyncSensor(GCSObjectExistenceSensor):
+    """
+    Checks for the existence of a file in Google Cloud Storage .
+
+    :param bucket: The Google Cloud Storage bucket where the object is.
+    :param object: The name of the object to check in the Google cloud storage bucket.
+    :param google_cloud_conn_id: The connection ID to use when connecting to Google Cloud Storage.
+    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+        if any. For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    """
+
+    def execute(self, context: Context) -> None:
+        """Airflow runs this method on the worker and defers using the trigger."""
+        self.defer(
+            timeout=timedelta(seconds=self.timeout),
+            trigger=GCSBlobTrigger(
+                bucket=self.bucket,
+                object_name=self.object,
+                poke_interval=self.poke_interval,
+                google_cloud_conn_id=self.google_cloud_conn_id,
+                hook_params={
+                    "delegate_to": self.delegate_to,
+                    "impersonation_chain": self.impersonation_chain,
+                },
+            ),
+            method_name="execute_complete",
+        )
+
+    def execute_complete(self, context: Context, event: dict[str, str]) -> str:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes execution was
+        successful.
+        """
+        if event["status"] == "error":
+            raise AirflowException(event["message"])
+        self.log.info("File %s was found in bucket %s.", self.object, self.bucket)
+        return event["message"]
+
+
 def ts_function(context):
     """
     Default callback for the GoogleCloudStorageObjectUpdatedSensor. The default
diff --git a/airflow/providers/google/cloud/triggers/gcs.py b/airflow/providers/google/cloud/triggers/gcs.py
new file mode 100644
index 0000000000..32ca257eae
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/gcs.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator
+
+from aiohttp import ClientSession
+
+from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class GCSBlobTrigger(BaseTrigger):
+    """
+    A trigger that fires and it finds the requested file or folder present in the given bucket.
+
+    :param bucket: the bucket in the google cloud storage where the objects are residing.
+    :param object_name: the file or folder present in the bucket
+    :param google_cloud_conn_id: reference to the Google Connection
+    :param poke_interval: polling period in seconds to check for file/folder
+    """
+
+    def __init__(
+        self,
+        bucket: str,
+        object_name: str,
+        poke_interval: float,
+        google_cloud_conn_id: str,
+        hook_params: dict[str, Any],
+    ):
+        super().__init__()
+        self.bucket = bucket
+        self.object_name = object_name
+        self.poke_interval = poke_interval
+        self.google_cloud_conn_id: str = google_cloud_conn_id
+        self.hook_params = hook_params
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes GCSBlobTrigger arguments and classpath."""
+        return (
+            "airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger",
+            {
+                "bucket": self.bucket,
+                "object_name": self.object_name,
+                "poke_interval": self.poke_interval,
+                "google_cloud_conn_id": self.google_cloud_conn_id,
+                "hook_params": self.hook_params,
+            },
+        )
+
+    async def run(self) -> AsyncIterator["TriggerEvent"]:
+        """Simple loop until the relevant file/folder is found."""
+        try:
+            hook = self._get_async_hook()
+            while True:
+                res = await self._object_exists(
+                    hook=hook, bucket_name=self.bucket, object_name=self.object_name
+                )
+                if res == "success":
+                    yield TriggerEvent({"status": "success", "message": res})
+                await asyncio.sleep(self.poke_interval)
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e)})
+            return
+
+    def _get_async_hook(self) -> GCSAsyncHook:
+        return GCSAsyncHook(gcp_conn_id=self.google_cloud_conn_id, **self.hook_params)
+
+    async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name: str) -> str:
+        """
+        Checks for the existence of a file in Google Cloud Storage.
+
+        :param bucket_name: The Google Cloud Storage bucket where the object is.
+        :param object_name: The name of the blob_name to check in the Google cloud
+            storage bucket.
+        """
+        async with ClientSession() as s:
+            client = await hook.get_storage_client(s)
+            bucket = client.get_bucket(bucket_name)
+            object_response = await bucket.blob_exists(blob_name=object_name)
+            if object_response:
+                return "success"
+            return "pending"
diff --git a/docs/apache-airflow-providers-google/operators/cloud/gcs.rst b/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
index 2a21a37cb5..6f60cc5452 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
@@ -188,6 +188,22 @@ Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSe
     :start-after: [START howto_sensor_object_exists_task]
     :end-before: [END howto_sensor_object_exists_task]
 
+
+.. _howto/sensor:GCSObjectExistenceAsyncSensor:
+
+GCSObjectExistenceAsyncSensor
+-----------------------------
+
+Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceAsyncSensor`
+(deferrable version) if you would like to free up the worker slots while the sensor is running.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_object_exists_task_async]
+    :end-before: [END howto_sensor_object_exists_task_async]
+
+
 .. _howto/sensor:GCSObjectsWithPrefixExistenceSensor:
 
 GCSObjectsWithPrefixExistenceSensor
diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py
index bc242ce509..bf586dfe1c 100644
--- a/tests/providers/google/cloud/sensors/test_gcs.py
+++ b/tests/providers/google/cloud/sensors/test_gcs.py
@@ -24,15 +24,17 @@ import pendulum
 import pytest
 from google.cloud.storage.retry import DEFAULT_RETRY
 
-from airflow.exceptions import AirflowSensorTimeout
+from airflow.exceptions import AirflowSensorTimeout, TaskDeferred
 from airflow.models.dag import DAG, AirflowException
 from airflow.providers.google.cloud.sensors.gcs import (
+    GCSObjectExistenceAsyncSensor,
     GCSObjectExistenceSensor,
     GCSObjectsWithPrefixExistenceSensor,
     GCSObjectUpdateSensor,
     GCSUploadSessionCompleteSensor,
     ts_function,
 )
+from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
 
 TEST_BUCKET = "TEST_BUCKET"
 
@@ -53,6 +55,15 @@ DEFAULT_DATE = datetime(2015, 1, 1)
 MOCK_DATE_ARRAY = [datetime(2019, 2, 24, 12, 0, 0) - i * timedelta(seconds=10) for i in range(25)]
 
 
+@pytest.fixture()
+def context():
+    """
+    Creates an empty context.
+    """
+    context = {"data_interval_end": datetime.utcnow()}
+    yield context
+
+
 def next_time_side_effect():
     """
     This each time this is called mock a time 10 seconds later
@@ -88,6 +99,46 @@ class TestGoogleCloudStorageObjectSensor(TestCase):
         mock_hook.return_value.exists.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, DEFAULT_RETRY)
 
 
+class TestGoogleCloudStorageObjectSensorAsync(TestCase):
+    def test_gcs_object_existence_sensor_async(self):
+        """
+        Asserts that a task is deferred and a GCSBlobTrigger will be fired
+        when the GCSObjectExistenceAsyncSensor is executed.
+        """
+        task = GCSObjectExistenceAsyncSensor(
+            task_id="task-id",
+            bucket=TEST_BUCKET,
+            object=TEST_OBJECT,
+            google_cloud_conn_id=TEST_GCP_CONN_ID,
+        )
+        with pytest.raises(TaskDeferred) as exc:
+            task.execute(context)
+        assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not a GCSBlobTrigger"
+
+    def test_gcs_object_existence_sensor_async_execute_failure(self):
+        """Tests that an AirflowException is raised in case of error event"""
+        task = GCSObjectExistenceAsyncSensor(
+            task_id="task-id",
+            bucket=TEST_BUCKET,
+            object=TEST_OBJECT,
+            google_cloud_conn_id=TEST_GCP_CONN_ID,
+        )
+        with pytest.raises(AirflowException):
+            task.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+    def test_gcs_object_existence_sensor_async_execute_complete(self):
+        """Asserts that logging occurs as expected"""
+        task = GCSObjectExistenceAsyncSensor(
+            task_id="task-id",
+            bucket=TEST_BUCKET,
+            object=TEST_OBJECT,
+            google_cloud_conn_id=TEST_GCP_CONN_ID,
+        )
+        with mock.patch.object(task.log, "info") as mock_log_info:
+            task.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
+        mock_log_info.assert_called_with("File %s was found in bucket %s.", TEST_OBJECT, TEST_BUCKET)
+
+
 class TestTsFunction(TestCase):
     def test_should_support_datetime(self):
         context = {
diff --git a/tests/providers/google/cloud/triggers/test_gcs.py b/tests/providers/google/cloud/triggers/test_gcs.py
new file mode 100644
index 0000000000..f7d735d1cf
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_gcs.py
@@ -0,0 +1,159 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import sys
+
+import pytest
+from gcloud.aio.storage import Bucket, Storage
+
+from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
+from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+    from asynctest import mock
+    from asynctest.mock import CoroutineMock as AsyncMock
+else:
+    from unittest import mock
+    from unittest.mock import AsyncMock
+
+TEST_BUCKET = "TEST_BUCKET"
+TEST_OBJECT = "TEST_OBJECT"
+TEST_PREFIX = "TEST_PREFIX"
+TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID"
+TEST_POLLING_INTERVAL = 3.0
+TEST_HOOK_PARAMS = {}
+
+
+def test_gcs_blob_trigger_serialization():
+    """
+    Asserts that the GCSBlobTrigger correctly serializes its arguments
+    and classpath.
+    """
+    trigger = GCSBlobTrigger(
+        TEST_BUCKET,
+        TEST_OBJECT,
+        TEST_POLLING_INTERVAL,
+        TEST_GCP_CONN_ID,
+        TEST_HOOK_PARAMS,
+    )
+    classpath, kwargs = trigger.serialize()
+    assert classpath == "airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger"
+    assert kwargs == {
+        "bucket": TEST_BUCKET,
+        "object_name": TEST_OBJECT,
+        "poke_interval": TEST_POLLING_INTERVAL,
+        "google_cloud_conn_id": TEST_GCP_CONN_ID,
+        "hook_params": TEST_HOOK_PARAMS,
+    }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger._object_exists")
+async def test_gcs_blob_trigger_success(mock_object_exists):
+    """
+    Tests that the GCSBlobTrigger is success case
+    """
+    mock_object_exists.return_value = "success"
+
+    trigger = GCSBlobTrigger(
+        TEST_BUCKET,
+        TEST_OBJECT,
+        TEST_POLLING_INTERVAL,
+        TEST_GCP_CONN_ID,
+        TEST_HOOK_PARAMS,
+    )
+
+    generator = trigger.run()
+    actual = await generator.asend(None)
+    assert TriggerEvent({"status": "success", "message": "success"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger._object_exists")
+async def test_gcs_blob_trigger_pending(mock_object_exists):
+    """
+    Test that GCSBlobTrigger is in loop if file isn't found.
+    """
+    mock_object_exists.return_value = "pending"
+
+    trigger = GCSBlobTrigger(
+        TEST_BUCKET,
+        TEST_OBJECT,
+        TEST_POLLING_INTERVAL,
+        TEST_GCP_CONN_ID,
+        TEST_HOOK_PARAMS,
+    )
+    task = asyncio.create_task(trigger.run().__anext__())
+    await asyncio.sleep(0.5)
+
+    # TriggerEvent was not returned
+    assert task.done() is False
+    asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger._object_exists")
+async def test_gcs_blob_trigger_exception(mock_object_exists):
+    """
+    Tests the GCSBlobTrigger does fire if there is an exception.
+    """
+    mock_object_exists.side_effect = AsyncMock(side_effect=Exception("Test exception"))
+    trigger = GCSBlobTrigger(
+        bucket=TEST_BUCKET,
+        object_name=TEST_OBJECT,
+        poke_interval=TEST_POLLING_INTERVAL,
+        google_cloud_conn_id=TEST_GCP_CONN_ID,
+        hook_params=TEST_HOOK_PARAMS,
+    )
+    task = [i async for i in trigger.run()]
+    assert len(task) == 1
+    assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "exists,response",
+    [
+        (True, "success"),
+        (False, "pending"),
+    ],
+)
+async def test_object_exists(exists, response):
+    """
+    Tests to check if a particular object in Google Cloud Storage
+    is found or not
+    """
+    hook = AsyncMock(GCSAsyncHook)
+    storage = AsyncMock(Storage)
+    hook.get_storage_client.return_value = storage
+    bucket = AsyncMock(Bucket)
+    storage.get_bucket.return_value = bucket
+    bucket.blob_exists.return_value = exists
+    trigger = GCSBlobTrigger(
+        bucket=TEST_BUCKET,
+        object_name=TEST_OBJECT,
+        poke_interval=TEST_POLLING_INTERVAL,
+        google_cloud_conn_id=TEST_GCP_CONN_ID,
+        hook_params=TEST_HOOK_PARAMS,
+    )
+    res = await trigger._object_exists(hook, TEST_BUCKET, TEST_OBJECT)
+    assert res == response
+    bucket.blob_exists.assert_called_once_with(blob_name=TEST_OBJECT)
diff --git a/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py b/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
index aff018d31d..ef48fa28db 100644
--- a/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
+++ b/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
@@ -29,6 +29,7 @@ from airflow.models.baseoperator import chain
 from airflow.operators.bash import BashOperator
 from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
 from airflow.providers.google.cloud.sensors.gcs import (
+    GCSObjectExistenceAsyncSensor,
     GCSObjectExistenceSensor,
     GCSObjectsWithPrefixExistenceSensor,
     GCSObjectUpdateSensor,
@@ -115,6 +116,14 @@ with models.DAG(
     )
     # [END howto_sensor_object_exists_task]
 
+    # [START howto_sensor_object_exists_task_async]
+    gcs_object_exists_async = GCSObjectExistenceAsyncSensor(
+        bucket=BUCKET_NAME,
+        object=FILE_NAME,
+        task_id="gcs_object_exists_task_async",
+    )
+    # [END howto_sensor_object_exists_task_async]
+
     # [START howto_sensor_object_with_prefix_exists_task]
     gcs_object_with_prefix_exists = GCSObjectsWithPrefixExistenceSensor(
         bucket=BUCKET_NAME,
@@ -135,7 +144,7 @@ with models.DAG(
         sleep,
         upload_file,
         # TEST BODY
-        [gcs_object_exists, gcs_object_with_prefix_exists],
+        [gcs_object_exists, gcs_object_exists_async, gcs_object_with_prefix_exists],
         # TEST TEARDOWN
         delete_bucket,
     )