You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/12/04 19:02:21 UTC

[airflow] branch main updated: Fix: re-enable use of parameters in gcs_to_bq which had been disabled (#27961)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d663df055 Fix: re-enable use of parameters in gcs_to_bq which had been disabled (#27961)
2d663df055 is described below

commit 2d663df0552542efcef6e59bc2bc1586f8d1c7f3
Author: Matt <md...@users.noreply.github.com>
AuthorDate: Sun Dec 4 14:02:09 2022 -0500

    Fix: re-enable use of parameters in gcs_to_bq which had been disabled (#27961)
---
 .../google/cloud/transfers/gcs_to_bigquery.py      |  26 ++-
 .../google/cloud/transfers/test_gcs_to_bigquery.py | 228 +++++++++++++++++++++
 2 files changed, 253 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
index ce7da49062..b681b92d03 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
@@ -27,7 +27,11 @@ from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, Q
 
 from airflow import AirflowException
 from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
+from airflow.providers.google.cloud.hooks.bigquery import (
+    BigQueryHook,
+    BigQueryJob,
+    _cleanse_time_partitioning,
+)
 from airflow.providers.google.cloud.hooks.gcs import GCSHook
 from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
 from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger
@@ -390,8 +394,28 @@ class GCSToBigQueryOperator(BaseOperator):
                     "ignoreUnknownValues": self.ignore_unknown_values,
                     "allowQuotedNewlines": self.allow_quoted_newlines,
                     "encoding": self.encoding,
+                    "allowJaggedRows": self.allow_jagged_rows,
+                    "fieldDelimiter": self.field_delimiter,
+                    "maxBadRecords": self.max_bad_records,
+                    "quote": self.quote_character,
+                    "schemaUpdateOptions": self.schema_update_options,
                 },
             }
