You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2021/05/31 05:07:20 UTC

[airflow] branch master updated: Fix: GCS To BigQuery source_object (#16160)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 99d1535  Fix: GCS To BigQuery source_object (#16160)
99d1535 is described below

commit 99d1535287df7f8cfced39baff7a08f6fcfdf8ca
Author: Tegar D Pratama <te...@gmail.com>
AuthorDate: Mon May 31 12:06:44 2021 +0700

    Fix: GCS To BigQuery source_object (#16160)
    
    * Fix: GCS To BigQuery source_object #16008
    
    Fix GCS To BigQuery source_object to accept both str and list
    
    * convert source_objects to list if not list
    
    converting source_objects to list instead of modifying the logic part
    
    * add tests
---
 .../google/cloud/transfers/gcs_to_bigquery.py      |  6 +-
 .../google/cloud/transfers/test_gcs_to_bigquery.py | 73 ++++++++++++++++++++++
 2 files changed, 76 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
index b925110..cef0d71 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
@@ -42,9 +42,9 @@ class GCSToBigQueryOperator(BaseOperator):
 
     :param bucket: The bucket to load from. (templated)
     :type bucket: str
-    :param source_objects: List of Google Cloud Storage URIs to load from. (templated)
+    :param source_objects: String or List of Google Cloud Storage URIs to load from. (templated)
         If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI.
-    :type source_objects: list[str]
+    :type source_objects: str, list[str]
     :param destination_project_dataset_table: The dotted
         ``(<project>.|<project>:)<dataset>.<table>`` BigQuery table to load data into.
         If ``<project>`` is not included, project will be the project defined in
@@ -219,7 +219,7 @@ class GCSToBigQueryOperator(BaseOperator):
         if time_partitioning is None:
             time_partitioning = {}
         self.bucket = bucket
-        self.source_objects = source_objects
+        self.source_objects = source_objects if isinstance(source_objects, list) else [source_objects]
         self.schema_object = schema_object
 
         # BQ config
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
index b0c2b3d..9bc3b3b 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
@@ -26,6 +26,7 @@ TEST_EXPLICIT_DEST = 'test-project.dataset.table'
 TEST_BUCKET = 'test-bucket'
 MAX_ID_KEY = 'id'
 TEST_SOURCE_OBJECTS = ['test/objects/*']
+TEST_SOURCE_OBJECTS_AS_STRING = 'test/objects/*'
 LABELS = {'k1': 'v1'}
 DESCRIPTION = "Test Description"
 
@@ -216,3 +217,75 @@ class TestGoogleCloudStorageToBigQueryOperator(unittest.TestCase):
                 description=DESCRIPTION,
             )
         # fmt: on
+
+    @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook')
+    def test_source_objects_as_list(self, bq_hook):
+        operator = GCSToBigQueryOperator(
+            task_id=TASK_ID,
+            bucket=TEST_BUCKET,
+            source_objects=TEST_SOURCE_OBJECTS,
+            destination_project_dataset_table=TEST_EXPLICIT_DEST,
+        )
+
+        operator.execute(None)
+
+        bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with(
+            destination_project_dataset_table=mock.ANY,
+            schema_fields=mock.ANY,
+            source_uris=[f'gs://{TEST_BUCKET}/{source_object}' for source_object in TEST_SOURCE_OBJECTS],
+            source_format=mock.ANY,
+            autodetect=mock.ANY,
+            create_disposition=mock.ANY,
+            skip_leading_rows=mock.ANY,
+            write_disposition=mock.ANY,
+            field_delimiter=mock.ANY,
+            max_bad_records=mock.ANY,
+            quote_character=mock.ANY,
+            ignore_unknown_values=mock.ANY,
+            allow_quoted_newlines=mock.ANY,
+            allow_jagged_rows=mock.ANY,
+            encoding=mock.ANY,
+            schema_update_options=mock.ANY,
+            src_fmt_configs=mock.ANY,
+            time_partitioning=mock.ANY,
+            cluster_fields=mock.ANY,
+            encryption_configuration=mock.ANY,
+            labels=mock.ANY,
+            description=mock.ANY,
+        )
+
+    @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook')
+    def test_source_objects_as_string(self, bq_hook):
+        operator = GCSToBigQueryOperator(
+            task_id=TASK_ID,
+            bucket=TEST_BUCKET,
+            source_objects=TEST_SOURCE_OBJECTS_AS_STRING,
+            destination_project_dataset_table=TEST_EXPLICIT_DEST,
+        )
+
+        operator.execute(None)
+
+        bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with(
+            destination_project_dataset_table=mock.ANY,
+            schema_fields=mock.ANY,
+            source_uris=[f'gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}'],
+            source_format=mock.ANY,
+            autodetect=mock.ANY,
+            create_disposition=mock.ANY,
+            skip_leading_rows=mock.ANY,
+            write_disposition=mock.ANY,
+            field_delimiter=mock.ANY,
+            max_bad_records=mock.ANY,
+            quote_character=mock.ANY,
+            ignore_unknown_values=mock.ANY,
+            allow_quoted_newlines=mock.ANY,
+            allow_jagged_rows=mock.ANY,
+            encoding=mock.ANY,
+            schema_update_options=mock.ANY,
+            src_fmt_configs=mock.ANY,
+            time_partitioning=mock.ANY,
+            cluster_fields=mock.ANY,
+            encryption_configuration=mock.ANY,
+            labels=mock.ANY,
+            description=mock.ANY,
+        )