You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/02/26 19:09:27 UTC

[airflow] branch main updated: Add a new param for BigQuery operators to support additional actions when resource exists (#29394)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a5adb87ab4 Add a new param for BigQuery operators to support additional actions when resource exists (#29394)
a5adb87ab4 is described below

commit a5adb87ab4ee537eb37ef31aba755b40f6f29a1e
Author: Hussein Awala <ho...@gmail.com>
AuthorDate: Sun Feb 26 20:09:08 2023 +0100

    Add a new param for BigQuery operators to support additional actions when resource exists (#29394)
    
    * Add a new param to support additional actions when resource exists and depracte old one
    ---------
    
    Co-authored-by: eladkal <45...@users.noreply.github.com>
---
 .../providers/google/cloud/operators/bigquery.py   | 101 ++++++++++++++-------
 .../google/cloud/operators/test_bigquery.py        |  71 ++++++++++++++-
 2 files changed, 140 insertions(+), 32 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 23144c1c11..75b8c02a0d 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -28,7 +28,7 @@ from google.api_core.exceptions import Conflict
 from google.api_core.retry import Retry
 from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.models import BaseOperator, BaseOperatorLink
 from airflow.models.xcom import XCom
 from airflow.providers.common.sql.operators.sql import (
@@ -68,6 +68,15 @@ class BigQueryUIColors(enum.Enum):
     DATASET = "#5F86FF"
 
 
+class IfExistAction(enum.Enum):
+    """Action to take if the resource exist"""
+
+    IGNORE = "ignore"
+    LOG = "log"
+    FAIL = "fail"
+    SKIP = "skip"
+
+
 class BigQueryConsoleLink(BaseOperatorLink):
     """Helper class for constructing BigQuery link."""
 
@@ -248,7 +257,9 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
         if not records:
             raise AirflowException("The query returned empty results")
         elif not all(bool(r) for r in records):
-            self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
+            self._raise_exception(  # type: ignore[attr-defined]
+                f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}"
+            )
         self.log.info("Record: %s", event["records"])
         self.log.info("Success.")
 
@@ -773,9 +784,6 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
     :param selected_fields: List of fields to return (comma-separated). If
         unspecified, all fields are returned.
     :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
-    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
-        if any. For this to work, the service account making the request must have
-        domain-wide delegation enabled. Deprecated.
     :param location: The location used for the operation.
     :param impersonation_chain: Optional service account to impersonate using short-term
         credentials, or chained list of accounts required to get the access_token
@@ -786,6 +794,9 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         Service Account Token Creator IAM role to the directly preceding identity, with first
         account from the list granting this role to the originating account (templated).
     :param deferrable: Run operator in the deferrable mode
+    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+        if any. For this to work, the service account making the request must have
+        domain-wide delegation enabled. Deprecated.
     """
 
     template_fields: Sequence[str] = (
@@ -807,10 +818,10 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         max_results: int = 100,
         selected_fields: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         location: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = False,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -1253,7 +1264,10 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding identity, with first
         account from the list granting this role to the originating account (templated).
-    :param exists_ok: If ``True``, ignore "already exists" errors when creating the table.
+    :param if_exists: What should Airflow do if the table exists. If set to `log`, the TI will be passed to
+        success and an error message will be logged. Set to `ignore` to ignore the error, set to `fail` to
+        fail the TI, and set to `skip` to skip it.
+    :param exists_ok: Deprecated - use `if_exists="ignore"` instead.
     """
 
     template_fields: Sequence[str] = (
@@ -1282,9 +1296,7 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
         gcs_schema_object: str | None = None,
         time_partitioning: dict | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        bigquery_conn_id: str | None = None,
         google_cloud_storage_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         labels: dict | None = None,
         view: dict | None = None,
         materialized_view: dict | None = None,
@@ -1292,7 +1304,10 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
         location: str | None = None,
         cluster_fields: list[str] | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
-        exists_ok: bool = False,
+        if_exists: str = "log",
+        delegate_to: str | None = None,
+        bigquery_conn_id: str | None = None,
+        exists_ok: bool | None = None,
         **kwargs,
     ) -> None:
         if bigquery_conn_id:
@@ -1326,7 +1341,11 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
         self.cluster_fields = cluster_fields
         self.table_resource = table_resource
         self.impersonation_chain = impersonation_chain
-        self.exists_ok = exists_ok
+        if exists_ok is not None:
+            warnings.warn("`exists_ok` parameter is deprecated, please use `if_exists`", DeprecationWarning)
+            self.if_exists = IfExistAction.IGNORE if exists_ok else IfExistAction.LOG
+        else:
+            self.if_exists = IfExistAction(if_exists)
 
     def execute(self, context: Context) -> None:
         bq_hook = BigQueryHook(
@@ -1362,7 +1381,7 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
                 materialized_view=self.materialized_view,
                 encryption_configuration=self.encryption_configuration,
                 table_resource=self.table_resource,
-                exists_ok=self.exists_ok,
+                exists_ok=self.if_exists == IfExistAction.IGNORE,
             )
             BigQueryTableLink.persist(
                 context=context,
@@ -1375,7 +1394,13 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
                 "Table %s.%s.%s created successfully", table.project, table.dataset_id, table.table_id
             )
         except Conflict:
-            self.log.info("Table %s.%s already exists.", self.dataset_id, self.table_id)
+            error_msg = f"Table {self.dataset_id}.{self.table_id} already exists."
+            if self.if_exists == IfExistAction.LOG:
+                self.log.info(error_msg)
+            elif self.if_exists == IfExistAction.FAIL:
+                raise AirflowException(error_msg)
+            else:
+                raise AirflowSkipException(error_msg)
 
 
 class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator):
@@ -1490,14 +1515,14 @@ class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator):
         allow_quoted_newlines: bool = False,
         allow_jagged_rows: bool = False,
         gcp_conn_id: str = "google_cloud_default",
