You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/06 13:02:42 UTC

[airflow] branch main updated: Cloud Storage assets & StorageLink update (#23865)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 80c1ce76e1 Cloud Storage assets & StorageLink update (#23865)
80c1ce76e1 is described below

commit 80c1ce76e19d363916f2253cdd536372f6a43aee
Author: Wojciech Januszek <wj...@sigma.ug.edu.pl>
AuthorDate: Mon Jun 6 15:02:35 2022 +0200

    Cloud Storage assets & StorageLink update (#23865)
    
    Co-authored-by: Wojciech Januszek <ja...@google.com>
---
 .../google/cloud/operators/dataproc_metastore.py   |  2 +-
 .../providers/google/cloud/operators/datastore.py  |  1 +
 airflow/providers/google/cloud/operators/gcs.py    | 57 ++++++++++++++++++++++
 airflow/providers/google/common/links/storage.py   |  4 +-
 tests/providers/google/cloud/operators/test_gcs.py | 14 +++---
 5 files changed, 69 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc_metastore.py b/airflow/providers/google/cloud/operators/dataproc_metastore.py
index d0ca4a5f28..4bdf519d2f 100644
--- a/airflow/providers/google/cloud/operators/dataproc_metastore.py
+++ b/airflow/providers/google/cloud/operators/dataproc_metastore.py
@@ -711,7 +711,7 @@ class DataprocMetastoreExportMetadataOperator(BaseOperator):
 
         DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_EXPORT_LINK)
         uri = self._get_uri_from_destination(MetadataExport.to_dict(metadata_export)["destination_gcs_uri"])
-        StorageLink.persist(context=context, task_instance=self, uri=uri)
+        StorageLink.persist(context=context, task_instance=self, uri=uri, project_id=self.project_id)
         return MetadataExport.to_dict(metadata_export)
 
     def _get_uri_from_destination(self, destination_uri: str):
diff --git a/airflow/providers/google/cloud/operators/datastore.py b/airflow/providers/google/cloud/operators/datastore.py
index 8a92665e36..db08d53ba7 100644
--- a/airflow/providers/google/cloud/operators/datastore.py
+++ b/airflow/providers/google/cloud/operators/datastore.py
@@ -140,6 +140,7 @@ class CloudDatastoreExportEntitiesOperator(BaseOperator):
             context=context,
             task_instance=self,
             uri=f"{self.bucket}/{result['response']['outputUrl'].split('/')[3]}",
+            project_id=self.project_id or ds_hook.project_id,
         )
         return result
 
diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py
index 27cc6f79bd..bfc2d96919 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -35,6 +35,7 @@ from pendulum.datetime import DateTime
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.common.links.storage import FileDetailsLink, StorageLink
 from airflow.utils import timezone
 
 
@@ -107,6 +108,7 @@ class GCSCreateBucketOperator(BaseOperator):
         'impersonation_chain',
     )
     ui_color = '#f0eee4'
