You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2023/01/10 14:31:27 UTC
[airflow] branch main updated: Add deferrable ``GCSObjectExistenceSensorAsync`` (#28763)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 284cd52989 Add deferrable ``GCSObjectExistenceSensorAsync`` (#28763)
284cd52989 is described below
commit 284cd529898fbadd14308004a0b0cb6f389b4318
Author: Rajath <92...@users.noreply.github.com>
AuthorDate: Tue Jan 10 20:01:16 2023 +0530
Add deferrable ``GCSObjectExistenceSensorAsync`` (#28763)
This PR donates the following developed GCSObjectExistenceSensorAsync` in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow.
`GCSObjectExistenceSensorAsync`
---
airflow/providers/google/cloud/hooks/gcs.py | 16 ++-
airflow/providers/google/cloud/sensors/gcs.py | 52 ++++++-
airflow/providers/google/cloud/triggers/gcs.py | 99 +++++++++++++
.../operators/cloud/gcs.rst | 16 +++
tests/providers/google/cloud/sensors/test_gcs.py | 53 ++++++-
tests/providers/google/cloud/triggers/test_gcs.py | 159 +++++++++++++++++++++
.../google/cloud/gcs/example_gcs_sensor.py | 11 +-
7 files changed, 402 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py
index 68e1c74986..6976724873 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -32,6 +32,8 @@ from tempfile import NamedTemporaryFile
from typing import IO, Callable, Generator, Sequence, TypeVar, cast, overload
from urllib.parse import urlsplit
+from aiohttp import ClientSession
+from gcloud.aio.storage import Storage
from google.api_core.exceptions import NotFound
from google.api_core.retry import Retry
@@ -39,11 +41,12 @@ from google.api_core.retry import Retry
from google.cloud import storage # type: ignore[attr-defined]
from google.cloud.exceptions import GoogleCloudError
from google.cloud.storage.retry import DEFAULT_RETRY
+from requests import Session
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.utils import timezone
from airflow.version import version
@@ -1174,3 +1177,14 @@ def _parse_gcs_url(gsurl: str) -> tuple[str, str]:
# Remove leading '/' but NOT trailing one
blob = parsed_url.path.lstrip("/")
return bucket, blob
+
+
+class GCSAsyncHook(GoogleBaseAsyncHook):
+ """GCSAsyncHook run on the trigger worker, inherits from GoogleBaseHookAsync"""
+
+ sync_hook_class = GCSHook
+
+ async def get_storage_client(self, session: ClientSession) -> Storage:
+ """Returns a Google Cloud Storage service object."""
+ with await self.service_file_as_context() as file:
+ return Storage(service_file=file, session=cast(Session, session))
diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py
index 8264a09a7b..a944488a06 100644
--- a/airflow/providers/google/cloud/sensors/gcs.py
+++ b/airflow/providers/google/cloud/sensors/gcs.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import os
import textwrap
-from datetime import datetime
+from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Callable, Sequence
from google.api_core.retry import Retry
@@ -28,6 +28,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
if TYPE_CHECKING:
@@ -94,6 +95,55 @@ class GCSObjectExistenceSensor(BaseSensorOperator):
return hook.exists(self.bucket, self.object, self.retry)
+class GCSObjectExistenceAsyncSensor(GCSObjectExistenceSensor):
+ """
+ Checks for the existence of a file in Google Cloud Storage .
+
+ :param bucket: The Google Cloud Storage bucket where the object is.
+ :param object: The name of the object to check in the Google cloud storage bucket.
+ :param google_cloud_conn_id: The connection ID to use when connecting to Google Cloud Storage.
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ def execute(self, context: Context) -> None:
+ """Airflow runs this method on the worker and defers using the trigger."""
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=GCSBlobTrigger(
+ bucket=self.bucket,
+ object_name=self.object,
+ poke_interval=self.poke_interval,
+ google_cloud_conn_id=self.google_cloud_conn_id,
+ hook_params={
+ "delegate_to": self.delegate_to,
+ "impersonation_chain": self.impersonation_chain,
+ },
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, str]) -> str:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info("File %s was found in bucket %s.", self.object, self.bucket)
+ return event["message"]
+
+
def ts_function(context):
"""
Default callback for the GoogleCloudStorageObjectUpdatedSensor. The default
diff --git a/airflow/providers/google/cloud/triggers/gcs.py b/airflow/providers/google/cloud/triggers/gcs.py
new file mode 100644
index 0000000000..32ca257eae
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/gcs.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator
+
+from aiohttp import ClientSession
+
+from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class GCSBlobTrigger(BaseTrigger):
+ """
+ A trigger that fires and it finds the requested file or folder present in the given bucket.
+
+ :param bucket: the bucket in the google cloud storage where the objects are residing.
+ :param object_name: the file or folder present in the bucket
+ :param google_cloud_conn_id: reference to the Google Connection
+ :param poke_interval: polling period in seconds to check for file/folder
+ """
+
+ def __init__(
+ self,
+ bucket: str,
+ object_name: str,
+ poke_interval: float,
+ google_cloud_conn_id: str,
+ hook_params: dict[str, Any],
+ ):
+ super().__init__()
+ self.bucket = bucket
+ self.object_name = object_name
+ self.poke_interval = poke_interval
+ self.google_cloud_conn_id: str = google_cloud_conn_id
+ self.hook_params = hook_params
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes GCSBlobTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger",
+ {
+ "bucket": self.bucket,
+ "object_name": self.object_name,
+ "poke_interval": self.poke_interval,
+ "google_cloud_conn_id": self.google_cloud_conn_id,
+ "hook_params": self.hook_params,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]:
+ """Simple loop until the relevant file/folder is found."""
+ try:
+ hook = self._get_async_hook()
+ while True:
+ res = await self._object_exists(
+ hook=hook, bucket_name=self.bucket, object_name=self.object_name
+ )
+ if res == "success":
+ yield TriggerEvent({"status": "success", "message": res})
+ await asyncio.sleep(self.poke_interval)
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+ def _get_async_hook(self) -> GCSAsyncHook:
+ return GCSAsyncHook(gcp_conn_id=self.google_cloud_conn_id, **self.hook_params)
+
+ async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name: str) -> str:
+ """
+ Checks for the existence of a file in Google Cloud Storage.
+
+ :param bucket_name: The Google Cloud Storage bucket where the object is.
+ :param object_name: The name of the blob_name to check in the Google cloud
+ storage bucket.
+ """
+ async with ClientSession() as s:
+ client = await hook.get_storage_client(s)
+ bucket = client.get_bucket(bucket_name)
+ object_response = await bucket.blob_exists(blob_name=object_name)
+ if object_response:
+ return "success"
+ return "pending"
diff --git a/docs/apache-airflow-providers-google/operators/cloud/gcs.rst b/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
index 2a21a37cb5..6f60cc5452 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
@@ -188,6 +188,22 @@ Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSe
:start-after: [START howto_sensor_object_exists_task]
:end-before: [END howto_sensor_object_exists_task]
+
+.. _howto/sensor:GCSObjectExistenceAsyncSensor:
+
+GCSObjectExistenceAsyncSensor
+-----------------------------
+
+Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceAsyncSensor`
+(deferrable version) if you would like to free up the worker slots while the sensor is running.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_object_exists_task_async]
+ :end-before: [END howto_sensor_object_exists_task_async]
+
+
.. _howto/sensor:GCSObjectsWithPrefixExistenceSensor:
GCSObjectsWithPrefixExistenceSensor
diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py
index bc242ce509..bf586dfe1c 100644
--- a/tests/providers/google/cloud/sensors/test_gcs.py
+++ b/tests/providers/google/cloud/sensors/test_gcs.py
@@ -24,15 +24,17 @@ import pendulum
import pytest
from google.cloud.storage.retry import DEFAULT_RETRY
-from airflow.exceptions import AirflowSensorTimeout
+from airflow.exceptions import AirflowSensorTimeout, TaskDeferred
from airflow.models.dag import DAG, AirflowException
from airflow.providers.google.cloud.sensors.gcs import (
+ GCSObjectExistenceAsyncSensor,
GCSObjectExistenceSensor,
GCSObjectsWithPrefixExistenceSensor,
GCSObjectUpdateSensor,
GCSUploadSessionCompleteSensor,
ts_function,
)
+from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
TEST_BUCKET = "TEST_BUCKET"
@@ -53,6 +55,15 @@ DEFAULT_DATE = datetime(2015, 1, 1)
MOCK_DATE_ARRAY = [datetime(2019, 2, 24, 12, 0, 0) - i * timedelta(seconds=10) for i in range(25)]
+@pytest.fixture()
+def context():
+ """
+ Creates an empty context.
+ """
+ context = {"data_interval_end": datetime.utcnow()}
+ yield context
+
+
def next_time_side_effect():
"""
This each time this is called mock a time 10 seconds later
@@ -88,6 +99,46 @@ class TestGoogleCloudStorageObjectSensor(TestCase):
mock_hook.return_value.exists.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, DEFAULT_RETRY)
+class TestGoogleCloudStorageObjectSensorAsync(TestCase):
+ def test_gcs_object_existence_sensor_async(self):
+ """
+ Asserts that a task is deferred and a GCSBlobTrigger will be fired
+ when the GCSObjectExistenceAsyncSensor is executed.
+ """
+ task = GCSObjectExistenceAsyncSensor(
+ task_id="task-id",
+ bucket=TEST_BUCKET,
+ object=TEST_OBJECT,
+ google_cloud_conn_id=TEST_GCP_CONN_ID,
+ )
+ with pytest.raises(TaskDeferred) as exc:
+ task.execute(context)
+ assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not a GCSBlobTrigger"
+
+ def test_gcs_object_existence_sensor_async_execute_failure(self):
+ """Tests that an AirflowException is raised in case of error event"""
+ task = GCSObjectExistenceAsyncSensor(
+ task_id="task-id",
+ bucket=TEST_BUCKET,
+ object=TEST_OBJECT,
+ google_cloud_conn_id=TEST_GCP_CONN_ID,
+ )
+ with pytest.raises(AirflowException):
+ task.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+ def test_gcs_object_existence_sensor_async_execute_complete(self):
+ """Asserts that logging occurs as expected"""
+ task = GCSObjectExistenceAsyncSensor(
+ task_id="task-id",
+ bucket=TEST_BUCKET,
+ object=TEST_OBJECT,
+ google_cloud_conn_id=TEST_GCP_CONN_ID,
+ )
+ with mock.patch.object(task.log, "info") as mock_log_info:
+ task.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
+ mock_log_info.assert_called_with("File %s was found in bucket %s.", TEST_OBJECT, TEST_BUCKET)
+
+
class TestTsFunction(TestCase):
def test_should_support_datetime(self):
context = {
diff --git a/tests/providers/google/cloud/triggers/test_gcs.py b/tests/providers/google/cloud/triggers/test_gcs.py
new file mode 100644
index 0000000000..f7d735d1cf
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_gcs.py
@@ -0,0 +1,159 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import sys
+
+import pytest
+from gcloud.aio.storage import Bucket, Storage
+
+from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
+from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+ from asynctest import mock
+ from asynctest.mock import CoroutineMock as AsyncMock
+else:
+ from unittest import mock
+ from unittest.mock import AsyncMock
+
+TEST_BUCKET = "TEST_BUCKET"
+TEST_OBJECT = "TEST_OBJECT"
+TEST_PREFIX = "TEST_PREFIX"
+TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID"
+TEST_POLLING_INTERVAL = 3.0
+TEST_HOOK_PARAMS = {}
+
+
+def test_gcs_blob_trigger_serialization():
+ """
+ Asserts that the GCSBlobTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = GCSBlobTrigger(
+ TEST_BUCKET,
+ TEST_OBJECT,
+ TEST_POLLING_INTERVAL,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger"
+ assert kwargs == {
+ "bucket": TEST_BUCKET,
+ "object_name": TEST_OBJECT,
+ "poke_interval": TEST_POLLING_INTERVAL,
+ "google_cloud_conn_id": TEST_GCP_CONN_ID,
+ "hook_params": TEST_HOOK_PARAMS,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger._object_exists")
+async def test_gcs_blob_trigger_success(mock_object_exists):
+ """
+ Tests that the GCSBlobTrigger is success case
+ """
+ mock_object_exists.return_value = "success"
+
+ trigger = GCSBlobTrigger(
+ TEST_BUCKET,
+ TEST_OBJECT,
+ TEST_POLLING_INTERVAL,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "success", "message": "success"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger._object_exists")
+async def test_gcs_blob_trigger_pending(mock_object_exists):
+ """
+ Test that GCSBlobTrigger is in loop if file isn't found.
+ """
+ mock_object_exists.return_value = "pending"
+
+ trigger = GCSBlobTrigger(
+ TEST_BUCKET,
+ TEST_OBJECT,
+ TEST_POLLING_INTERVAL,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger._object_exists")
+async def test_gcs_blob_trigger_exception(mock_object_exists):
+ """
+ Tests the GCSBlobTrigger does fire if there is an exception.
+ """
+ mock_object_exists.side_effect = AsyncMock(side_effect=Exception("Test exception"))
+ trigger = GCSBlobTrigger(
+ bucket=TEST_BUCKET,
+ object_name=TEST_OBJECT,
+ poke_interval=TEST_POLLING_INTERVAL,
+ google_cloud_conn_id=TEST_GCP_CONN_ID,
+ hook_params=TEST_HOOK_PARAMS,
+ )
+ task = [i async for i in trigger.run()]
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "exists,response",
+ [
+ (True, "success"),
+ (False, "pending"),
+ ],
+)
+async def test_object_exists(exists, response):
+ """
+ Tests to check if a particular object in Google Cloud Storage
+ is found or not
+ """
+ hook = AsyncMock(GCSAsyncHook)
+ storage = AsyncMock(Storage)
+ hook.get_storage_client.return_value = storage
+ bucket = AsyncMock(Bucket)
+ storage.get_bucket.return_value = bucket
+ bucket.blob_exists.return_value = exists
+ trigger = GCSBlobTrigger(
+ bucket=TEST_BUCKET,
+ object_name=TEST_OBJECT,
+ poke_interval=TEST_POLLING_INTERVAL,
+ google_cloud_conn_id=TEST_GCP_CONN_ID,
+ hook_params=TEST_HOOK_PARAMS,
+ )
+ res = await trigger._object_exists(hook, TEST_BUCKET, TEST_OBJECT)
+ assert res == response
+ bucket.blob_exists.assert_called_once_with(blob_name=TEST_OBJECT)
diff --git a/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py b/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
index aff018d31d..ef48fa28db 100644
--- a/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
+++ b/tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
@@ -29,6 +29,7 @@ from airflow.models.baseoperator import chain
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.sensors.gcs import (
+ GCSObjectExistenceAsyncSensor,
GCSObjectExistenceSensor,
GCSObjectsWithPrefixExistenceSensor,
GCSObjectUpdateSensor,
@@ -115,6 +116,14 @@ with models.DAG(
)
# [END howto_sensor_object_exists_task]
+ # [START howto_sensor_object_exists_task_async]
+ gcs_object_exists_async = GCSObjectExistenceAsyncSensor(
+ bucket=BUCKET_NAME,
+ object=FILE_NAME,
+ task_id="gcs_object_exists_task_async",
+ )
+ # [END howto_sensor_object_exists_task_async]
+
# [START howto_sensor_object_with_prefix_exists_task]
gcs_object_with_prefix_exists = GCSObjectsWithPrefixExistenceSensor(
bucket=BUCKET_NAME,
@@ -135,7 +144,7 @@ with models.DAG(
sleep,
upload_file,
# TEST BODY
- [gcs_object_exists, gcs_object_with_prefix_exists],
+ [gcs_object_exists, gcs_object_exists_async, gcs_object_with_prefix_exists],
# TEST TEARDOWN
delete_bucket,
)