-        bigquery_conn_id: str | None = None,
         google_cloud_storage_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         src_fmt_configs: dict | None = None,
         labels: dict | None = None,
         encryption_configuration: dict | None = None,
         location: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
+        bigquery_conn_id: str | None = None,
         **kwargs,
     ) -> None:
         if bigquery_conn_id:
@@ -1721,8 +1746,8 @@ class BigQueryDeleteDatasetOperator(GoogleCloudBaseOperator):
         project_id: str | None = None,
         delete_contents: bool = False,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         self.dataset_id = dataset_id
@@ -1779,7 +1804,9 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding identity, with first
         account from the list granting this role to the originating account (templated).
-    :param exists_ok: If ``True``, ignore "already exists" errors when creating the dataset.
+    :param if_exists: What should Airflow do if the dataset exists. If set to `log`, the TI will be passed to
+        success and an error message will be logged. Set to `ignore` to ignore the error, set to `fail` to
+        fail the TI, and set to `skip` to skip it.
         **Example**: ::
 
             create_new_dataset = BigQueryCreateEmptyDatasetOperator(
@@ -1789,6 +1816,7 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
                 gcp_conn_id='_my_gcp_conn_',
                 task_id='newDatasetCreator',
                 dag=dag)
+    :param exists_ok: Deprecated - use `if_exists="ignore"` instead.
     """
 
     template_fields: Sequence[str] = (
@@ -1809,9 +1837,10 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
         dataset_reference: dict | None = None,
         location: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
-        exists_ok: bool = False,
+        if_exists: str = "log",
+        delegate_to: str | None = None,
+        exists_ok: bool | None = None,
         **kwargs,
     ) -> None:
 
@@ -1826,7 +1855,11 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
             )
         self.delegate_to = delegate_to
         self.impersonation_chain = impersonation_chain
-        self.exists_ok = exists_ok
+        if exists_ok is not None:
+            warnings.warn("`exists_ok` parameter is deprecated, please use `if_exists`", DeprecationWarning)
+            self.if_exists = IfExistAction.IGNORE if exists_ok else IfExistAction.LOG
+        else:
+            self.if_exists = IfExistAction(if_exists)
 
         super().__init__(**kwargs)
 
@@ -1844,7 +1877,7 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
                 dataset_id=self.dataset_id,
                 dataset_reference=self.dataset_reference,
                 location=self.location,
-                exists_ok=self.exists_ok,
+                exists_ok=self.if_exists == IfExistAction.IGNORE,
             )
             BigQueryDatasetLink.persist(
                 context=context,
@@ -1854,7 +1887,13 @@ class BigQueryCreateEmptyDatasetOperator(GoogleCloudBaseOperator):
             )
         except Conflict:
             dataset_id = self.dataset_reference.get("datasetReference", {}).get("datasetId", self.dataset_id)
-            self.log.info("Dataset %s already exists.", dataset_id)
+            error_msg = f"Dataset {dataset_id} already exists."
+            if self.if_exists == IfExistAction.LOG:
+                self.log.info(error_msg)
+            elif self.if_exists == IfExistAction.FAIL:
+                raise AirflowException(error_msg)
+            else:
+                raise AirflowSkipException(error_msg)
 
 
 class BigQueryGetDatasetOperator(GoogleCloudBaseOperator):
@@ -1897,8 +1936,8 @@ class BigQueryGetDatasetOperator(GoogleCloudBaseOperator):
         dataset_id: str,
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         self.dataset_id = dataset_id
@@ -1972,8 +2011,8 @@ class BigQueryGetDatasetTablesOperator(GoogleCloudBaseOperator):
         project_id: str | None = None,
         max_results: int | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         self.dataset_id = dataset_id
@@ -2045,8 +2084,8 @@ class BigQueryPatchDatasetOperator(GoogleCloudBaseOperator):
         dataset_resource: dict,
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         warnings.warn(
@@ -2133,8 +2172,8 @@ class BigQueryUpdateTableOperator(GoogleCloudBaseOperator):
         table_id: str | None = None,
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         self.dataset_id = dataset_id
@@ -2227,8 +2266,8 @@ class BigQueryUpdateDatasetOperator(GoogleCloudBaseOperator):
         dataset_id: str | None = None,
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         self.dataset_id = dataset_id
@@ -2308,10 +2347,10 @@ class BigQueryDeleteTableOperator(GoogleCloudBaseOperator):
         *,
         deletion_dataset_table: str,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         ignore_if_missing: bool = False,
         location: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -2385,9 +2424,9 @@ class BigQueryUpsertTableOperator(GoogleCloudBaseOperator):
         table_resource: dict,
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         location: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -2496,8 +2535,8 @@ class BigQueryUpdateTableSchemaOperator(GoogleCloudBaseOperator):
         include_policy_tags: bool = False,
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         self.schema_fields_updates = schema_fields_updates
@@ -2616,12 +2655,12 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
         force_rerun: bool = True,
         reattach_states: set[str] | None = None,
         gcp_conn_id: str = "google_cloud_default",
-        delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
         cancel_on_kill: bool = True,
         result_retry: Retry = DEFAULT_RETRY,
         result_timeout: float | None = None,
         deferrable: bool = False,
+        delegate_to: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py
index d814c894bb..b14b43e108 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -26,7 +26,7 @@ import pytest
 from google.cloud.bigquery import DEFAULT_RETRY
 from google.cloud.exceptions import Conflict
 
-from airflow.exceptions import AirflowException, AirflowTaskTimeout, TaskDeferred
+from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout, TaskDeferred
 from airflow.models import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
@@ -201,6 +201,41 @@ class TestBigQueryCreateEmptyTableOperator(unittest.TestCase):
         )
 
 
+@pytest.mark.parametrize(
+    "if_exists, is_conflict, expected_error, log_msg",
+    [
+        ("ignore", False, None, None),
+        ("log", False, None, None),
+        ("log", True, None, f"Table {TEST_DATASET}.{TEST_TABLE_ID} already exists."),
+        ("fail", False, None, None),
+        ("fail", True, AirflowException, None),
+        ("skip", False, None, None),
+        ("skip", True, AirflowSkipException, None),
+    ],
+)
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_create_existing_table(mock_hook, caplog, if_exists, is_conflict, expected_error, log_msg):
+    operator = BigQueryCreateEmptyTableOperator(
+        task_id=TASK_ID,
+        dataset_id=TEST_DATASET,
+        project_id=TEST_GCP_PROJECT_ID,
+        table_id=TEST_TABLE_ID,
+        view=VIEW_DEFINITION,
+        if_exists=if_exists,
+    )
+    if is_conflict:
+        mock_hook.return_value.create_empty_table.side_effect = Conflict("any")
+    else:
+        mock_hook.return_value.create_empty_table.side_effect = None
+    if expected_error is not None:
+        with pytest.raises(expected_error):
+            operator.execute(context=MagicMock())
+    else:
+        operator.execute(context=MagicMock())
+    if log_msg is not None:
+        assert log_msg in caplog.text
+
+
 class TestBigQueryCreateExternalTableOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
     def test_execute(self, mock_hook):
@@ -288,6 +323,40 @@ class TestBigQueryCreateEmptyDatasetOperator(unittest.TestCase):
         )
 
 
+@pytest.mark.parametrize(
+    "if_exists, is_conflict, expected_error, log_msg",
+    [
+        ("ignore", False, None, None),
+        ("log", False, None, None),
+        ("log", True, None, f"Dataset {TEST_DATASET} already exists."),
+        ("fail", False, None, None),
+        ("fail", True, AirflowException, None),
+        ("skip", False, None, None),
+        ("skip", True, AirflowSkipException, None),
+    ],
+)
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_create_empty_dataset(mock_hook, caplog, if_exists, is_conflict, expected_error, log_msg):
+    operator = BigQueryCreateEmptyDatasetOperator(
+        task_id=TASK_ID,
+        dataset_id=TEST_DATASET,
+        project_id=TEST_GCP_PROJECT_ID,
+        location=TEST_DATASET_LOCATION,
+        if_exists=if_exists,
+    )
+    if is_conflict:
+        mock_hook.return_value.create_empty_dataset.side_effect = Conflict("any")
+    else:
+        mock_hook.return_value.create_empty_dataset.side_effect = None
+    if expected_error is not None:
+        with pytest.raises(expected_error):
+            operator.execute(context=MagicMock())
+    else:
+        operator.execute(context=MagicMock())
+    if log_msg is not None:
+        assert log_msg in caplog.text
+
+
 class TestBigQueryGetDatasetOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
     def test_execute(self, mock_hook):