+    operator_extra_links = (StorageLink(),)
 
     def __init__(
         self,
@@ -139,6 +141,12 @@ class GCSCreateBucketOperator(BaseOperator):
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
+        StorageLink.persist(
+            context=context,
+            task_instance=self,
+            uri=self.bucket_name,
+            project_id=self.project_id or hook.project_id,
+        )
         try:
             hook.create_bucket(
                 bucket_name=self.bucket_name,
@@ -200,6 +208,8 @@ class GCSListObjectsOperator(BaseOperator):
 
     ui_color = '#f0eee4'
 
+    operator_extra_links = (StorageLink(),)
+
     def __init__(
         self,
         *,
@@ -234,6 +244,13 @@ class GCSListObjectsOperator(BaseOperator):
             self.prefix,
         )
 
+        StorageLink.persist(
+            context=context,
+            task_instance=self,
+            uri=self.bucket,
+            project_id=hook.project_id,
+        )
+
         return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
 
 
@@ -346,6 +363,7 @@ class GCSBucketCreateAclEntryOperator(BaseOperator):
         'impersonation_chain',
     )
     # [END gcs_bucket_create_acl_template_fields]
+    operator_extra_links = (StorageLink(),)
 
     def __init__(
         self,
@@ -371,6 +389,12 @@ class GCSBucketCreateAclEntryOperator(BaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        StorageLink.persist(
+            context=context,
+            task_instance=self,
+            uri=self.bucket,
+            project_id=hook.project_id,
+        )
         hook.insert_bucket_acl(
             bucket_name=self.bucket, entity=self.entity, role=self.role, user_project=self.user_project
         )
@@ -418,6 +442,7 @@ class GCSObjectCreateAclEntryOperator(BaseOperator):
         'impersonation_chain',
     )
     # [END gcs_object_create_acl_template_fields]
+    operator_extra_links = (FileDetailsLink(),)
 
     def __init__(
         self,
@@ -447,6 +472,12 @@ class GCSObjectCreateAclEntryOperator(BaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+        FileDetailsLink.persist(
+            context=context,
+            task_instance=self,
+            uri=f"{self.bucket}/{self.object_name}",
+            project_id=hook.project_id,
+        )
         hook.insert_object_acl(
             bucket_name=self.bucket,
             object_name=self.object_name,
@@ -498,6 +529,7 @@ class GCSFileTransformOperator(BaseOperator):
         'transform_script',
         'impersonation_chain',
     )
+    operator_extra_links = (FileDetailsLink(),)
 
     def __init__(
         self,
@@ -549,6 +581,12 @@ class GCSFileTransformOperator(BaseOperator):
             self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name)
 
             self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object)
+            FileDetailsLink.persist(
+                context=context,
+                task_instance=self,
+                uri=f"{self.destination_bucket}/{self.destination_object}",
+                project_id=hook.project_id,
+            )
             hook.upload(
                 bucket_name=self.destination_bucket,
                 object_name=self.destination_object,
@@ -628,6 +666,7 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
         'source_impersonation_chain',
         'destination_impersonation_chain',
     )
+    operator_extra_links = (StorageLink(),)
 
     @staticmethod
     def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]:
@@ -718,6 +757,12 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
             gcp_conn_id=self.destination_gcp_conn_id,
             impersonation_chain=self.destination_impersonation_chain,
         )
+        StorageLink.persist(
+            context=context,
+            task_instance=self,
+            uri=self.destination_bucket,
+            project_id=destination_hook.project_id,
+        )
 
         # Fetch list of files.
         blobs_to_transform = source_hook.list_by_timespan(
@@ -904,6 +949,7 @@ class GCSSynchronizeBucketsOperator(BaseOperator):
         'delegate_to',
         'impersonation_chain',
     )
+    operator_extra_links = (StorageLink(),)
 
     def __init__(
         self,
@@ -938,6 +984,12 @@ class GCSSynchronizeBucketsOperator(BaseOperator):
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
+        StorageLink.persist(
+            context=context,
+            task_instance=self,
+            uri=self._get_uri(self.destination_bucket, self.destination_object),
+            project_id=hook.project_id,
+        )
         hook.sync(
             source_bucket=self.source_bucket,
             destination_bucket=self.destination_bucket,
@@ -947,3 +999,8 @@ class GCSSynchronizeBucketsOperator(BaseOperator):
             delete_extra_files=self.delete_extra_files,
             allow_overwrite=self.allow_overwrite,
         )
+
+    def _get_uri(self, gcs_bucket: str, gcs_object: Optional[str]) -> str:
+        if gcs_object and gcs_object[-1] == "/":
+            gcs_object = gcs_object[:-1]
+        return f"{gcs_bucket}/{gcs_object}" if gcs_object else gcs_bucket
diff --git a/airflow/providers/google/common/links/storage.py b/airflow/providers/google/common/links/storage.py
index 7934d95d33..013dcc25f9 100644
--- a/airflow/providers/google/common/links/storage.py
+++ b/airflow/providers/google/common/links/storage.py
@@ -36,11 +36,11 @@ class StorageLink(BaseGoogleLink):
     format_str = GCS_STORAGE_LINK
 
     @staticmethod
-    def persist(context: "Context", task_instance, uri: str):
+    def persist(context: "Context", task_instance, uri: str, project_id: Optional[str]):
         task_instance.xcom_push(
             context=context,
             key=StorageLink.key,
-            value={"uri": uri, "project_id": task_instance.project_id},
+            value={"uri": uri, "project_id": project_id},
         )
 
 
diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py
index cac11ccf03..3d6cba0374 100644
--- a/tests/providers/google/cloud/operators/test_gcs.py
+++ b/tests/providers/google/cloud/operators/test_gcs.py
@@ -57,7 +57,7 @@ class TestGoogleCloudStorageCreateBucket(unittest.TestCase):
             project_id=TEST_PROJECT,
         )
 
-        operator.execute(None)
+        operator.execute(context=mock.MagicMock())
         mock_hook.return_value.create_bucket.assert_called_once_with(
             bucket_name=TEST_BUCKET,
             storage_class="MULTI_REGIONAL",
@@ -78,7 +78,7 @@ class TestGoogleCloudStorageAcl(unittest.TestCase):
             user_project="test-user-project",
             task_id="id",
         )
-        operator.execute(None)
+        operator.execute(context=mock.MagicMock())
         mock_hook.return_value.insert_bucket_acl.assert_called_once_with(
             bucket_name="test-bucket",
             entity="test-entity",
@@ -97,7 +97,7 @@ class TestGoogleCloudStorageAcl(unittest.TestCase):
             user_project="test-user-project",
             task_id="id",
         )
-        operator.execute(None)
+        operator.execute(context=mock.MagicMock())
         mock_hook.return_value.insert_object_acl.assert_called_once_with(
             bucket_name="test-bucket",
             object_name="test-object",
@@ -148,7 +148,7 @@ class TestGoogleCloudStorageListOperator(unittest.TestCase):
             task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
         )
 
-        files = operator.execute(None)
+        files = operator.execute(context=mock.MagicMock())
         mock_hook.return_value.list.assert_called_once_with(
             bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
         )
@@ -197,7 +197,7 @@ class TestGCSFileTransformOperator(unittest.TestCase):
             destination_bucket=destination_bucket,
             transform_script=transform_script,
         )
-        op.execute(None)
+        op.execute(context=mock.MagicMock())
 
         mock_hook.return_value.download.assert_called_once_with(
             bucket_name=source_bucket, object_name=source_object, filename=source
@@ -273,9 +273,11 @@ class TestGCSTimeSpanFileTransformOperator(unittest.TestCase):
         timespan_end = timespan_start + timedelta(hours=1)
         mock_dag = mock.Mock()
         mock_dag.following_schedule = lambda x: x + timedelta(hours=1)
+        mock_ti = mock.Mock()
         context = dict(
             execution_date=timespan_start,
             dag=mock_dag,
+            ti=mock_ti,
         )
 
         mock_tempdir.return_value.__enter__.side_effect = [source, destination]
@@ -397,7 +399,7 @@ class TestGoogleCloudStorageSync(unittest.TestCase):
             delegate_to="DELEGATE_TO",
             impersonation_chain=IMPERSONATION_CHAIN,
         )
-        task.execute({})
+        task.execute(context=mock.MagicMock())
         mock_hook.assert_called_once_with(
             gcp_conn_id='GCP_CONN_ID',
             delegate_to='DELEGATE_TO',