+            if self.cluster_fields:
+                self.configuration["load"].update({"clustering": {"fields": self.cluster_fields}})
+            time_partitioning = _cleanse_time_partitioning(
+                self.destination_project_dataset_table, self.time_partitioning
+            )
+            if time_partitioning:
+                self.configuration["load"].update({"timePartitioning": time_partitioning})
+            # fields that should only be set if defined
+            set_if_def = {
+                "quote": self.quote_character,
+                "destinationEncryptionConfiguration": self.encryption_configuration,
+            }
+            for k, v in set_if_def.items():
+                if v:
+                    self.configuration["load"][k] = v
             self.configuration = self._check_schema_fields(self.configuration)
             try:
                 self.log.info("Executing: %s", self.configuration)
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
index ee69c214e7..f4b5f59f82 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
@@ -163,6 +163,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=hook.return_value.project_id,
@@ -226,6 +231,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=hook.return_value.project_id,
@@ -335,6 +345,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=hook.return_value.project_id,
@@ -434,6 +449,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=hook.return_value.project_id,
@@ -535,6 +555,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=hook.return_value.project_id,
@@ -632,6 +657,194 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
+                    ),
+                },
+                project_id=hook.return_value.project_id,
+                location=None,
+                job_id=pytest.real_job_id,
+                timeout=None,
+                retry=DEFAULT_RETRY,
+                nowait=True,
+            ),
+        ]
+
+        hook.return_value.insert_job.assert_has_calls(calls)
+
+    @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+    def test_all_fields_should_be_present(self, hook):
+        hook.return_value.insert_job.side_effect = [
+            MagicMock(job_id=pytest.real_job_id, error_result=False),
+            pytest.real_job_id,
+        ]
+        hook.return_value.generate_job_id.return_value = pytest.real_job_id
+        hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+        operator = GCSToBigQueryOperator(
+            task_id=TASK_ID,
+            bucket=TEST_BUCKET,
+            source_objects=TEST_SOURCE_OBJECTS,
+            schema_fields=SCHEMA_FIELDS,
+            destination_project_dataset_table=TEST_EXPLICIT_DEST,
+            write_disposition=WRITE_DISPOSITION,
+            external_table=False,
+            field_delimiter=";",
+            max_bad_records=13,
+            quote_character="|",
+            schema_update_options={"foo": "bar"},
+            allow_jagged_rows=True,
+            encryption_configuration={"bar": "baz"},
+            cluster_fields=["field_1", "field_2"],
+        )
+
+        operator.execute(context=MagicMock())
+
+        calls = [
+            call(
+                configuration={
+                    "load": dict(
+                        autodetect=True,
+                        createDisposition="CREATE_IF_NEEDED",
+                        destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+                        destinationTableProperties={
+                            "description": None,
+                            "labels": None,
+                        },
+                        sourceFormat="CSV",
+                        skipLeadingRows=None,
+                        sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+                        writeDisposition=WRITE_DISPOSITION,
+                        ignoreUnknownValues=False,
+                        allowQuotedNewlines=False,
+                        encoding="UTF-8",
+                        schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=True,
+                        fieldDelimiter=";",
+                        maxBadRecords=13,
+                        quote="|",
+                        schemaUpdateOptions={"foo": "bar"},
+                        destinationEncryptionConfiguration={"bar": "baz"},
+                        clustering={"fields": ["field_1", "field_2"]},
+                    ),
+                },
+                project_id=hook.return_value.project_id,
+                location=None,
+                job_id=pytest.real_job_id,
+                timeout=None,
+                retry=DEFAULT_RETRY,
+                nowait=True,
+            ),
+        ]
+
+        hook.return_value.insert_job.assert_has_calls(calls)
+
+    @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+    def test_date_partitioned_explicit_setting_should_be_found(self, hook):
+        hook.return_value.insert_job.side_effect = [
+            MagicMock(job_id=pytest.real_job_id, error_result=False),
+            pytest.real_job_id,
+        ]
+        hook.return_value.generate_job_id.return_value = pytest.real_job_id
+        hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+        operator = GCSToBigQueryOperator(
+            task_id=TASK_ID,
+            bucket=TEST_BUCKET,
+            source_objects=TEST_SOURCE_OBJECTS,
+            schema_fields=SCHEMA_FIELDS,
+            destination_project_dataset_table=TEST_EXPLICIT_DEST,
+            write_disposition=WRITE_DISPOSITION,
+            external_table=False,
+            time_partitioning={"type": "DAY"},
+        )
+
+        operator.execute(context=MagicMock())
+
+        calls = [
+            call(
+                configuration={
+                    "load": dict(
+                        autodetect=True,
+                        createDisposition="CREATE_IF_NEEDED",
+                        destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+                        destinationTableProperties={
+                            "description": None,
+                            "labels": None,
+                        },
+                        sourceFormat="CSV",
+                        skipLeadingRows=None,
+                        sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+                        writeDisposition=WRITE_DISPOSITION,
+                        ignoreUnknownValues=False,
+                        allowQuotedNewlines=False,
+                        encoding="UTF-8",
+                        schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
+                        timePartitioning={"type": "DAY"},
+                    ),
+                },
+                project_id=hook.return_value.project_id,
+                location=None,
+                job_id=pytest.real_job_id,
+                timeout=None,
+                retry=DEFAULT_RETRY,
+                nowait=True,
+            ),
+        ]
+
+        hook.return_value.insert_job.assert_has_calls(calls)
+
+    @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+    def test_date_partitioned_implied_in_table_name_should_be_found(self, hook):
+        hook.return_value.insert_job.side_effect = [
+            MagicMock(job_id=pytest.real_job_id, error_result=False),
+            pytest.real_job_id,
+        ]
+        hook.return_value.generate_job_id.return_value = pytest.real_job_id
+        hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+        operator = GCSToBigQueryOperator(
+            task_id=TASK_ID,
+            bucket=TEST_BUCKET,
+            source_objects=TEST_SOURCE_OBJECTS,
+            schema_fields=SCHEMA_FIELDS,
+            destination_project_dataset_table=TEST_EXPLICIT_DEST + "$20221123",
+            write_disposition=WRITE_DISPOSITION,
+            external_table=False,
+        )
+
+        operator.execute(context=MagicMock())
+
+        calls = [
+            call(
+                configuration={
+                    "load": dict(
+                        autodetect=True,
+                        createDisposition="CREATE_IF_NEEDED",
+                        destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+                        destinationTableProperties={
+                            "description": None,
+                            "labels": None,
+                        },
+                        sourceFormat="CSV",
+                        skipLeadingRows=None,
+                        sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+                        writeDisposition=WRITE_DISPOSITION,
+                        ignoreUnknownValues=False,
+                        allowQuotedNewlines=False,
+                        encoding="UTF-8",
+                        schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
+                        timePartitioning={"type": "DAY"},
                     ),
                 },
                 project_id=hook.return_value.project_id,
@@ -830,6 +1043,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=bq_hook.return_value.project_id,
@@ -1023,6 +1241,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         ignoreUnknownValues=False,
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=bq_hook.return_value.project_id,
@@ -1087,6 +1310,11 @@ class TestGCSToBigQueryOperator(unittest.TestCase):
                         allowQuotedNewlines=False,
                         encoding="UTF-8",
                         schema={"fields": SCHEMA_FIELDS_INT},
+                        allowJaggedRows=False,
+                        fieldDelimiter=",",
+                        maxBadRecords=0,
+                        quote=None,
+                        schemaUpdateOptions=(),
                     ),
                 },
                 project_id=hook.return_value.project_id,