You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 01:53:47 UTC
[airflow] branch main updated: Add deferrable mode to GCPToBigQueryOperator + tests (#27052)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 124fb3948d Add deferrable mode to GCPToBigQueryOperator + tests (#27052)
124fb3948d is described below
commit 124fb3948d18c4fe4b2aad12eecfd5ba1efca4bc
Author: VladaZakharova <80...@users.noreply.github.com>
AuthorDate: Mon Oct 31 02:53:40 2022 +0100
Add deferrable mode to GCPToBigQueryOperator + tests (#27052)
---
airflow/providers/google/cloud/hooks/bigquery.py | 4 +-
.../providers/google/cloud/operators/bigquery.py | 2 +-
.../google/cloud/transfers/gcs_to_bigquery.py | 409 ++++--
.../operators/cloud/gcs.rst | 10 +-
.../google/cloud/transfers/test_gcs_to_bigquery.py | 1554 ++++++++++++++++++--
.../cloud/gcs/example_gcs_to_bigquery_async.py | 121 ++
6 files changed, 1845 insertions(+), 255 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py
index 28635123f8..0618967eb3 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2247,7 +2247,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
if project_id is None:
if var_name is not None:
self.log.info(
- 'Project not included in %s: %s; using project "%s"',
+ 'Project is not included in %s: %s; using project "%s"',
var_name,
table_input,
default_project_id,
@@ -2913,7 +2913,7 @@ def split_tablename(
if project_id is None:
if var_name is not None:
log.info(
- 'Project not included in %s: %s; using project "%s"',
+ 'Project is not included in %s: %s; using project "%s"',
var_name,
table_input,
default_project_id,
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index b67d5a050a..9dd481c093 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -2511,7 +2511,7 @@ class BigQueryInsertJobOperator(BaseOperator):
:param configuration: The configuration parameter maps directly to BigQuery's
- configuration field in the job object. For more details see
+ configuration field in the job object. For more details see
https://cloud.google.com/bigquery/docs/reference/v2/jobs
:param job_id: The ID of the job. It will be suffixed with hash of job configuration
unless ``force_rerun`` is True.
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
index e4fbb9daa0..ce7da49062 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
@@ -19,12 +19,18 @@
from __future__ import annotations
import json
-import warnings
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
+from google.api_core.exceptions import Conflict
+from google.api_core.retry import Retry
+from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob
+
+from airflow import AirflowException
from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
+from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
+from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -66,7 +72,18 @@ class GCSToBigQueryOperator(BaseOperator):
This setting is ignored for Google Cloud Bigtable,
Google Cloud Datastore backups and Avro formats.
:param create_disposition: The create disposition if the table doesn't exist.
- :param skip_leading_rows: Number of rows to skip when loading from a CSV.
+ :param skip_leading_rows: The number of rows at the top of a CSV file that BigQuery
+ will skip when loading the data.
+ When autodetect is on, the behavior is the following:
+ skip_leading_rows unspecified - Autodetect tries to detect headers in the first row.
+ If they are not detected, the row is read as data. Otherwise, data is read starting
+ from the second row.
+ skip_leading_rows is 0 - Instructs autodetect that there are no headers and data
+ should be read starting from the first row.
+ skip_leading_rows = N > 0 - Autodetect skips N-1 rows and tries to detect headers
+ in row N. If headers are not detected, row N is just skipped. Otherwise, row N is
+ used to extract column names for the detected schema.
+ Default value set to None so that autodetect option can detect schema fields.
:param write_disposition: The write disposition if the table already exists.
:param field_delimiter: The delimiter to use when loading from a CSV.
:param max_bad_records: The maximum number of bad records that BigQuery can
@@ -129,7 +146,10 @@ class GCSToBigQueryOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param labels: [Optional] Labels for the BiqQuery table.
- :param description: [Optional] Description for the BigQuery table.
+ :param description: [Optional] Description for the BigQuery table. This will only be used if the
+ destination table is newly created. If the table already exists and a value different than the
+ current description is provided, the job will fail.
+ :param deferrable: Run operator in the deferrable mode
"""
template_fields: Sequence[str] = (
@@ -142,6 +162,7 @@ class GCSToBigQueryOperator(BaseOperator):
)
template_ext: Sequence[str] = (".sql",)
ui_color = "#f0eee4"
+ operator_extra_links = (BigQueryTableLink(),)
def __init__(
self,
@@ -155,7 +176,7 @@ class GCSToBigQueryOperator(BaseOperator):
source_format="CSV",
compression="NONE",
create_disposition="CREATE_IF_NEEDED",
- skip_leading_rows=0,
+ skip_leading_rows=None,
write_disposition="WRITE_EMPTY",
field_delimiter=",",
max_bad_records=0,
@@ -178,10 +199,19 @@ class GCSToBigQueryOperator(BaseOperator):
impersonation_chain: str | Sequence[str] | None = None,
labels=None,
description=None,
+ deferrable: bool = False,
+ result_retry: Retry = DEFAULT_RETRY,
+ result_timeout: float | None = None,
+ cancel_on_kill: bool = True,
+ job_id: str | None = None,
+ force_rerun: bool = True,
+ reattach_states: set[str] | None = None,
**kwargs,
- ):
+ ) -> None:
super().__init__(**kwargs)
+ self.hook: BigQueryHook | None = None
+ self.configuration: dict[str, Any] = {}
# GCS config
if src_fmt_configs is None:
@@ -229,16 +259,275 @@ class GCSToBigQueryOperator(BaseOperator):
self.labels = labels
self.description = description
+ self.job_id = job_id
+ self.deferrable = deferrable
+ self.result_retry = result_retry
+ self.result_timeout = result_timeout
+ self.force_rerun = force_rerun
+ self.reattach_states: set[str] = reattach_states or set()
+ self.cancel_on_kill = cancel_on_kill
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ job_id: str,
+ ) -> BigQueryJob:
+ # Submit a new job without waiting for it to complete.
+ return hook.insert_job(
+ configuration=self.configuration,
+ project_id=hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ timeout=self.result_timeout,
+ retry=self.result_retry,
+ nowait=True,
+ )
+
+ @staticmethod
+ def _handle_job_error(job: BigQueryJob) -> None:
+ if job.error_result:
+ raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}")
+
def execute(self, context: Context):
- bq_hook = BigQueryHook(
+ hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
location=self.location,
impersonation_chain=self.impersonation_chain,
)
+ self.hook = hook
+ job_id = self.hook.generate_job_id(
+ job_id=self.job_id,
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=context["logical_date"],
+ configuration=self.configuration,
+ force_rerun=self.force_rerun,
+ )
+ self.source_objects = (
+ self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
+ )
+ source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects]
if not self.schema_fields:
+ gcs_hook = GCSHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ )
if self.schema_object and self.source_format != "DATASTORE_BACKUP":
+ schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
+ self.log.info("Autodetected fields from schema object: %s", schema_fields)
+
+ if self.external_table:
+ self.log.info("Creating a new BigQuery table for storing data...")
+ project_id, dataset_id, table_id = self.hook.split_tablename(
+ table_input=self.destination_project_dataset_table,
+ default_project_id=self.hook.project_id or "",
+ )
+ table_resource = {
+ "tableReference": {
+ "projectId": project_id,
+ "datasetId": dataset_id,
+ "tableId": table_id,
+ },
+ "labels": self.labels,
+ "description": self.description,
+ "externalDataConfiguration": {
+ "source_uris": source_uris,
+ "source_format": self.source_format,
+ "maxBadRecords": self.max_bad_records,
+ "autodetect": self.autodetect,
+ "compression": self.compression,
+ "csvOptions": {
+ "fieldDelimeter": self.field_delimiter,
+ "skipLeadingRows": self.skip_leading_rows,
+ "quote": self.quote_character,
+ "allowQuotedNewlines": self.allow_quoted_newlines,
+ "allowJaggedRows": self.allow_jagged_rows,
+ },
+ },
+ "location": self.location,
+ "encryptionConfiguration": self.encryption_configuration,
+ }
+ table_resource_checked_schema = self._check_schema_fields(table_resource)
+ table = self.hook.create_empty_table(
+ table_resource=table_resource_checked_schema,
+ )
+ max_id = self._find_max_value_in_column()
+ BigQueryTableLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=table.to_api_repr()["tableReference"]["datasetId"],
+ project_id=table.to_api_repr()["tableReference"]["projectId"],
+ table_id=table.to_api_repr()["tableReference"]["tableId"],
+ )
+ return max_id
+ else:
+ self.log.info("Using existing BigQuery table for storing data...")
+ destination_project, destination_dataset, destination_table = self.hook.split_tablename(
+ table_input=self.destination_project_dataset_table,
+ default_project_id=self.hook.project_id or "",
+ var_name="destination_project_dataset_table",
+ )
+ self.configuration = {
+ "load": {
+ "autodetect": self.autodetect,
+ "createDisposition": self.create_disposition,
+ "destinationTable": {
+ "projectId": destination_project,
+ "datasetId": destination_dataset,
+ "tableId": destination_table,
+ },
+ "destinationTableProperties": {
+ "description": self.description,
+ "labels": self.labels,
+ },
+ "sourceFormat": self.source_format,
+ "skipLeadingRows": self.skip_leading_rows,
+ "sourceUris": source_uris,
+ "writeDisposition": self.write_disposition,
+ "ignoreUnknownValues": self.ignore_unknown_values,
+ "allowQuotedNewlines": self.allow_quoted_newlines,
+ "encoding": self.encoding,
+ },
+ }
+ self.configuration = self._check_schema_fields(self.configuration)
+ try:
+ self.log.info("Executing: %s", self.configuration)
+ job = self._submit_job(self.hook, job_id)
+ except Conflict:
+ # If the job already exists retrieve it
+ job = self.hook.get_job(
+ project_id=self.hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ )
+ if job.state in self.reattach_states:
+ # We are reattaching to a job
+ job._begin()
+ self._handle_job_error(job)
+ else:
+ # Same job configuration so we need force_rerun
+ raise AirflowException(
+ f"Job with id: {job_id} already exists and is in {job.state} state. If you "
+ f"want to force rerun it consider setting `force_rerun=True`."
+ f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
+ )
+
+ job_types = {
+ LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"],
+ CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"],
+ ExtractJob._JOB_TYPE: ["sourceTable"],
+ QueryJob._JOB_TYPE: ["destinationTable"],
+ }
+
+ if self.hook.project_id:
+ for job_type, tables_prop in job_types.items():
+ job_configuration = job.to_api_repr()["configuration"]
+ if job_type in job_configuration:
+ for table_prop in tables_prop:
+ if table_prop in job_configuration[job_type]:
+ table = job_configuration[job_type][table_prop]
+ persist_kwargs = {
+ "context": context,
+ "task_instance": self,
+ "project_id": self.hook.project_id,
+ "table_id": table,
+ }
+ if not isinstance(table, str):
+ persist_kwargs["table_id"] = table["tableId"]
+ persist_kwargs["dataset_id"] = table["datasetId"]
+ BigQueryTableLink.persist(**persist_kwargs)
+
+ self.job_id = job.job_id
+ context["ti"].xcom_push(key="job_id", value=self.job_id)
+ if self.deferrable:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BigQueryInsertJobTrigger(
+ conn_id=self.gcp_conn_id,
+ job_id=self.job_id,
+ project_id=self.hook.project_id,
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ job.result(timeout=self.result_timeout, retry=self.result_retry)
+ max_id = self._find_max_value_in_column()
+ self._handle_job_error(job)
+ return max_id
+
+ def execute_complete(self, context: Context, event: dict[str, Any]):
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(
+ "%s completed with response %s ",
+ self.task_id,
+ event["message"],
+ )
+ return self._find_max_value_in_column()
+
+ def _find_max_value_in_column(self):
+ hook = BigQueryHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ location=self.location,
+ impersonation_chain=self.impersonation_chain,
+ )
+ if self.max_id_key:
+ self.log.info(f"Selecting the MAX value from BigQuery column '{self.max_id_key}'...")
+ select_command = (
+ f"SELECT MAX({self.max_id_key}) AS max_value "
+ f"FROM {self.destination_project_dataset_table}"
+ )
+
+ self.configuration = {
+ "query": {
+ "query": select_command,
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ }
+ job_id = hook.insert_job(configuration=self.configuration, project_id=hook.project_id)
+ rows = list(hook.get_job(job_id=job_id, location=self.location).result())
+ if rows:
+ for row in rows:
+ max_id = row[0] if row[0] else 0
+ self.log.info(
+ "Loaded BQ data with MAX value of column %s.%s: %s",
+ self.destination_project_dataset_table,
+ self.max_id_key,
+ max_id,
+ )
+ return str(max_id)
+ else:
+ raise RuntimeError(f"The {select_command} returned no rows!")
+
+ def _check_schema_fields(self, table_resource):
+ """
+ Helper method to detect schema fields if they were not specified by user and autodetect=True.
+ If source_objects were passed, method reads the second row in CSV file. If there is at least one digit
+ table_resurce is returned without changes so that BigQuery can determine schema_fields in the
+ next step.
+ If there are only characters, the first row with fields is used to construct schema_fields argument
+ with type 'STRING'. Table_resource is updated with new schema_fileds key and returned back to operator
+ :param table_resource: Configuration or table_resource dictionary
+ :return: table_resource: Updated table_resource dict with schema_fields
+ """
+ if not self.autodetect and not self.schema_fields:
+ raise RuntimeError(
+ "Table schema was not found. Set autodetect=True to "
+ "automatically set schema fields from source objects or pass "
+ "schema_fields explicitly"
+ )
+ elif not self.schema_fields:
+ for source_object in self.source_objects:
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
@@ -246,88 +535,30 @@ class GCSToBigQueryOperator(BaseOperator):
)
blob = gcs_hook.download(
bucket_name=self.schema_object_bucket,
- object_name=self.schema_object,
+ object_name=source_object,
)
- schema_fields = json.loads(blob.decode("utf-8"))
- else:
- schema_fields = None
- else:
- schema_fields = self.schema_fields
-
- self.source_objects = (
- self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
- )
- source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects]
+ fields, values = [item.split(",") for item in blob.decode("utf-8").splitlines()][:2]
+ import re
+ if any(re.match(r"[\d\-\\.]+$", value) for value in values):
+ return table_resource
+ else:
+ schema_fields = []
+ for field in fields:
+ schema_fields.append({"name": field, "type": "STRING", "mode": "NULLABLE"})
+ self.schema_fields = schema_fields
+ if self.external_table:
+ table_resource["externalDataConfiguration"]["csvOptions"]["skipLeadingRows"] = 1
+ elif not self.external_table:
+ table_resource["load"]["skipLeadingRows"] = 1
if self.external_table:
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
- bq_hook.create_external_table(
- external_project_dataset_table=self.destination_project_dataset_table,
- schema_fields=schema_fields,
- source_uris=source_uris,
- source_format=self.source_format,
- autodetect=self.autodetect,
- compression=self.compression,
- skip_leading_rows=self.skip_leading_rows,
- field_delimiter=self.field_delimiter,
- max_bad_records=self.max_bad_records,
- quote_character=self.quote_character,
- ignore_unknown_values=self.ignore_unknown_values,
- allow_quoted_newlines=self.allow_quoted_newlines,
- allow_jagged_rows=self.allow_jagged_rows,
- encoding=self.encoding,
- src_fmt_configs=self.src_fmt_configs,
- encryption_configuration=self.encryption_configuration,
- labels=self.labels,
- description=self.description,
- )
- else:
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
- bq_hook.run_load(
- destination_project_dataset_table=self.destination_project_dataset_table,
- schema_fields=schema_fields,
- source_uris=source_uris,
- source_format=self.source_format,
- autodetect=self.autodetect,
- create_disposition=self.create_disposition,
- skip_leading_rows=self.skip_leading_rows,
- write_disposition=self.write_disposition,
- field_delimiter=self.field_delimiter,
- max_bad_records=self.max_bad_records,
- quote_character=self.quote_character,
- ignore_unknown_values=self.ignore_unknown_values,
- allow_quoted_newlines=self.allow_quoted_newlines,
- allow_jagged_rows=self.allow_jagged_rows,
- encoding=self.encoding,
- schema_update_options=self.schema_update_options,
- src_fmt_configs=self.src_fmt_configs,
- time_partitioning=self.time_partitioning,
- cluster_fields=self.cluster_fields,
- encryption_configuration=self.encryption_configuration,
- labels=self.labels,
- description=self.description,
- )
+ table_resource["schema"] = {"fields": self.schema_fields}
+ elif not self.external_table:
+ table_resource["load"]["schema"] = {"fields": self.schema_fields}
+ return table_resource
- if self.max_id_key:
- select_command = f"SELECT MAX({self.max_id_key}) FROM `{self.destination_project_dataset_table}`"
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
- job_id = bq_hook.run_query(
- sql=select_command,
- location=self.location,
- use_legacy_sql=False,
- )
- result = bq_hook.get_job(job_id=job_id, location=self.location).result()
- row = next(iter(result), None)
- if row is None:
- raise RuntimeError(f"The {select_command} returned no rows!")
- max_id = row[0]
- self.log.info(
- "Loaded BQ data with max %s.%s=%s",
- self.destination_project_dataset_table,
- self.max_id_key,
- max_id,
- )
- return max_id
+ def on_kill(self) -> None:
+ if self.job_id and self.cancel_on_kill:
+ self.hook.cancel_job(job_id=self.job_id, location=self.location) # type: ignore[union-attr]
+ else:
+ self.log.info("Skipping to cancel job: %s.%s", self.location, self.job_id)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/gcs.rst b/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
index 923e7f7396..2a21a37cb5 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/gcs.rst
@@ -39,7 +39,7 @@ GCSToBigQueryOperator
Use the
:class:`~airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator`
-to execute a BigQuery load job.
+to execute a BigQuery load job to load existing dataset from Google Cloud Storage to BigQuery table.
.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery.py
:language: python
@@ -47,6 +47,14 @@ to execute a BigQuery load job.
:start-after: [START howto_operator_gcs_to_bigquery]
:end-before: [END howto_operator_gcs_to_bigquery]
+Also you can use GCSToBigQueryOperator in the deferrable mode:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_gcs_to_bigquery_async]
+ :end-before: [END howto_operator_gcs_to_bigquery_async]
+
.. _howto/operator:GCSTimeSpanFileTransformOperator:
diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
index 33e9dc55c1..ee69c214e7 100644
--- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
+++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
@@ -19,260 +19,1490 @@ from __future__ import annotations
import unittest
from unittest import mock
+from unittest.mock import MagicMock, call
-from google.cloud.bigquery.table import Row
+import pytest
+from google.cloud.bigquery import DEFAULT_RETRY
+from google.cloud.exceptions import Conflict
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import DAG
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator
+from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger
+from airflow.utils.timezone import datetime
+from airflow.utils.types import DagRunType
TASK_ID = "test-gcs-to-bq-operator"
TEST_EXPLICIT_DEST = "test-project.dataset.table"
TEST_BUCKET = "test-bucket"
+PROJECT_ID = "test-project"
+DATASET = "dataset"
+TABLE = "table"
+WRITE_DISPOSITION = "WRITE_TRUNCATE"
MAX_ID_KEY = "id"
-TEST_SOURCE_OBJECTS = ["test/objects/*"]
-TEST_SOURCE_OBJECTS_AS_STRING = "test/objects/*"
+TEST_DATASET_LOCATION = "US"
+SCHEMA_FIELDS = [
+ {"name": "id", "type": "STRING", "mode": "NULLABLE"},
+ {"name": "name", "type": "STRING", "mode": "NULLABLE"},
+]
+SCHEMA_FIELDS_INT = [
+ {"name": "id", "type": "INTEGER", "mode": "NULLABLE"},
+ {"name": "name", "type": "STRING", "mode": "NULLABLE"},
+]
+TEST_SOURCE_OBJECTS = ["test/objects/test.csv"]
+TEST_SOURCE_OBJECTS_AS_STRING = "test/objects/test.csv"
LABELS = {"k1": "v1"}
DESCRIPTION = "Test Description"
+job_id = "123456"
+hash_ = "hash"
+pytest.real_job_id = f"{job_id}_{hash_}"
+
class TestGCSToBigQueryOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_execute_explicit_project(self, bq_hook):
+ def test_max_value_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
max_id_key=MAX_ID_KEY,
+ external_table=True,
)
- bq_hook.return_value.get_job.return_value.result.return_value = [Row(("100",), {"f0_": 0})]
+ result = operator.execute(context=MagicMock())
- result = operator.execute(None)
+ assert result == "1"
+ hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS},
+ }
+ )
+ hook.return_value.insert_job.assert_called_once_with(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=hook.return_value.project_id,
+ )
- assert result == "100"
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_max_value_without_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
- bq_hook.return_value.run_query.assert_called_once_with(
- sql="SELECT MAX(id) FROM `test-project.dataset.table`",
- location=None,
- use_legacy_sql=False,
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ schema_fields=SCHEMA_FIELDS,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
)
+ result = operator.execute(context=MagicMock())
+ assert result == "1"
+
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ call(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=hook.return_value.project_id,
+ ),
+ ]
+
+ hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_max_value_should_throw_ex_when_query_returns_no_rows(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ with pytest.raises(RuntimeError, match=r"returned no rows!"):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ schema_fields=SCHEMA_FIELDS,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
+ )
+ operator.execute(context=MagicMock())
+
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ call(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=hook.return_value.project_id,
+ ),
+ ]
+
+ hook.return_value.insert_job.assert_has_calls(calls)
+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_labels(self, bq_hook):
+ def test_labels_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ schema_fields=SCHEMA_FIELDS,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=True,
labels=LABELS,
)
- operator.execute(None)
-
- bq_hook.return_value.run_load.assert_called_once_with(
- destination_project_dataset_table=mock.ANY,
- schema_fields=mock.ANY,
- source_uris=mock.ANY,
- source_format=mock.ANY,
- autodetect=mock.ANY,
- create_disposition=mock.ANY,
- skip_leading_rows=mock.ANY,
- write_disposition=mock.ANY,
- field_delimiter=mock.ANY,
- max_bad_records=mock.ANY,
- quote_character=mock.ANY,
- ignore_unknown_values=mock.ANY,
- allow_quoted_newlines=mock.ANY,
- allow_jagged_rows=mock.ANY,
- encoding=mock.ANY,
- schema_update_options=mock.ANY,
- src_fmt_configs=mock.ANY,
- time_partitioning=mock.ANY,
- cluster_fields=mock.ANY,
- encryption_configuration=mock.ANY,
+ operator.execute(context=MagicMock())
+ hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": LABELS,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS},
+ }
+ )
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_labels_without_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ external_table=False,
labels=LABELS,
- description=mock.ANY,
)
+ operator.execute(context=MagicMock())
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": LABELS,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ ]
+
+ hook.return_value.insert_job.assert_has_calls(calls)
+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_description(self, bq_hook):
+ def test_description_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
description=DESCRIPTION,
+ external_table=True,
)
- operator.execute(None)
-
- bq_hook.return_value.run_load.assert_called_once_with(
- destination_project_dataset_table=mock.ANY,
- schema_fields=mock.ANY,
- source_uris=mock.ANY,
- source_format=mock.ANY,
- autodetect=mock.ANY,
- create_disposition=mock.ANY,
- skip_leading_rows=mock.ANY,
- write_disposition=mock.ANY,
- field_delimiter=mock.ANY,
- max_bad_records=mock.ANY,
- quote_character=mock.ANY,
- ignore_unknown_values=mock.ANY,
- allow_quoted_newlines=mock.ANY,
- allow_jagged_rows=mock.ANY,
- encoding=mock.ANY,
- schema_update_options=mock.ANY,
- src_fmt_configs=mock.ANY,
- time_partitioning=mock.ANY,
- cluster_fields=mock.ANY,
- encryption_configuration=mock.ANY,
- labels=mock.ANY,
+ operator.execute(context=MagicMock())
+ hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": DESCRIPTION,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS},
+ }
+ )
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_description_without_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ schema_fields=SCHEMA_FIELDS,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
description=DESCRIPTION,
)
+ operator.execute(context=MagicMock())
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": DESCRIPTION,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ ]
+
+ hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_source_objs_as_list_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ schema_fields=SCHEMA_FIELDS,
+ write_disposition=WRITE_DISPOSITION,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ external_table=True,
+ )
+
+ operator.execute(context=MagicMock())
+
+ hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [
+ f"gs://{TEST_BUCKET}/{source_object}" for source_object in TEST_SOURCE_OBJECTS
+ ],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS},
+ }
+ )
+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_labels_external_table(self, bq_hook):
+ def test_source_objs_as_list_without_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ schema_fields=SCHEMA_FIELDS,
+ write_disposition=WRITE_DISPOSITION,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ external_table=False,
+ )
+
+ operator.execute(context=MagicMock())
+
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[
+ f"gs://{TEST_BUCKET}/{source_object}" for source_object in TEST_SOURCE_OBJECTS
+ ],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ ]
+ hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_source_objs_as_string_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
+ schema_fields=SCHEMA_FIELDS,
+ write_disposition=WRITE_DISPOSITION,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
- labels=LABELS,
external_table=True,
)
- operator.execute(None)
- # fmt: off
- bq_hook.return_value.create_external_table.assert_called_once_with(
- external_project_dataset_table=mock.ANY,
- schema_fields=mock.ANY,
- source_uris=mock.ANY,
- source_format=mock.ANY,
- autodetect=mock.ANY,
- compression=mock.ANY,
- skip_leading_rows=mock.ANY,
- field_delimiter=mock.ANY,
- max_bad_records=mock.ANY,
- quote_character=mock.ANY,
- ignore_unknown_values=mock.ANY,
- allow_quoted_newlines=mock.ANY,
- allow_jagged_rows=mock.ANY,
- encoding=mock.ANY,
- src_fmt_configs=mock.ANY,
- encryption_configuration=mock.ANY,
- labels=LABELS,
- description=mock.ANY,
+ operator.execute(context=MagicMock())
+
+ hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS},
+ }
)
- # fmt: on
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_description_external_table(self, bq_hook):
+ def test_source_objs_as_string_without_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ schema_fields=SCHEMA_FIELDS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
+ )
+
+ operator.execute(context=MagicMock())
+
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ ]
+
+ hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_execute_should_throw_ex_when_no_bucket_specified(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ with pytest.raises(AirflowException, match=r"missing keyword argument 'bucket'"):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ schema_fields=SCHEMA_FIELDS,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
+ )
+ operator.execute(context=MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_execute_should_throw_ex_when_no_source_objects_specified(self, hook):
+
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ with pytest.raises(AirflowException, match=r"missing keyword argument 'source_objects'"):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ schema_fields=SCHEMA_FIELDS,
+ bucket=TEST_BUCKET,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
+ )
+ operator.execute(context=MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_execute_should_throw_ex_when_no_destination_project_dataset_table_specified(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ with pytest.raises(
+ AirflowException, match=r"missing keyword argument 'destination_project_dataset_table'"
+ ):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ schema_fields=SCHEMA_FIELDS,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
+ )
+ operator.execute(context=MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_schema_fields_scanner_external_table_should_execute_successfully(self, bq_hook, gcs_hook):
+ """
+ Check detection of schema fields if schema_fields parameter is not
+ specified and fields are read from source objects correctly by the operator
+ if all fields are characters. In this case operator searches for fields in source object
+ and update configuration with constructed schema_fields.
+ """
+ bq_hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ bq_hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ gcs_hook.return_value.download.return_value = b"id,name\r\none,Anna"
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
- description=DESCRIPTION,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
external_table=True,
+ autodetect=True,
)
- operator.execute(None)
- # fmt: off
- bq_hook.return_value.create_external_table.assert_called_once_with(
- external_project_dataset_table=mock.ANY,
- schema_fields=mock.ANY,
- source_uris=mock.ANY,
- source_format=mock.ANY,
- autodetect=mock.ANY,
- compression=mock.ANY,
- skip_leading_rows=mock.ANY,
- field_delimiter=mock.ANY,
- max_bad_records=mock.ANY,
- quote_character=mock.ANY,
- ignore_unknown_values=mock.ANY,
- allow_quoted_newlines=mock.ANY,
- allow_jagged_rows=mock.ANY,
- encoding=mock.ANY,
- src_fmt_configs=mock.ANY,
- encryption_configuration=mock.ANY,
- labels=mock.ANY,
- description=DESCRIPTION,
+ result = operator.execute(context=MagicMock())
+
+ assert result == "1"
+ bq_hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": 1,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS},
+ }
+ )
+ bq_hook.return_value.insert_job.assert_called_once_with(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=bq_hook.return_value.project_id,
)
- # fmt: on
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_source_objects_as_list(self, bq_hook):
+ def test_schema_fields_scanner_without_external_table_should_execute_successfully(
+ self, bq_hook, gcs_hook
+ ):
+ """
+ Check detection of schema fields if schema_fields parameter is not
+ specified and fields are read from source objects correctly by the operator
+ if all fields are characters. In this case operator searches for fields in source object
+ and update configuration with constructed schema_fields.
+ """
+ bq_hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ bq_hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ gcs_hook.return_value.download.return_value = b"id,name\r\none,Anna"
+
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ max_id_key=MAX_ID_KEY,
+ external_table=False,
+ autodetect=True,
)
- operator.execute(None)
-
- bq_hook.return_value.run_load.assert_called_once_with(
- destination_project_dataset_table=mock.ANY,
- schema_fields=mock.ANY,
- source_uris=[f"gs://{TEST_BUCKET}/{source_object}" for source_object in TEST_SOURCE_OBJECTS],
- source_format=mock.ANY,
- autodetect=mock.ANY,
- create_disposition=mock.ANY,
- skip_leading_rows=mock.ANY,
- write_disposition=mock.ANY,
- field_delimiter=mock.ANY,
- max_bad_records=mock.ANY,
- quote_character=mock.ANY,
- ignore_unknown_values=mock.ANY,
- allow_quoted_newlines=mock.ANY,
- allow_jagged_rows=mock.ANY,
- encoding=mock.ANY,
- schema_update_options=mock.ANY,
- src_fmt_configs=mock.ANY,
- time_partitioning=mock.ANY,
- cluster_fields=mock.ANY,
- encryption_configuration=mock.ANY,
- labels=mock.ANY,
- description=mock.ANY,
+ result = operator.execute(context=MagicMock())
+
+ assert result == "1"
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=1,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=bq_hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ call(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=bq_hook.return_value.project_id,
+ ),
+ ]
+
+ bq_hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_schema_fields_scanner_external_table_should_throw_ex_when_autodetect_not_specified(
+ self,
+ hook,
+ ):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ with pytest.raises(RuntimeError, match=r"Table schema was not found."):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=True,
+ autodetect=False,
+ )
+ operator.execute(context=MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_schema_fields_scanner_without_external_table_should_throw_ex_when_autodetect_not_specified(
+ self,
+ hook,
+ ):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ with pytest.raises(RuntimeError, match=r"Table schema was not found."):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ max_id_key=MAX_ID_KEY,
+ write_disposition=WRITE_DISPOSITION,
+ external_table=False,
+ autodetect=False,
+ )
+ operator.execute(context=MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_schema_fields_integer_scanner_external_table_should_execute_successfully(
+ self, bq_hook, gcs_hook
+ ):
+ """
+ Check detection of schema fields if schema_fields parameter is not
+ specified and fields are read from source objects correctly by BigQuery if at least
+ one field includes non-string value.
+ """
+ bq_hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ bq_hook.return_value.get_job.return_value.result.return_value = ("1",)
+ gcs_hook.return_value.download.return_value = b"id,name\r\n1,Anna"
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ max_id_key=MAX_ID_KEY,
+ external_table=True,
+ autodetect=True,
+ )
+
+ result = operator.execute(context=MagicMock())
+
+ assert result == "1"
+ bq_hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ }
+ )
+ bq_hook.return_value.insert_job.assert_called_once_with(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=bq_hook.return_value.project_id,
)
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
- def test_source_objects_as_string(self, bq_hook):
+ def test_schema_fields_integer_scanner_without_external_table_should_execute_successfully(
+ self, bq_hook, gcs_hook
+ ):
+ """
+ Check detection of schema fields if schema_fields parameter is not
+ specified and fields are read from source objects correctly by BigQuery if at least
+ one field includes non-string value.
+ """
+ bq_hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ bq_hook.return_value.get_job.return_value.result.return_value = ("1",)
+ gcs_hook.return_value.download.return_value = b"id,name\r\n1,Anna"
+
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
- source_objects=TEST_SOURCE_OBJECTS_AS_STRING,
+ source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ max_id_key=MAX_ID_KEY,
+ external_table=False,
+ autodetect=True,
)
- operator.execute(None)
-
- bq_hook.return_value.run_load.assert_called_once_with(
- destination_project_dataset_table=mock.ANY,
- schema_fields=mock.ANY,
- source_uris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
- source_format=mock.ANY,
- autodetect=mock.ANY,
- create_disposition=mock.ANY,
- skip_leading_rows=mock.ANY,
- write_disposition=mock.ANY,
- field_delimiter=mock.ANY,
- max_bad_records=mock.ANY,
- quote_character=mock.ANY,
- ignore_unknown_values=mock.ANY,
- allow_quoted_newlines=mock.ANY,
- allow_jagged_rows=mock.ANY,
- encoding=mock.ANY,
- schema_update_options=mock.ANY,
- src_fmt_configs=mock.ANY,
- time_partitioning=mock.ANY,
- cluster_fields=mock.ANY,
- encryption_configuration=mock.ANY,
- labels=mock.ANY,
- description=mock.ANY,
+ result = operator.execute(context=MagicMock())
+
+ assert result == "1"
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ ),
+ },
+ project_id=bq_hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ call(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=bq_hook.return_value.project_id,
+ ),
+ ]
+
+ bq_hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_schema_fields_without_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS_INT,
+ external_table=False,
+ autodetect=True,
+ )
+
+ operator.execute(context=MagicMock())
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS_INT},
+ ),
+ },
+ project_id=hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ ]
+
+ hook.return_value.insert_job.assert_has_calls(calls)
+
+ @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+ def test_schema_fields_external_table_should_execute_successfully(self, hook):
+ hook.return_value.insert_job.side_effect = [
+ MagicMock(job_id=pytest.real_job_id, error_result=False),
+ pytest.real_job_id,
+ ]
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS_INT,
+ external_table=True,
+ autodetect=True,
)
+
+ operator.execute(context=MagicMock())
+ hook.return_value.create_empty_table.assert_called_once_with(
+ table_resource={
+ "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ "labels": None,
+ "description": None,
+ "externalDataConfiguration": {
+ "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ "source_format": "CSV",
+ "maxBadRecords": 0,
+ "autodetect": True,
+ "compression": "NONE",
+ "csvOptions": {
+ "fieldDelimeter": ",",
+ "skipLeadingRows": None,
+ "quote": None,
+ "allowQuotedNewlines": False,
+ "allowJaggedRows": False,
+ },
+ },
+ "location": None,
+ "encryptionConfiguration": None,
+ "schema": {"fields": SCHEMA_FIELDS_INT},
+ }
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+def test_execute_without_external_table_async_should_execute_successfully(hook):
+ """
+ Asserts that a task is deferred and a BigQueryInsertJobTrigger will be fired
+ when Operator is executed in deferrable.
+ """
+ hook.return_value.insert_job.return_value = MagicMock(job_id=pytest.real_job_id, error_result=False)
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ hook.return_value.get_job.return_value.result.return_value = ("1",)
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(create_context(operator))
+
+ assert isinstance(
+ exc.value.trigger, BigQueryInsertJobTrigger
+ ), "Trigger is not a BigQueryInsertJobTrigger"
+
+
+def test_execute_without_external_table_async_should_throw_ex_when_event_status_error():
+ """
+ Tests that an AirflowException is raised in case of error event.
+ """
+
+ with pytest.raises(AirflowException):
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+ operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+
+def test_execute_logging_without_external_table_async_should_execute_successfully():
+ """
+ Asserts that logging occurs as expected.
+ """
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ operator.execute_complete(
+ context=create_context(operator),
+ event={"status": "success", "message": "Job completed", "job_id": job_id},
+ )
+ mock_log_info.assert_called_with(
+ "%s completed with response %s ", "test-gcs-to-bq-operator", "Job completed"
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+def test_execute_without_external_table_generate_job_id_async_should_execute_successfully(hook):
+ hook.return_value.insert_job.side_effect = Conflict("any")
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ job = MagicMock(
+ job_id=pytest.real_job_id,
+ error_result=False,
+ state="PENDING",
+ done=lambda: False,
+ )
+ hook.return_value.get_job.return_value = job
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ reattach_states={"PENDING"},
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred):
+ operator.execute(create_context(operator))
+
+ hook.return_value.generate_job_id.assert_called_once_with(
+ job_id=None,
+ dag_id="adhoc_airflow",
+ task_id=TASK_ID,
+ logical_date=datetime(2022, 1, 1, 0, 0),
+ configuration={},
+ force_rerun=True,
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+def test_execute_without_external_table_reattach_async_should_execute_successfully(hook):
+ hook.return_value.generate_job_id.return_value = pytest.real_job_id
+
+ hook.return_value.insert_job.side_effect = Conflict("any")
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ job = MagicMock(
+ job_id=pytest.real_job_id,
+ error_result=False,
+ state="PENDING",
+ done=lambda: False,
+ )
+ hook.return_value.get_job.return_value = job
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ location=TEST_DATASET_LOCATION,
+ reattach_states={"PENDING"},
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred):
+ operator.execute(create_context(operator))
+
+ hook.return_value.get_job.assert_called_once_with(
+ location=TEST_DATASET_LOCATION,
+ job_id=pytest.real_job_id,
+ project_id=hook.return_value.project_id,
+ )
+
+ job._begin.assert_called_once_with()
+
+
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+def test_execute_without_external_table_force_rerun_async_should_execute_successfully(hook):
+ hook.return_value.generate_job_id.return_value = f"{job_id}_{hash_}"
+ hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+
+ hook.return_value.insert_job.side_effect = Conflict("any")
+ job = MagicMock(
+ job_id=pytest.real_job_id,
+ error_result=False,
+ state="DONE",
+ done=lambda: False,
+ )
+ hook.return_value.get_job.return_value = job
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ location=TEST_DATASET_LOCATION,
+ reattach_states={"PENDING"},
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+
+ with pytest.raises(AirflowException) as exc:
+ operator.execute(create_context(operator))
+
+ expected_exception_msg = (
+ f"Job with id: {pytest.real_job_id} already exists and is in {job.state} state. "
+ f"If you want to force rerun it consider setting `force_rerun=True`."
+ f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
+ )
+
+ assert str(exc.value) == expected_exception_msg
+
+ hook.return_value.get_job.assert_called_once_with(
+ location=TEST_DATASET_LOCATION,
+ job_id=pytest.real_job_id,
+ project_id=hook.return_value.project_id,
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+def test_schema_fields_without_external_table_async_should_execute_successfully(bq_hook, gcs_hook):
+ bq_hook.return_value.insert_job.return_value = MagicMock(job_id=pytest.real_job_id, error_result=False)
+ bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ bq_hook.return_value.get_job.return_value.result.return_value = ("1",)
+ gcs_hook.return_value.download.return_value = b"id,name\r\none,Anna"
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ max_id_key=MAX_ID_KEY,
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred):
+ result = operator.execute(create_context(operator))
+ assert result == "1"
+
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ schema={"fields": SCHEMA_FIELDS},
+ ),
+ },
+ project_id=bq_hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ call(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=bq_hook.return_value.project_id,
+ ),
+ ]
+
+ bq_hook.return_value.insert_job.assert_has_calls(calls)
+
+
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
+@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
+def test_schema_fields_int_without_external_table_async_should_execute_successfully(bq_hook, gcs_hook):
+ bq_hook.return_value.insert_job.return_value = MagicMock(job_id=pytest.real_job_id, error_result=False)
+ bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
+ bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
+ bq_hook.return_value.get_job.return_value.result.return_value = ("1",)
+ gcs_hook.return_value.download.return_value = b"id,name\r\n1,Anna"
+
+ operator = GCSToBigQueryOperator(
+ task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ source_objects=TEST_SOURCE_OBJECTS,
+ destination_project_dataset_table=TEST_EXPLICIT_DEST,
+ write_disposition=WRITE_DISPOSITION,
+ schema_fields=SCHEMA_FIELDS,
+ max_id_key=MAX_ID_KEY,
+ external_table=False,
+ autodetect=True,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred):
+ result = operator.execute(create_context(operator))
+ assert result == "1"
+
+ calls = [
+ call(
+ configuration={
+ "load": dict(
+ autodetect=True,
+ createDisposition="CREATE_IF_NEEDED",
+ destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
+ destinationTableProperties={
+ "description": None,
+ "labels": None,
+ },
+ sourceFormat="CSV",
+ skipLeadingRows=None,
+ sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
+ writeDisposition=WRITE_DISPOSITION,
+ ignoreUnknownValues=False,
+ allowQuotedNewlines=False,
+ encoding="UTF-8",
+ ),
+ },
+ project_id=bq_hook.return_value.project_id,
+ location=None,
+ job_id=pytest.real_job_id,
+ timeout=None,
+ retry=DEFAULT_RETRY,
+ nowait=True,
+ ),
+ call(
+ configuration={
+ "query": {
+ "query": f"SELECT MAX({MAX_ID_KEY}) AS max_value FROM {TEST_EXPLICIT_DEST}",
+ "useLegacySql": False,
+ "schemaUpdateOptions": [],
+ }
+ },
+ project_id=bq_hook.return_value.project_id,
+ ),
+ ]
+
+ bq_hook.return_value.insert_job.assert_has_calls(calls)
+
+
+def create_context(task):
+ dag = DAG(dag_id="dag")
+ logical_date = datetime(2022, 1, 1, 0, 0, 0)
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ execution_date=logical_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date),
+ )
+ task_instance = TaskInstance(task=task)
+ task_instance.dag_run = dag_run
+ task_instance.dag_id = dag.dag_id
+ task_instance.xcom_push = mock.Mock()
+ return {
+ "dag": dag,
+ "run_id": dag_run.run_id,
+ "task": task,
+ "ti": task_instance,
+ "task_instance": task_instance,
+ "logical_date": logical_date,
+ }
diff --git a/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py b/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py
new file mode 100644
index 0000000000..ebfebab585
--- /dev/null
+++ b/tests/system/providers/google/cloud/gcs/example_gcs_to_bigquery_async.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG using GCSToBigQueryOperator.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.bigquery import (
+ BigQueryCreateEmptyDatasetOperator,
+ BigQueryDeleteDatasetOperator,
+)
+from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "gcs_to_bigquery_operator_async"
+
+DATASET_NAME_STR = f"dataset_{DAG_ID}_{ENV_ID}_STR"
+DATASET_NAME_DATE = f"dataset_{DAG_ID}_{ENV_ID}_DATE"
+TABLE_NAME_STR = "test_str"
+TABLE_NAME_DATE = "test_date"
+MAX_ID_STR = "name"
+MAX_ID_DATE = "date"
+
+with models.DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "gcs"],
+) as dag:
+ create_test_dataset_for_string_fileds = BigQueryCreateEmptyDatasetOperator(
+ task_id="create_airflow_test_dataset_str", dataset_id=DATASET_NAME_STR, project_id=PROJECT_ID
+ )
+
+ create_test_dataset_for_date_fileds = BigQueryCreateEmptyDatasetOperator(
+ task_id="create_airflow_test_dataset_date", dataset_id=DATASET_NAME_DATE, project_id=PROJECT_ID
+ )
+
+ # [START howto_operator_gcs_to_bigquery_async]
+ load_string_based_csv = GCSToBigQueryOperator(
+ task_id="gcs_to_bigquery_example_str_csv_async",
+ bucket="cloud-samples-data",
+ source_objects=["bigquery/us-states/us-states.csv"],
+ destination_project_dataset_table=f"{DATASET_NAME_STR}.{TABLE_NAME_STR}",
+ write_disposition="WRITE_TRUNCATE",
+ external_table=False,
+ autodetect=True,
+ max_id_key=MAX_ID_STR,
+ deferrable=True,
+ )
+
+ load_date_based_csv = GCSToBigQueryOperator(
+ task_id="gcs_to_bigquery_example_date_csv_async",
+ bucket="cloud-samples-data",
+ source_objects=["bigquery/us-states/us-states-by-date.csv"],
+ destination_project_dataset_table=f"{DATASET_NAME_DATE}.{TABLE_NAME_DATE}",
+ write_disposition="WRITE_TRUNCATE",
+ external_table=False,
+ autodetect=True,
+ max_id_key=MAX_ID_DATE,
+ deferrable=True,
+ )
+ # [END howto_operator_gcs_to_bigquery_async]
+
+ delete_test_dataset_str = BigQueryDeleteDatasetOperator(
+ task_id="delete_airflow_test_str_dataset",
+ dataset_id=DATASET_NAME_STR,
+ delete_contents=True,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_test_dataset_date = BigQueryDeleteDatasetOperator(
+ task_id="delete_airflow_test_date_dataset",
+ dataset_id=DATASET_NAME_DATE,
+ delete_contents=True,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_test_dataset_for_string_fileds
+ >> create_test_dataset_for_date_fileds
+ # TEST BODY
+ >> load_string_based_csv
+ >> load_date_based_csv
+ # TEST TEARDOWN
+ >> delete_test_dataset_str
+ >> delete_test_dataset_date
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)