You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2022/09/08 21:17:41 UTC
[airflow] branch main updated: Add deferrable big query operators and sensors (#26156)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f938cd4fc8 Add deferrable big query operators and sensors (#26156)
f938cd4fc8 is described below
commit f938cd4fc867513e729aa9a005d663c9713f74e6
Author: Phani Kumar <94...@users.noreply.github.com>
AuthorDate: Fri Sep 9 02:47:33 2022 +0530
Add deferrable big query operators and sensors (#26156)
This PR donates the following big query deferrable operators and sensors developed in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow.
- `BigQueryInsertJobAsyncOperator`
- `BigQueryCheckAsyncOperator`
- `BigQueryGetDataAsyncOperator`
- `BigQueryIntervalCheckAsyncOperator`
- `BigQueryValueCheckAsyncOperator`
- `BigQueryTableExistenceAsyncSensor`
---
airflow/providers/google/cloud/hooks/bigquery.py | 259 ++++-
.../providers/google/cloud/operators/bigquery.py | 464 +++++++++
airflow/providers/google/cloud/sensors/bigquery.py | 77 +-
.../providers/google/cloud/triggers/bigquery.py | 528 ++++++++++
.../providers/google/common/hooks/base_google.py | 25 +-
airflow/providers/google/provider.yaml | 3 +
.../operators/cloud/bigquery.rst | 89 +-
docs/spelling_wordlist.txt | 2 +
generated/provider_dependencies.json | 3 +
.../providers/google/cloud/hooks/test_bigquery.py | 219 ++++-
.../google/cloud/operators/test_bigquery.py | 552 ++++++++++-
.../google/cloud/sensors/test_bigquery.py | 68 ++
tests/providers/google/cloud/triggers/__init__.py | 16 +
.../google/cloud/triggers/test_bigquery.py | 1040 ++++++++++++++++++++
.../bigquery/example_bigquery_queries_async.py | 251 +++++
.../cloud/bigquery/example_bigquery_sensors.py | 12 +-
16 files changed, 3582 insertions(+), 26 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py
index 3ec149b577..d9852218da 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -29,8 +29,10 @@ import uuid
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
-from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union
+from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union, cast
+from aiohttp import ClientSession as ClientSession
+from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core.retry import Retry
from google.cloud.bigquery import (
DEFAULT_RETRY,
@@ -49,12 +51,13 @@ from googleapiclient.discovery import Resource, build
from pandas import DataFrame
from pandas_gbq import read_gbq
from pandas_gbq.gbq import GbqConnector # noqa
+from requests import Session
from sqlalchemy import create_engine
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -2305,7 +2308,6 @@ class BigQueryBaseCursor(LoggingMixin):
num_retries: int = 5,
labels: Optional[Dict] = None,
) -> None:
-
super().__init__()
self.service = service
self.project_id = project_id
@@ -2870,7 +2872,6 @@ def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, s
def split_tablename(
table_input: str, default_project_id: str, var_name: Optional[str] = None
) -> Tuple[str, str, str]:
-
if '.' not in table_input:
raise ValueError(f'Expected table name in the format of <dataset>.<table>. Got: {table_input}')
@@ -3010,3 +3011,253 @@ def _format_schema_for_description(schema: Dict) -> List:
)
description.append(field_description)
return description
+
+
+class BigQueryAsyncHook(GoogleBaseAsyncHook):
+ """Uses gcloud-aio library to retrieve Job details"""
+
+ sync_hook_class = BigQueryHook
+
+ async def get_job_instance(
+ self, project_id: Optional[str], job_id: Optional[str], session: ClientSession
+ ) -> Job:
+ """Get the specified job resource by job ID and project ID."""
+ with await self.service_file_as_context() as f:
+ return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session))
+
+ async def get_job_status(
+ self,
+ job_id: Optional[str],
+ project_id: Optional[str] = None,
+ ) -> Optional[str]:
+ """
+ Polls for job status asynchronously using gcloud-aio.
+
+ Note that an OSError is raised when Job results are still pending.
+ Exception means that Job finished with errors
+ """
+ async with ClientSession() as s:
+ try:
+ self.log.info("Executing get_job_status...")
+ job_client = await self.get_job_instance(project_id, job_id, s)
+ job_status_response = await job_client.result(cast(Session, s))
+ if job_status_response:
+ job_status = "success"
+ except OSError:
+ job_status = "pending"
+ except Exception as e:
+ self.log.info("Query execution finished with errors...")
+ job_status = str(e)
+ return job_status
+
+ async def get_job_output(
+ self,
+ job_id: Optional[str],
+ project_id: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """Get the big query job output for the given job id asynchronously using gcloud-aio."""
+ async with ClientSession() as session:
+ self.log.info("Executing get_job_output..")
+ job_client = await self.get_job_instance(project_id, job_id, session)
+ job_query_response = await job_client.get_query_results(cast(Session, session))
+ return job_query_response
+
+ def get_records(self, query_results: Dict[str, Any]) -> List[Any]:
+ """
+ Given the output query response from gcloud-aio bigquery, convert the response to records.
+
+ :param query_results: the results from a SQL query
+ """
+ buffer = []
+ if "rows" in query_results and query_results["rows"]:
+ rows = query_results["rows"]
+ for dict_row in rows:
+ typed_row = [vs["v"] for vs in dict_row["f"]]
+ buffer.append(typed_row)
+ return buffer
+
+ def value_check(
+ self,
+ sql: str,
+ pass_value: Any,
+ records: List[Any],
+ tolerance: Optional[float] = None,
+ ) -> None:
+ """
+ Match a single query resulting row and tolerance with pass_value
+
+ :return: If Match fail, we throw an AirflowException.
+ """
+ if not records:
+ raise AirflowException("The query returned None")
+ pass_value_conv = self._convert_to_float_if_possible(pass_value)
+ is_numeric_value_check = isinstance(pass_value_conv, float)
+ tolerance_pct_str = str(tolerance * 100) + "%" if tolerance else None
+
+ error_msg = (
+ "Test failed.\nPass value:{pass_value_conv}\n"
+ "Tolerance:{tolerance_pct_str}\n"
+ "Query:\n{sql}\nResults:\n{records!s}"
+ ).format(
+ pass_value_conv=pass_value_conv,
+ tolerance_pct_str=tolerance_pct_str,
+ sql=sql,
+ records=records,
+ )
+
+ if not is_numeric_value_check:
+ tests = [str(record) == pass_value_conv for record in records]
+ else:
+ try:
+ numeric_records = [float(record) for record in records]
+ except (ValueError, TypeError):
+ raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
+ tests = self._get_numeric_matches(numeric_records, pass_value_conv, tolerance)
+
+ if not all(tests):
+ raise AirflowException(error_msg)
+
+ @staticmethod
+ def _get_numeric_matches(
+ records: List[float], pass_value: Any, tolerance: Optional[float] = None
+ ) -> List[bool]:
+ """
+ A helper function to match numeric pass_value, tolerance with records value
+
+ :param records: List of value to match against
+ :param pass_value: Expected value
+ :param tolerance: Allowed tolerance for match to succeed
+ """
+ if tolerance:
+ return [
+ pass_value * (1 - tolerance) <= record <= pass_value * (1 + tolerance) for record in records
+ ]
+
+ return [record == pass_value for record in records]
+
+ @staticmethod
+ def _convert_to_float_if_possible(s: Any) -> Any:
+ """
+ A small helper function to convert a string to a numeric value if appropriate
+
+ :param s: the string to be converted
+ """
+ try:
+ return float(s)
+ except (ValueError, TypeError):
+ return s
+
+ def interval_check(
+ self,
+ row1: Optional[str],
+ row2: Optional[str],
+ metrics_thresholds: Dict[str, Any],
+ ignore_zero: bool,
+ ratio_formula: str,
+ ) -> None:
+ """
+ Checks that the values of metrics given as SQL expressions are within a certain tolerance
+
+ :param row1: first resulting row of a query execution job for first SQL query
+ :param row2: first resulting row of a query execution job for second SQL query
+ :param metrics_thresholds: a dictionary of ratios indexed by metrics, for
+ example 'COUNT(*)': 1.5 would require a 50 percent or less difference
+ between the current day, and the prior days_back.
+ :param ignore_zero: whether we should ignore zero metrics
+ :param ratio_formula: which formula to use to compute the ratio between
+ the two metrics. Assuming cur is the metric of today and ref is
+ the metric to today - days_back.
+ max_over_min: computes max(cur, ref) / min(cur, ref)
+ relative_diff: computes abs(cur-ref) / ref
+ """
+ if not row2:
+ raise AirflowException("The second SQL query returned None")
+ if not row1:
+ raise AirflowException("The first SQL query returned None")
+
+ ratio_formulas = {
+ "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
+ "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
+ }
+
+ metrics_sorted = sorted(metrics_thresholds.keys())
+
+ current = dict(zip(metrics_sorted, row1))
+ reference = dict(zip(metrics_sorted, row2))
+ ratios: Dict[str, Any] = {}
+ test_results: Dict[str, Any] = {}
+
+ for metric in metrics_sorted:
+ cur = float(current[metric])
+ ref = float(reference[metric])
+ threshold = float(metrics_thresholds[metric])
+ if cur == 0 or ref == 0:
+ ratios[metric] = None
+ test_results[metric] = ignore_zero
+ else:
+ ratios[metric] = ratio_formulas[ratio_formula](
+ float(current[metric]), float(reference[metric])
+ )
+ test_results[metric] = float(ratios[metric]) < threshold
+
+ self.log.info(
+ (
+ "Current metric for %s: %s\n"
+ "Past metric for %s: %s\n"
+ "Ratio for %s: %s\n"
+ "Threshold: %s\n"
+ ),
+ metric,
+ cur,
+ metric,
+ ref,
+ metric,
+ ratios[metric],
+ threshold,
+ )
+
+ if not all(test_results.values()):
+ failed_tests = [metric for metric, value in test_results.items() if not value]
+ self.log.warning(
+ "The following %s tests out of %s failed:",
+ len(failed_tests),
+ len(metrics_sorted),
+ )
+ for k in failed_tests:
+ self.log.warning(
+ "'%s' check failed. %s is above %s",
+ k,
+ ratios[k],
+ metrics_thresholds[k],
+ )
+ raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
+
+ self.log.info("All tests have passed")
+
+
+class BigQueryTableAsyncHook(GoogleBaseAsyncHook):
+ """Class to get async hook for Bigquery Table Async"""
+
+ sync_hook_class = BigQueryHook
+
+ async def get_table_client(
+ self, dataset: str, table_id: str, project_id: str, session: ClientSession
+ ) -> Table_async:
+ """
+ Returns a Google Big Query Table object.
+
+ :param dataset: The name of the dataset in which to look for the table storage bucket.
+ :param table_id: The name of the table to check the existence of.
+ :param project_id: The Google cloud project in which to look for the table.
+ The connection supplied to the hook must provide
+ access to the specified project.
+ :param session: aiohttp ClientSession
+ """
+ with await self.service_file_as_context() as file:
+ return Table_async(
+ dataset_name=dataset,
+ table_name=table_id,
+ project=project_id,
+ service_file=file,
+ session=cast(Session, session),
+ )
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 74b9bd0a2e..b361378c8c 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -40,6 +40,13 @@ from airflow.providers.common.sql.operators.sql import (
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.bigquery import BigQueryDatasetLink, BigQueryTableLink
+from airflow.providers.google.cloud.triggers.bigquery import (
+ BigQueryCheckTrigger,
+ BigQueryGetDataTrigger,
+ BigQueryInsertJobTrigger,
+ BigQueryIntervalCheckTrigger,
+ BigQueryValueCheckTrigger,
+)
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
@@ -2241,3 +2248,460 @@ class BigQueryInsertJobOperator(BaseOperator):
)
else:
self.log.info('Skipping to cancel job: %s:%s.%s', self.project_id, self.location, self.job_id)
+
+
+class BigQueryInsertJobAsyncOperator(BigQueryInsertJobOperator, BaseOperator):
+ """
+ Starts a BigQuery job asynchronously, and returns job id.
+ This operator works in the following way:
+
+ - it calculates a unique hash of the job using job's configuration or uuid if ``force_rerun`` is True
+ - creates ``job_id`` in form of
+ ``[provided_job_id | airflow_{dag_id}_{task_id}_{exec_date}]_{uniqueness_suffix}``
+ - submits a BigQuery job using the ``job_id``
+ - if job with given id already exists then it tries to reattach to the job if its not done and its
+ state is in ``reattach_states``. If the job is done the operator will raise ``AirflowException``.
+
+ Using ``force_rerun`` will submit a new job every time without attaching to already existing ones.
+
+ For job definition see here:
+
+ https://cloud.google.com/bigquery/docs/reference/v2/jobs
+
+ :param configuration: The configuration parameter maps directly to BigQuery's
+ configuration field in the job object. For more details see
+ https://cloud.google.com/bigquery/docs/reference/v2/jobs
+ :param job_id: The ID of the job. It will be suffixed with hash of job configuration
+ unless ``force_rerun`` is True.
+ The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or
+ dashes (-). The maximum length is 1,024 characters. If not provided then uuid will
+ be generated.
+ :param force_rerun: If True then operator will use hash of uuid as job id suffix
+ :param reattach_states: Set of BigQuery job's states in case of which we should reattach
+ to the job. Should be other than final states.
+ :param project_id: Google Cloud Project where the job is running
+ :param location: location the job is running
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
+ """
+
+ def _submit_job(self, hook: BigQueryHook, job_id: str) -> BigQueryJob: # type: ignore[override]
+ """Submit a new job and get the job id for polling the status using Triggerer."""
+ return hook.insert_job(
+ configuration=self.configuration,
+ project_id=self.project_id,
+ location=self.location,
+ job_id=job_id,
+ nowait=True,
+ )
+
+ def execute(self, context: Any) -> None:
+ hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
+
+ self.hook = hook
+ job_id = self.hook.generate_job_id(
+ job_id=self.job_id,
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=context["logical_date"],
+ configuration=self.configuration,
+ force_rerun=self.force_rerun,
+ )
+
+ try:
+ job = self._submit_job(hook, job_id)
+ self._handle_job_error(job)
+ except Conflict:
+ # If the job already exists retrieve it
+ job = hook.get_job(
+ project_id=self.project_id,
+ location=self.location,
+ job_id=job_id,
+ )
+ if job.state in self.reattach_states:
+ # We are reattaching to a job
+ job._begin()
+ self._handle_job_error(job)
+ else:
+ # Same job configuration so we need force_rerun
+ raise AirflowException(
+ f"Job with id: {job_id} already exists and is in {job.state} state. If you "
+ f"want to force rerun it consider setting `force_rerun=True`."
+ f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
+ )
+
+ self.job_id = job.job_id
+ context["ti"].xcom_push(key="job_id", value=self.job_id)
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BigQueryInsertJobTrigger(
+ conn_id=self.gcp_conn_id,
+ job_id=self.job_id,
+ project_id=self.project_id,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Any, event: Dict[str, Any]) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(
+ "%s completed with response %s ",
+ self.task_id,
+ event["message"],
+ )
+
+
+class BigQueryCheckAsyncOperator(BigQueryCheckOperator):
+ """
+ BigQueryCheckAsyncOperator is asynchronous operator, submit the job and check
+ for the status in async mode by using the job id
+ """
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ job_id: str,
+ ) -> BigQueryJob:
+ """Submit a new job and get the job id for polling the status using Trigger."""
+ configuration = {"query": {"query": self.sql}}
+
+ return hook.insert_job(
+ configuration=configuration,
+ project_id=hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ nowait=True,
+ )
+
+ def execute(self, context: Any) -> None:
+ hook = BigQueryHook(
+ gcp_conn_id=self.gcp_conn_id,
+ )
+ job = self._submit_job(hook, job_id="")
+ context["ti"].xcom_push(key="job_id", value=job.job_id)
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BigQueryCheckTrigger(
+ conn_id=self.gcp_conn_id,
+ job_id=job.job_id,
+ project_id=hook.project_id,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Any, event: Dict[str, Any]) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+
+ records = event["records"]
+ if not records:
+ raise AirflowException("The query returned None")
+ elif not all(bool(r) for r in records):
+ raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
+ self.log.info("Record: %s", event["records"])
+ self.log.info("Success.")
+
+
+class BigQueryGetDataAsyncOperator(BigQueryGetDataOperator):
+ """
+ Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
+ and returns data in a python list. The number of elements in the returned list will
+ be equal to the number of rows fetched. Each element in the list will again be a list
+ where element would represent the columns values for that row.
+
+ **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]``
+
+ .. note::
+ If you pass fields to ``selected_fields`` which are in different order than the
+ order of columns already in
+ BQ table, the data will still be in the order of BQ table.
+ For example if the BQ table has 3 columns as
+ ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields``
+ the data would still be of the form ``'A,B'``.
+
+ **Example**: ::
+
+ get_data = BigQueryGetDataOperator(
+ task_id='get_data_from_bq',
+ dataset_id='test_dataset',
+ table_id='Transaction_partitions',
+ max_results=100,
+ selected_fields='DATE',
+ gcp_conn_id='airflow-conn-id'
+ )
+
+ :param dataset_id: The dataset ID of the requested table. (templated)
+ :param table_id: The table ID of the requested table. (templated)
+ :param max_results: The maximum number of records (rows) to be fetched from the table. (templated)
+ :param selected_fields: List of fields to return (comma-separated). If
+ unspecified, all fields are returned.
+ :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :param location: The location used for the operation.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ job_id: str,
+ configuration: Dict[str, Any],
+ ) -> BigQueryJob:
+ """Submit a new job and get the job id for polling the status using Triggerer."""
+ return hook.insert_job(
+ configuration=configuration,
+ location=self.location,
+ project_id=hook.project_id,
+ job_id=job_id,
+ nowait=True,
+ )
+
+ def generate_query(self) -> str:
+ """
+ Generate a select query if selected fields are given or with *
+ for the given dataset and table id
+ """
+ selected_fields = self.selected_fields if self.selected_fields else "*"
+ return f"select {selected_fields} from {self.dataset_id}.{self.table_id} limit {self.max_results}"
+
+ def execute(self, context: Any) -> None: # type: ignore[override]
+ get_query = self.generate_query()
+ configuration = {"query": {"query": get_query}}
+
+ hook = BigQueryHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ location=self.location,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ self.hook = hook
+ job = self._submit_job(hook, job_id="", configuration=configuration)
+ self.job_id = job.job_id
+ context["ti"].xcom_push(key="job_id", value=self.job_id)
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BigQueryGetDataTrigger(
+ conn_id=self.gcp_conn_id,
+ job_id=self.job_id,
+ dataset_id=self.dataset_id,
+ table_id=self.table_id,
+ project_id=hook.project_id,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Any, event: Dict[str, Any]) -> Any:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+
+ self.log.info("Total extracted rows: %s", len(event["records"]))
+ return event["records"]
+
+
+class BigQueryIntervalCheckAsyncOperator(BigQueryIntervalCheckOperator):
+ """
+ Checks asynchronously that the values of metrics given as SQL expressions are within
+ a certain tolerance of the ones from days_back before.
+
+ This method constructs a query like so ::
+ SELECT {metrics_threshold_dict_key} FROM {table}
+ WHERE {date_filter_column}=<date>
+
+ :param table: the table name
+ :param days_back: number of days between ds and the ds we want to check
+ against. Defaults to 7 days
+ :param metrics_thresholds: a dictionary of ratios indexed by metrics, for
+ example 'COUNT(*)': 1.5 would require a 50 percent or less difference
+ between the current day, and the prior days_back.
+ :param use_legacy_sql: Whether to use legacy SQL (true)
+ or standard SQL (false).
+ :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
+ :param location: The geographic location of the job. See details at:
+ https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :param labels: a dictionary containing labels for the table, passed to BigQuery
+ """
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ sql: str,
+ job_id: str,
+ ) -> BigQueryJob:
+ """Submit a new job and get the job id for polling the status using Triggerer."""
+ configuration = {"query": {"query": sql}}
+ return hook.insert_job(
+ configuration=configuration,
+ project_id=hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ nowait=True,
+ )
+
+ def execute(self, context: Any) -> None:
+ """Execute the job in sync mode and defers the trigger with job id to poll for the status"""
+ hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
+ self.log.info("Using ratio formula: %s", self.ratio_formula)
+
+ self.log.info("Executing SQL check: %s", self.sql1)
+ job_1 = self._submit_job(hook, sql=self.sql1, job_id="")
+ context["ti"].xcom_push(key="job_id", value=job_1.job_id)
+
+ self.log.info("Executing SQL check: %s", self.sql2)
+ job_2 = self._submit_job(hook, sql=self.sql2, job_id="")
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BigQueryIntervalCheckTrigger(
+ conn_id=self.gcp_conn_id,
+ first_job_id=job_1.job_id,
+ second_job_id=job_2.job_id,
+ project_id=hook.project_id,
+ table=self.table,
+ metrics_thresholds=self.metrics_thresholds,
+ date_filter_column=self.date_filter_column,
+ days_back=self.days_back,
+ ratio_formula=self.ratio_formula,
+ ignore_zero=self.ignore_zero,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Any, event: Dict[str, Any]) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+
+ self.log.info(
+ "%s completed with response %s ",
+ self.task_id,
+ event["status"],
+ )
+
+
+class BigQueryValueCheckAsyncOperator(BigQueryValueCheckOperator):
+ """
+ Performs a simple value check using sql code.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:BigQueryValueCheckOperator`
+
+ :param sql: the sql to be executed
+ :param use_legacy_sql: Whether to use legacy SQL (true)
+ or standard SQL (false).
+ :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
+ :param location: The geographic location of the job. See details at:
+ https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :param labels: a dictionary containing labels for the table, passed to BigQuery
+ """
+
+ def _submit_job(
+ self,
+ hook: BigQueryHook,
+ job_id: str,
+ ) -> BigQueryJob:
+ """Submit a new job and get the job id for polling the status using Triggerer."""
+ configuration = {
+ "query": {
+ "query": self.sql,
+ "useLegacySql": False,
+ }
+ }
+ if self.use_legacy_sql:
+ configuration["query"]["useLegacySql"] = self.use_legacy_sql
+
+ return hook.insert_job(
+ configuration=configuration,
+ project_id=hook.project_id,
+ location=self.location,
+ job_id=job_id,
+ nowait=True,
+ )
+
+ def execute(self, context: Any) -> None:
+ hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
+
+ job = self._submit_job(hook, job_id="")
+ context["ti"].xcom_push(key="job_id", value=job.job_id)
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BigQueryValueCheckTrigger(
+ conn_id=self.gcp_conn_id,
+ job_id=job.job_id,
+ project_id=hook.project_id,
+ sql=self.sql,
+ pass_value=self.pass_value,
+ tolerance=self.tol,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Any, event: Dict[str, Any]) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(
+ "%s completed with response %s ",
+ self.task_id,
+ event["message"],
+ )
diff --git a/airflow/providers/google/cloud/sensors/bigquery.py b/airflow/providers/google/cloud/sensors/bigquery.py
index f0a9d67f58..ddcd203eb8 100644
--- a/airflow/providers/google/cloud/sensors/bigquery.py
+++ b/airflow/providers/google/cloud/sensors/bigquery.py
@@ -16,9 +16,12 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google Bigquery sensor."""
-from typing import TYPE_CHECKING, Optional, Sequence, Union
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
+from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -68,7 +71,6 @@ class BigQueryTableExistenceSensor(BaseSensorOperator):
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
) -> None:
-
super().__init__(**kwargs)
self.project_id = project_id
@@ -137,7 +139,6 @@ class BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
) -> None:
-
super().__init__(**kwargs)
self.project_id = project_id
@@ -162,3 +163,73 @@ class BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
table_id=self.table_id,
partition_id=self.partition_id,
)
+
+
+class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor):
+ """
+ Checks for the existence of a table in Google Big Query.
+
+ :param project_id: The Google cloud project in which to look for the table.
+ The connection supplied to the hook must provide
+ access to the specified project.
+ :param dataset_id: The name of the dataset in which to look for the table.
+ storage bucket.
+ :param table_id: The name of the table to check the existence of.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
+ This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :param polling_interval: The interval in seconds to wait between checks table existence.
+ """
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ polling_interval: float = 5.0,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.polling_interval = polling_interval
+ self.gcp_conn_id = gcp_conn_id
+
+ def execute(self, context: 'Context') -> None:
+ """Airflow runs this method on the worker and defers using the trigger."""
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=BigQueryTableExistenceTrigger(
+ dataset_id=self.dataset_id,
+ table_id=self.table_id,
+ project_id=self.project_id,
+ poll_interval=self.polling_interval,
+ gcp_conn_id=self.gcp_conn_id,
+ hook_params={
+ "delegate_to": self.delegate_to,
+ "impersonation_chain": self.impersonation_chain,
+ },
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
+ self.log.info("Sensor checks existence of table: %s", table_uri)
+ if event:
+ if event["status"] == "success":
+ return event["message"]
+ raise AirflowException(event["message"])
+ raise AirflowException("No event received in trigger callback")
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py
new file mode 100644
index 0000000000..56e985ea1f
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -0,0 +1,528 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import asyncio
+from typing import Any, AsyncIterator, Dict, Optional, SupportsAbs, Tuple, Union
+
+from aiohttp import ClientSession
+from aiohttp.client_exceptions import ClientResponseError
+
+from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class BigQueryInsertJobTrigger(BaseTrigger):
+ """
+ BigQueryInsertJobTrigger run on the trigger worker to perform insert operation
+
+ :param conn_id: Reference to google cloud connection id
+ :param job_id: The ID of the job. It will be suffixed with hash of job configuration
+ :param project_id: Google Cloud Project where the job is running
+ :param dataset_id: The dataset ID of the requested table. (templated)
+ :param table_id: The table ID of the requested table. (templated)
+ :param poll_interval: polling period in seconds to check for the status
+ """
+
+ def __init__(
+ self,
+ conn_id: str,
+ job_id: Optional[str],
+ project_id: Optional[str],
+ dataset_id: Optional[str] = None,
+ table_id: Optional[str] = None,
+ poll_interval: float = 4.0,
+ ):
+ super().__init__()
+ self.log.info("Using the connection %s .", conn_id)
+ self.conn_id = conn_id
+ self.job_id = job_id
+ self._job_conn = None
+ self.dataset_id = dataset_id
+ self.project_id = project_id
+ self.table_id = table_id
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ """Serializes BigQueryInsertJobTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger",
+ {
+ "conn_id": self.conn_id,
+ "job_id": self.job_id,
+ "dataset_id": self.dataset_id,
+ "project_id": self.project_id,
+ "table_id": self.table_id,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
+ """Gets current job execution status and yields a TriggerEvent"""
+ hook = self._get_async_hook()
+ while True:
+ try:
+ # Poll for job execution status
+ response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
+ self.log.debug("Response from hook: %s", response_from_hook)
+
+ if response_from_hook == "success":
+ yield TriggerEvent(
+ {
+ "job_id": self.job_id,
+ "status": "success",
+ "message": "Job completed",
+ }
+ )
+ elif response_from_hook == "pending":
+ self.log.info("Query is still running...")
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ else:
+ yield TriggerEvent({"status": "error", "message": response_from_hook})
+
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for query completion")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+
+ def _get_async_hook(self) -> BigQueryAsyncHook:
+ return BigQueryAsyncHook(gcp_conn_id=self.conn_id)
+
+
+class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
+ """BigQueryCheckTrigger run on the trigger worker"""
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ """Serializes BigQueryCheckTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger",
+ {
+ "conn_id": self.conn_id,
+ "job_id": self.job_id,
+ "dataset_id": self.dataset_id,
+ "project_id": self.project_id,
+ "table_id": self.table_id,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
+ """Gets current job execution status and yields a TriggerEvent"""
+ hook = self._get_async_hook()
+ while True:
+ try:
+ # Poll for job execution status
+ response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
+ if response_from_hook == "success":
+ query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
+
+ records = hook.get_records(query_results)
+
+ # If empty list, then no records are available
+ if not records:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "records": None,
+ }
+ )
+ else:
+ # Extract only first record from the query results
+ first_record = records.pop(0)
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "records": first_record,
+ }
+ )
+ return
+
+ elif response_from_hook == "pending":
+ self.log.info("Query is still running...")
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ else:
+ yield TriggerEvent({"status": "error", "message": response_from_hook})
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for query completion")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+
+
+class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
+ """BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class"""
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ """Serializes BigQueryInsertJobTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger",
+ {
+ "conn_id": self.conn_id,
+ "job_id": self.job_id,
+ "dataset_id": self.dataset_id,
+ "project_id": self.project_id,
+ "table_id": self.table_id,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
+ """Gets current job execution status and yields a TriggerEvent with response data"""
+ hook = self._get_async_hook()
+ while True:
+ try:
+ # Poll for job execution status
+ response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
+ if response_from_hook == "success":
+ query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
+ records = hook.get_records(query_results)
+ self.log.debug("Response from hook: %s", response_from_hook)
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": response_from_hook,
+ "records": records,
+ }
+ )
+ return
+ elif response_from_hook == "pending":
+ self.log.info("Query is still running...")
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ else:
+ yield TriggerEvent({"status": "error", "message": response_from_hook})
+ return
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for query completion")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+
+class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
+ """
+ BigQueryIntervalCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class
+
+ :param conn_id: Reference to google cloud connection id
+ :param first_job_id: The ID of the job 1 performed
+ :param second_job_id: The ID of the job 2 performed
+ :param project_id: Google Cloud Project where the job is running
+ :param dataset_id: The dataset ID of the requested table. (templated)
+ :param table: table name
+ :param metrics_thresholds: dictionary of ratios indexed by metrics
+ :param date_filter_column: column name
+ :param days_back: number of days between ds and the ds we want to check
+ against
+ :param ratio_formula: ration formula
+ :param ignore_zero: boolean value to consider zero or not
+ :param table_id: The table ID of the requested table. (templated)
+ :param poll_interval: polling period in seconds to check for the status
+ """
+
+ def __init__(
+ self,
+ conn_id: str,
+ first_job_id: str,
+ second_job_id: str,
+ project_id: Optional[str],
+ table: str,
+ metrics_thresholds: Dict[str, int],
+ date_filter_column: Optional[str] = "ds",
+ days_back: SupportsAbs[int] = -7,
+ ratio_formula: str = "max_over_min",
+ ignore_zero: bool = True,
+ dataset_id: Optional[str] = None,
+ table_id: Optional[str] = None,
+ poll_interval: float = 4.0,
+ ):
+ super().__init__(
+ conn_id=conn_id,
+ job_id=first_job_id,
+ project_id=project_id,
+ dataset_id=dataset_id,
+ table_id=table_id,
+ poll_interval=poll_interval,
+ )
+ self.conn_id = conn_id
+ self.first_job_id = first_job_id
+ self.second_job_id = second_job_id
+ self.project_id = project_id
+ self.table = table
+ self.metrics_thresholds = metrics_thresholds
+ self.date_filter_column = date_filter_column
+ self.days_back = days_back
+ self.ratio_formula = ratio_formula
+ self.ignore_zero = ignore_zero
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ """Serializes BigQueryCheckTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger",
+ {
+ "conn_id": self.conn_id,
+ "first_job_id": self.first_job_id,
+ "second_job_id": self.second_job_id,
+ "project_id": self.project_id,
+ "table": self.table,
+ "metrics_thresholds": self.metrics_thresholds,
+ "date_filter_column": self.date_filter_column,
+ "days_back": self.days_back,
+ "ratio_formula": self.ratio_formula,
+ "ignore_zero": self.ignore_zero,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
+ """Gets current job execution status and yields a TriggerEvent"""
+ hook = self._get_async_hook()
+ while True:
+ try:
+ first_job_response_from_hook = await hook.get_job_status(
+ job_id=self.first_job_id, project_id=self.project_id
+ )
+ second_job_response_from_hook = await hook.get_job_status(
+ job_id=self.second_job_id, project_id=self.project_id
+ )
+
+ if first_job_response_from_hook == "success" and second_job_response_from_hook == "success":
+ first_query_results = await hook.get_job_output(
+ job_id=self.first_job_id, project_id=self.project_id
+ )
+
+ second_query_results = await hook.get_job_output(
+ job_id=self.second_job_id, project_id=self.project_id
+ )
+
+ first_records = hook.get_records(first_query_results)
+
+ second_records = hook.get_records(second_query_results)
+
+ # If empty list, then no records are available
+ if not first_records:
+ first_job_row: Optional[str] = None
+ else:
+ # Extract only first record from the query results
+ first_job_row = first_records.pop(0)
+
+ # If empty list, then no records are available
+ if not second_records:
+ second_job_row: Optional[str] = None
+ else:
+ # Extract only first record from the query results
+ second_job_row = second_records.pop(0)
+
+ hook.interval_check(
+ first_job_row,
+ second_job_row,
+ self.metrics_thresholds,
+ self.ignore_zero,
+ self.ratio_formula,
+ )
+
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Job completed",
+ "first_row_data": first_job_row,
+ "second_row_data": second_job_row,
+ }
+ )
+ return
+ elif first_job_response_from_hook == "pending" or second_job_response_from_hook == "pending":
+ self.log.info("Query is still running...")
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ else:
+ yield TriggerEvent(
+ {"status": "error", "message": second_job_response_from_hook, "data": None}
+ )
+ return
+
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for query completion")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+
+class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
+ """
+ BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class
+
+ :param conn_id: Reference to google cloud connection id
+ :param sql: the sql to be executed
+ :param pass_value: pass value
+ :param job_id: The ID of the job
+ :param project_id: Google Cloud Project where the job is running
+ :param tolerance: certain metrics for tolerance
+ :param dataset_id: The dataset ID of the requested table. (templated)
+ :param table_id: The table ID of the requested table. (templated)
+ :param poll_interval: polling period in seconds to check for the status
+ """
+
+ def __init__(
+ self,
+ conn_id: str,
+ sql: str,
+ pass_value: Union[int, float, str],
+ job_id: Optional[str],
+ project_id: Optional[str],
+ tolerance: Any = None,
+ dataset_id: Optional[str] = None,
+ table_id: Optional[str] = None,
+ poll_interval: float = 4.0,
+ ):
+ super().__init__(
+ conn_id=conn_id,
+ job_id=job_id,
+ project_id=project_id,
+ dataset_id=dataset_id,
+ table_id=table_id,
+ poll_interval=poll_interval,
+ )
+ self.sql = sql
+ self.pass_value = pass_value
+ self.tolerance = tolerance
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ """Serializes BigQueryValueCheckTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger",
+ {
+ "conn_id": self.conn_id,
+ "pass_value": self.pass_value,
+ "job_id": self.job_id,
+ "dataset_id": self.dataset_id,
+ "project_id": self.project_id,
+ "sql": self.sql,
+ "table_id": self.table_id,
+ "tolerance": self.tolerance,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
+ """Gets current job execution status and yields a TriggerEvent"""
+ hook = self._get_async_hook()
+ while True:
+ try:
+ # Poll for job execution status
+ response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
+ if response_from_hook == "success":
+ query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
+ records = hook.get_records(query_results)
+ records = records.pop(0) if records else None
+ hook.value_check(self.sql, self.pass_value, records, self.tolerance)
+ yield TriggerEvent({"status": "success", "message": "Job completed", "records": records})
+ return
+ elif response_from_hook == "pending":
+ self.log.info("Query is still running...")
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ else:
+ yield TriggerEvent({"status": "error", "message": response_from_hook, "records": None})
+ return
+
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for query completion")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+
+class BigQueryTableExistenceTrigger(BaseTrigger):
+ """
+ Initialize the BigQuery Table Existence Trigger with needed parameters
+
+ :param project_id: Google Cloud Project where the job is running
+ :param dataset_id: The dataset ID of the requested table.
+ :param table_id: The table ID of the requested table.
+ :param gcp_conn_id: Reference to google cloud connection id
+ :param hook_params: params for hook
+ :param poll_interval: polling period in seconds to check for the status
+ """
+
+ def __init__(
+ self,
+ project_id: str,
+ dataset_id: str,
+ table_id: str,
+ gcp_conn_id: str,
+ hook_params: Dict[str, Any],
+ poll_interval: float = 4.0,
+ ):
+ self.dataset_id = dataset_id
+ self.project_id = project_id
+ self.table_id = table_id
+ self.gcp_conn_id: str = gcp_conn_id
+ self.poll_interval = poll_interval
+ self.hook_params = hook_params
+
+ def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ """Serializes BigQueryTableExistenceTrigger arguments and classpath."""
+ return (
+ "airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger",
+ {
+ "dataset_id": self.dataset_id,
+ "project_id": self.project_id,
+ "table_id": self.table_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "poll_interval": self.poll_interval,
+ "hook_params": self.hook_params,
+ },
+ )
+
+ def _get_async_hook(self) -> BigQueryTableAsyncHook:
+ return BigQueryTableAsyncHook(gcp_conn_id=self.gcp_conn_id)
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
+ """Will run until the table exists in the Google Big Query."""
+ while True:
+ try:
+ hook = self._get_async_hook()
+ response = await self._table_exists(
+ hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id
+ )
+ if response:
+ yield TriggerEvent({"status": "success", "message": "success"})
+ return
+ await asyncio.sleep(self.poll_interval)
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for Table existence")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+ async def _table_exists(
+ self, hook: BigQueryTableAsyncHook, dataset: str, table_id: str, project_id: str
+ ) -> bool:
+ """
+ Create client session and make call to BigQueryTableAsyncHook and check for the table in
+ Google Big Query.
+
+ :param hook: BigQueryTableAsyncHook Hook class
+ :param dataset: The name of the dataset in which to look for the table storage bucket.
+ :param table_id: The name of the table to check the existence of.
+ :param project_id: The Google cloud project in which to look for the table.
+ The connection supplied to the hook must provide
+ access to the specified project.
+ """
+ async with ClientSession() as session:
+ try:
+ client = await hook.get_table_client(
+ dataset=dataset, table_id=table_id, project_id=project_id, session=session
+ )
+ response = await client.get()
+ return True if response else False
+ except ClientResponseError as err:
+ if err.status == 404:
+ return False
+ raise err
diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py
index 79ed286a3d..e2179bedfa 100644
--- a/airflow/providers/google/common/hooks/base_google.py
+++ b/airflow/providers/google/common/hooks/base_google.py
@@ -33,6 +33,7 @@ import google.oauth2.service_account
import google_auth_httplib2
import requests
import tenacity
+from asgiref.sync import sync_to_async
from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth import _cloud_sdk, compute_engine
@@ -56,7 +57,6 @@ from airflow.utils.process_utils import patch_environ
log = logging.getLogger(__name__)
-
# Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota.
INVALID_KEYS = [
'DefaultRequestsPerMinutePerProject',
@@ -602,3 +602,26 @@ class GoogleBaseHook(BaseHook):
message = str(e)
return status, message
+
+
+class GoogleBaseAsyncHook(BaseHook):
+ """GoogleBaseAsyncHook inherits from BaseHook class, run on the trigger worker"""
+
+ sync_hook_class: Any = None
+
+ def __init__(self, **kwargs: Any):
+ self._hook_kwargs = kwargs
+ self._sync_hook = None
+
+ async def get_sync_hook(self) -> Any:
+ """
+ Sync version of the Google Cloud Hooks makes blocking calls in ``__init__`` so we don't inherit
+ from it.
+ """
+ if not self._sync_hook:
+ self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs)
+ return self._sync_hook
+
+ async def service_file_as_context(self) -> Any:
+ sync_hook = await self.get_sync_hook()
+ return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)()
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 5923805302..a3edbf5867 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -64,6 +64,9 @@ dependencies:
# Introduced breaking changes across the board. Those libraries should be upgraded soon
# TODO: Upgrade all Google libraries that are limited to <2.0.0
- PyOpenSSL
+ - asgiref
+ - gcloud-aio-bigquery
+ - gcloud-aio-storage
- google-ads>=15.1.1
- google-api-core>=2.7.0,<3.0.0
- google-api-python-client>=1.6.0,<2.0.0
diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
index 57e4d87ff8..6672d2ce58 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
@@ -202,7 +202,8 @@ Fetch data from table
"""""""""""""""""""""
To fetch data from a BigQuery table you can use
-:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator`.
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator` or
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataAsyncOperator` .
Alternatively you can fetch data for selected columns if you pass fields to
``selected_fields``.
@@ -217,6 +218,17 @@ that row.
:start-after: [START howto_operator_bigquery_get_data]
:end-before: [END howto_operator_bigquery_get_data]
+The below example shows how to use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataAsyncOperator`.
+Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow
+deployment.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_bigquery_get_data_async]
+ :end-before: [END howto_operator_bigquery_get_data_async]
+
.. _howto/operator:BigQueryUpsertTableOperator:
Upsert table
@@ -294,9 +306,10 @@ Let's say you would like to execute the following query.
:start-after: [START howto_operator_bigquery_query]
:end-before: [END howto_operator_bigquery_query]
-To execute the SQL query in a specific BigQuery database you can use
-:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobOperator` with
-proper query job configuration that can be Jinja templated.
+To execute the SQL query in a specific BigQuery database you can use either
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobOperator` or
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobAsyncOperator`
+with proper query job configuration that can be Jinja templated.
.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
@@ -304,6 +317,17 @@ proper query job configuration that can be Jinja templated.
:start-after: [START howto_operator_bigquery_insert_job]
:end-before: [END howto_operator_bigquery_insert_job]
+The below example shows how to use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobAsyncOperator`.
+Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow
+deployment.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_bigquery_insert_job_async]
+ :end-before: [END howto_operator_bigquery_insert_job_async]
+
For more information on types of BigQuery job please check
`documentation <https://cloud.google.com/bigquery/docs/reference/v2/jobs>`__.
@@ -332,8 +356,9 @@ Validate data
Check if query result has data
""""""""""""""""""""""""""""""
-To perform checks against BigQuery you can use
-:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator`.
+To perform checks against BigQuery you can use either
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator` or
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckAsyncOperator`
This operator expects a sql query that will return a single row. Each value on
that first row is evaluated using python ``bool`` casting. If any of the values
@@ -345,15 +370,25 @@ return ``False`` the check is failed and errors out.
:start-after: [START howto_operator_bigquery_check]
:end-before: [END howto_operator_bigquery_check]
+Below example shows the usage of :class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCheckAsyncOperator`,
+which is the deferrable version of the operator
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_bigquery_check_async]
+ :end-before: [END howto_operator_bigquery_check_async]
+
.. _howto/operator:BigQueryValueCheckOperator:
Compare query result to pass value
""""""""""""""""""""""""""""""""""
To perform a simple value check using sql code you can use
-:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator`.
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator` or
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckAsyncOperator`
-This operator expects a sql query that will return a single row. Each value on
+These operators expects a sql query that will return a single row. Each value on
that first row is evaluated against ``pass_value`` which can be either a string
or numeric value. If numeric, you can also specify ``tolerance``.
@@ -363,14 +398,26 @@ or numeric value. If numeric, you can also specify ``tolerance``.
:start-after: [START howto_operator_bigquery_value_check]
:end-before: [END howto_operator_bigquery_value_check]
+The below example shows how to use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckAsyncOperator`.
+Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow
+deployment.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_bigquery_value_check_async]
+ :end-before: [END howto_operator_bigquery_value_check_async]
+
.. _howto/operator:BigQueryIntervalCheckOperator:
Compare metrics over time
"""""""""""""""""""""""""
To check that the values of metrics given as SQL expressions are within a certain
-tolerance of the ones from ``days_back`` before you can use
-:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator`.
+tolerance of the ones from ``days_back`` before you can either use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator` or
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckAsyncOperator`
.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
@@ -378,6 +425,17 @@ tolerance of the ones from ``days_back`` before you can use
:start-after: [START howto_operator_bigquery_interval_check]
:end-before: [END howto_operator_bigquery_interval_check]
+The below example shows how to use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckAsyncOperator`.
+Note that this is a deferrable operator which requires the Triggerer to be running on your Airflow
+deployment.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_bigquery_interval_check_async]
+ :end-before: [END howto_operator_bigquery_interval_check_async]
+
Sensors
^^^^^^^
@@ -396,6 +454,17 @@ use the ``{{ ds_nodash }}`` macro as the table name suffix.
:start-after: [START howto_sensor_bigquery_table]
:end-before: [END howto_sensor_bigquery_table]
+Use the :class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor`
+(deferrable version) if you would like to free up the worker slots while the sensor is running.
+
+:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceAsyncSensor`.
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_async_bigquery_table]
+ :end-before: [END howto_sensor_async_bigquery_table]
+
Check that a Table Partition exists
"""""""""""""""""""""""""""""""""""
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 4d8bcd6eec..f7539f34ff 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -14,6 +14,7 @@ adhoc
adls
afterall
AgentKey
+aio
Airbnb
airbnb
Airbyte
@@ -71,6 +72,7 @@ asc
ascii
asciiart
asctime
+asend
asia
assertEqualIgnoreMultipleSpaces
assigment
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index a4cda45f89..bc266ec27a 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -301,6 +301,9 @@
"PyOpenSSL",
"apache-airflow-providers-common-sql>=1.1.0",
"apache-airflow>=2.2.0",
+ "asgiref",
+ "gcloud-aio-bigquery",
+ "gcloud-aio-storage",
"google-ads>=15.1.1",
"google-api-core>=2.7.0,<3.0.0",
"google-api-python-client>=1.6.0,<2.0.0",
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py
index 6d69431f14..c2342dd8fd 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -18,20 +18,23 @@
import re
+import sys
import unittest
from datetime import datetime
-from unittest import mock
import pytest
+from gcloud.aio.bigquery import Job, Table as Table_async
from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, TableReference
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
from google.cloud.exceptions import NotFound
from parameterized import parameterized
-from airflow import AirflowException
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import (
+ BigQueryAsyncHook,
BigQueryCursor,
BigQueryHook,
+ BigQueryTableAsyncHook,
_api_resource_configs_duplication_check,
_cleanse_time_partitioning,
_format_schema_for_description,
@@ -40,6 +43,13 @@ from airflow.providers.google.cloud.hooks.bigquery import (
split_tablename,
)
+if sys.version_info < (3, 8):
+ from asynctest import mock
+ from asynctest.mock import CoroutineMock as AsyncMock
+else:
+ from unittest import mock
+ from unittest.mock import AsyncMock
+
PROJECT_ID = "bq-project"
CREDENTIALS = "bq-credentials"
DATASET_ID = "bq_dataset"
@@ -2011,7 +2021,6 @@ class TestBigQueryBaseCursorMethodsDeprecationWarning(unittest.TestCase):
class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_load_labels(self, mock_insert):
-
labels = {'label1': 'test1', 'label2': 'test2'}
self.hook.run_load(
destination_project_dataset_table='my_dataset.my_table',
@@ -2025,7 +2034,6 @@ class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_run_load_description(self, mock_insert):
-
description = "Test Description"
self.hook.run_load(
destination_project_dataset_table='my_dataset.my_table',
@@ -2039,7 +2047,6 @@ class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table")
def test_create_external_table_labels(self, mock_create):
-
labels = {'label1': 'test1', 'label2': 'test2'}
self.hook.create_external_table(
external_project_dataset_table='my_dataset.my_table',
@@ -2053,7 +2060,6 @@ class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_table")
def test_create_external_table_description(self, mock_create):
-
description = "Test Description"
self.hook.create_external_table(
external_project_dataset_table='my_dataset.my_table',
@@ -2064,3 +2070,204 @@ class TestBigQueryWithLabelsAndDescription(_BigQueryBaseTestClass):
_, kwargs = mock_create.call_args
assert kwargs['table_resource']['description'] is description
+
+
+class _BigQueryBaseAsyncTestClass:
+ def setup_method(self) -> None:
+ class MockedBigQueryAsyncHook(BigQueryAsyncHook):
+ def get_credentials_and_project_id(self):
+ return CREDENTIALS, PROJECT_ID
+
+ self.hook = MockedBigQueryAsyncHook()
+
+
+class TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession")
+ async def test_get_job_instance(self, mock_session):
+ hook = BigQueryAsyncHook()
+ result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session)
+ assert isinstance(result, Job)
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+ async def test_get_job_status_success(self, mock_job_instance):
+ hook = BigQueryAsyncHook()
+ mock_job_client = AsyncMock(Job)
+ mock_job_instance.return_value = mock_job_client
+ response = "success"
+ mock_job_instance.return_value.result.return_value = response
+ resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
+ assert resp == response
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+ async def test_get_job_status_oserror(self, mock_job_instance):
+ """Assets that the BigQueryAsyncHook returns a pending response when OSError is raised"""
+ mock_job_instance.return_value.result.side_effect = OSError()
+ hook = BigQueryAsyncHook()
+ job_status = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
+ assert job_status == "pending"
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+ async def test_get_job_status_exception(self, mock_job_instance, caplog):
+ """Assets that the logging is done correctly when BigQueryAsyncHook raises Exception"""
+ mock_job_instance.return_value.result.side_effect = Exception()
+ hook = BigQueryAsyncHook()
+ await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
+ assert "Query execution finished with errors..." in caplog.text
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+ async def test_get_job_output_assert_once_with(self, mock_job_instance):
+ hook = BigQueryAsyncHook()
+ mock_job_client = AsyncMock(Job)
+ mock_job_instance.return_value = mock_job_client
+ response = "success"
+ mock_job_instance.return_value.get_query_results.return_value = response
+ resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID)
+ assert resp == response
+
+ def test_interval_check_for_airflow_exception(self):
+ """
+ Assert that check return AirflowException
+ """
+ hook = BigQueryAsyncHook()
+
+ row1, row2, metrics_thresholds, ignore_zero, ratio_formula = (
+ None,
+ "0",
+ {"COUNT(*)": 1.5},
+ True,
+ "max_over_min",
+ )
+ with pytest.raises(AirflowException):
+ hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula)
+
+ row1, row2, metrics_thresholds, ignore_zero, ratio_formula = (
+ "0",
+ None,
+ {"COUNT(*)": 1.5},
+ True,
+ "max_over_min",
+ )
+ with pytest.raises(AirflowException):
+ hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula)
+
+ row1, row2, metrics_thresholds, ignore_zero, ratio_formula = (
+ "1",
+ "1",
+ {"COUNT(*)": 0},
+ True,
+ "max_over_min",
+ )
+ with pytest.raises(AirflowException):
+ hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula)
+
+ def test_interval_check_for_success(self):
+ """
+ Assert that check return None
+ """
+ hook = BigQueryAsyncHook()
+
+ row1, row2, metrics_thresholds, ignore_zero, ratio_formula = (
+ "0",
+ "0",
+ {"COUNT(*)": 1.5},
+ True,
+ "max_over_min",
+ )
+ response = hook.interval_check(row1, row2, metrics_thresholds, ignore_zero, ratio_formula)
+ assert response is None
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+ async def test_get_job_output(self, mock_job_instance):
+ """
+ Tests to check if a particular object in Google Cloud Storage
+ is found or not
+ """
+ response = {
+ "kind": "bigquery#tableDataList",
+ "etag": "test_etag",
+ "schema": {"fields": [{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}]},
+ "jobReference": {
+ "projectId": "test_astronomer-airflow-providers",
+ "jobId": "test_jobid",
+ "location": "US",
+ },
+ "totalRows": "10",
+ "rows": [{"f": [{"v": "42"}, {"v": "monthy python"}]}, {"f": [{"v": "42"}, {"v": "fishy fish"}]}],
+ "totalBytesProcessed": "0",
+ "jobComplete": True,
+ "cacheHit": False,
+ }
+ hook = BigQueryAsyncHook()
+ mock_job_client = AsyncMock(Job)
+ mock_job_instance.return_value = mock_job_client
+ mock_job_client.get_query_results.return_value = response
+ resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID)
+ assert resp == response
+
+ @pytest.mark.parametrize(
+ "records,pass_value,tolerance", [(["str"], "str", None), ([2], 2, None), ([0], 2, 1), ([4], 2, 1)]
+ )
+ def test_value_check_success(self, records, pass_value, tolerance):
+ """
+ Assert that value_check method execution succeed
+ """
+ hook = BigQueryAsyncHook()
+ query = "SELECT COUNT(*) from Any"
+ response = hook.value_check(query, pass_value, records, tolerance)
+ assert response is None
+
+ @pytest.mark.parametrize(
+ "records,pass_value,tolerance",
+ [([], "", None), (["str"], "str1", None), ([2], 21, None), ([5], 2, 1), (["str"], 2, None)],
+ )
+ def test_value_check_fail(self, records, pass_value, tolerance):
+ """Assert that check raise AirflowException"""
+ hook = BigQueryAsyncHook()
+ query = "SELECT COUNT(*) from Any"
+
+ with pytest.raises(AirflowException) as ex:
+ hook.value_check(query, pass_value, records, tolerance)
+ assert isinstance(ex.value, AirflowException)
+
+ @pytest.mark.parametrize(
+ "records,pass_value,tolerance, expected",
+ [
+ ([2.0], 2.0, None, [True]),
+ ([2.0], 2.1, None, [False]),
+ ([2.0], 2.0, 0.5, [True]),
+ ([1.0], 2.0, 0.5, [True]),
+ ([3.0], 2.0, 0.5, [True]),
+ ([0.9], 2.0, 0.5, [False]),
+ ([3.1], 2.0, 0.5, [False]),
+ ],
+ )
+ def test_get_numeric_matches(self, records, pass_value, tolerance, expected):
+ """Assert the if response list have all element match with pass_value with tolerance"""
+
+ assert BigQueryAsyncHook._get_numeric_matches(records, pass_value, tolerance) == expected
+
+ @pytest.mark.parametrize("test_input,expected", [(5.0, 5.0), (5, 5.0), ("5", 5), ("str", "str")])
+ def test_convert_to_float_if_possible(self, test_input, expected):
+ """
+ Assert that type casting succeed for the possible value
+ Otherwise return the same value
+ """
+
+ assert BigQueryAsyncHook._convert_to_float_if_possible(test_input) == expected
+
+ @pytest.mark.asyncio
+ @mock.patch("aiohttp.client.ClientSession")
+ async def test_get_table_client(self, mock_session):
+ """Test get_table_client async function and check whether the return value is a
+ Table instance object"""
+ hook = BigQueryTableAsyncHook()
+ result = await hook.get_table_client(
+ dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session
+ )
+ assert isinstance(result, Table_async)
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py
index 7d53a017c1..4aebe77cd4 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -23,8 +23,12 @@ import pytest
from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.exceptions import Conflict
-from airflow.exceptions import AirflowException, AirflowTaskTimeout
+from airflow.exceptions import AirflowException, AirflowTaskTimeout, TaskDeferred
+from airflow.models import DAG
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.operators.bigquery import (
+ BigQueryCheckAsyncOperator,
BigQueryCheckOperator,
BigQueryConsoleIndexableLink,
BigQueryConsoleLink,
@@ -34,20 +38,32 @@ from airflow.providers.google.cloud.operators.bigquery import (
BigQueryDeleteDatasetOperator,
BigQueryDeleteTableOperator,
BigQueryExecuteQueryOperator,
+ BigQueryGetDataAsyncOperator,
BigQueryGetDataOperator,
BigQueryGetDatasetOperator,
BigQueryGetDatasetTablesOperator,
+ BigQueryInsertJobAsyncOperator,
BigQueryInsertJobOperator,
+ BigQueryIntervalCheckAsyncOperator,
BigQueryIntervalCheckOperator,
BigQueryPatchDatasetOperator,
BigQueryUpdateDatasetOperator,
BigQueryUpdateTableOperator,
BigQueryUpdateTableSchemaOperator,
BigQueryUpsertTableOperator,
+ BigQueryValueCheckAsyncOperator,
BigQueryValueCheckOperator,
)
+from airflow.providers.google.cloud.triggers.bigquery import (
+ BigQueryCheckTrigger,
+ BigQueryGetDataTrigger,
+ BigQueryInsertJobTrigger,
+ BigQueryIntervalCheckTrigger,
+ BigQueryValueCheckTrigger,
+)
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
+from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags, clear_db_xcom
TASK_ID = 'test-bq-generic-operator'
@@ -71,6 +87,7 @@ MATERIALIZED_VIEW_DEFINITION = {
'enableRefresh': True,
'refreshIntervalMs': 2000000,
}
+TEST_TABLE = "test-table"
class TestBigQueryCreateEmptyTableOperator(unittest.TestCase):
@@ -1103,3 +1120,536 @@ class TestBigQueryInsertJobOperator:
# No force rerun
with pytest.raises(AirflowException):
op.execute(context=MagicMock())
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_insert_job_operator_async(mock_hook):
+ """
+ Asserts that a task is deferred and a BigQueryInsertJobTrigger will be fired
+ when the BigQueryInsertJobAsyncOperator is executed.
+ """
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+
+ op = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(create_context(op))
+
+ assert isinstance(
+ exc.value.trigger, BigQueryInsertJobTrigger
+ ), "Trigger is not a BigQueryInsertJobTrigger"
+
+
+def test_bigquery_insert_job_operator_execute_failure():
+ """Tests that an AirflowException is raised in case of error event"""
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+ job_id = "123456"
+
+ operator = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+
+def create_context(task):
+ dag = DAG(dag_id="dag")
+ logical_date = datetime(2022, 1, 1, 0, 0, 0)
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ execution_date=logical_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, logical_date),
+ )
+ task_instance = TaskInstance(task=task)
+ task_instance.dag_run = dag_run
+ task_instance.dag_id = dag.dag_id
+ task_instance.xcom_push = mock.Mock()
+ return {
+ "dag": dag,
+ "run_id": dag_run.run_id,
+ "task": task,
+ "ti": task_instance,
+ "task_instance": task_instance,
+ "logical_date": logical_date,
+ }
+
+
+def test_bigquery_insert_job_operator_execute_complete():
+ """Asserts that logging occurs as expected"""
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+ job_id = "123456"
+
+ operator = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ operator.execute_complete(
+ context=create_context(operator),
+ event={"status": "success", "message": "Job completed", "job_id": job_id},
+ )
+ mock_log_info.assert_called_with("%s completed with response %s ", "insert_query_job", "Job completed")
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_insert_job_operator_with_job_id_generate(mock_hook):
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+
+ mock_hook.return_value.insert_job.side_effect = Conflict("any")
+ job = MagicMock(
+ job_id=real_job_id,
+ error_result=False,
+ state="PENDING",
+ done=lambda: False,
+ )
+ mock_hook.return_value.get_job.return_value = job
+
+ op = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ reattach_states={"PENDING"},
+ )
+
+ with pytest.raises(TaskDeferred):
+ op.execute(create_context(op))
+
+ mock_hook.return_value.generate_job_id.assert_called_once_with(
+ job_id=job_id,
+ dag_id="adhoc_airflow",
+ task_id="insert_query_job",
+ logical_date=datetime(2022, 1, 1, 0, 0),
+ configuration=configuration,
+ force_rerun=True,
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_execute_reattach(mock_hook):
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+ mock_hook.return_value.generate_job_id.return_value = f"{job_id}_{hash_}"
+
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+
+ mock_hook.return_value.insert_job.side_effect = Conflict("any")
+ job = MagicMock(
+ job_id=real_job_id,
+ error_result=False,
+ state="PENDING",
+ done=lambda: False,
+ )
+ mock_hook.return_value.get_job.return_value = job
+
+ op = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ reattach_states={"PENDING"},
+ )
+
+ with pytest.raises(TaskDeferred):
+ op.execute(create_context(op))
+
+ mock_hook.return_value.get_job.assert_called_once_with(
+ location=TEST_DATASET_LOCATION,
+ job_id=real_job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+
+ job._begin.assert_called_once_with()
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_execute_force_rerun(mock_hook):
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+ mock_hook.return_value.generate_job_id.return_value = f"{job_id}_{hash_}"
+
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+
+ mock_hook.return_value.insert_job.side_effect = Conflict("any")
+ job = MagicMock(
+ job_id=real_job_id,
+ error_result=False,
+ state="DONE",
+ done=lambda: False,
+ )
+ mock_hook.return_value.get_job.return_value = job
+
+ op = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ reattach_states={"PENDING"},
+ )
+
+ with pytest.raises(AirflowException) as exc:
+ op.execute(create_context(op))
+
+ expected_exception_msg = (
+ f"Job with id: {real_job_id} already exists and is in {job.state} state. "
+ f"If you want to force rerun it consider setting `force_rerun=True`."
+ f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
+ )
+
+ assert str(exc.value) == expected_exception_msg
+
+ mock_hook.return_value.get_job.assert_called_once_with(
+ location=TEST_DATASET_LOCATION,
+ job_id=real_job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_check_operator_async(mock_hook):
+ """
+ Asserts that a task is deferred and a BigQueryCheckTrigger will be fired
+ when the BigQueryCheckAsyncOperator is executed.
+ """
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+
+ op = BigQueryCheckAsyncOperator(
+ task_id="bq_check_operator_job",
+ sql="SELECT * FROM any",
+ location=TEST_DATASET_LOCATION,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(create_context(op))
+
+ assert isinstance(exc.value.trigger, BigQueryCheckTrigger), "Trigger is not a BigQueryCheckTrigger"
+
+
+def test_bigquery_check_operator_execute_failure():
+ """Tests that an AirflowException is raised in case of error event"""
+
+ operator = BigQueryCheckAsyncOperator(
+ task_id="bq_check_operator_execute_failure", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+
+def test_bigquery_check_op_execute_complete_with_no_records():
+ """Asserts that exception is raised with correct expected exception message"""
+
+ operator = BigQueryCheckAsyncOperator(
+ task_id="bq_check_operator_execute_complete", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION
+ )
+
+ with pytest.raises(AirflowException) as exc:
+ operator.execute_complete(context=None, event={"status": "success", "records": None})
+
+ expected_exception_msg = "The query returned None"
+
+ assert str(exc.value) == expected_exception_msg
+
+
+def test_bigquery_check_op_execute_complete_with_non_boolean_records():
+ """Executing a sql which returns a non-boolean value should raise exception"""
+
+ test_sql = "SELECT * FROM any"
+
+ operator = BigQueryCheckAsyncOperator(
+ task_id="bq_check_operator_execute_complete", sql=test_sql, location=TEST_DATASET_LOCATION
+ )
+
+ expected_exception_msg = f"Test failed.\nQuery:\n{test_sql}\nResults:\n{[20, False]!s}"
+
+ with pytest.raises(AirflowException) as exc:
+ operator.execute_complete(context=None, event={"status": "success", "records": [20, False]})
+
+ assert str(exc.value) == expected_exception_msg
+
+
+def test_bigquery_check_operator_execute_complete():
+ """Asserts that logging occurs as expected"""
+
+ operator = BigQueryCheckAsyncOperator(
+ task_id="bq_check_operator_execute_complete", sql="SELECT * FROM any", location=TEST_DATASET_LOCATION
+ )
+
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ operator.execute_complete(context=None, event={"status": "success", "records": [20]})
+ mock_log_info.assert_called_with("Success.")
+
+
+def test_bigquery_interval_check_operator_execute_complete():
+ """Asserts that logging occurs as expected"""
+
+ operator = BigQueryIntervalCheckAsyncOperator(
+ task_id="bq_interval_check_operator_execute_complete",
+ table="test_table",
+ metrics_thresholds={"COUNT(*)": 1.5},
+ location=TEST_DATASET_LOCATION,
+ )
+
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ operator.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
+ mock_log_info.assert_called_with(
+ "%s completed with response %s ", "bq_interval_check_operator_execute_complete", "success"
+ )
+
+
+def test_bigquery_interval_check_operator_execute_failure():
+ """Tests that an AirflowException is raised in case of error event"""
+
+ operator = BigQueryIntervalCheckAsyncOperator(
+ task_id="bq_interval_check_operator_execute_complete",
+ table="test_table",
+ metrics_thresholds={"COUNT(*)": 1.5},
+ location=TEST_DATASET_LOCATION,
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_interval_check_operator_async(mock_hook):
+ """
+ Asserts that a task is deferred and a BigQueryIntervalCheckTrigger will be fired
+ when the BigQueryIntervalCheckAsyncOperator is executed.
+ """
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+
+ op = BigQueryIntervalCheckAsyncOperator(
+ task_id="bq_interval_check_operator_execute_complete",
+ table="test_table",
+ metrics_thresholds={"COUNT(*)": 1.5},
+ location=TEST_DATASET_LOCATION,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(create_context(op))
+
+ assert isinstance(
+ exc.value.trigger, BigQueryIntervalCheckTrigger
+ ), "Trigger is not a BigQueryIntervalCheckTrigger"
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_get_data_operator_async_with_selected_fields(mock_hook):
+ """
+ Asserts that a task is deferred and a BigQuerygetDataTrigger will be fired
+ when the BigQueryGetDataAsyncOperator is executed.
+ """
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+
+ op = BigQueryGetDataAsyncOperator(
+ task_id="get_data_from_bq",
+ dataset_id=TEST_DATASET,
+ table_id=TEST_TABLE,
+ max_results=100,
+ selected_fields="value,name",
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(create_context(op))
+
+ assert isinstance(exc.value.trigger, BigQueryGetDataTrigger), "Trigger is not a BigQueryGetDataTrigger"
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_get_data_operator_async_without_selected_fields(mock_hook):
+ """
+ Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired
+ when the BigQueryGetDataAsyncOperator is executed.
+ """
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+
+ op = BigQueryGetDataAsyncOperator(
+ task_id="get_data_from_bq",
+ dataset_id=TEST_DATASET,
+ table_id=TEST_TABLE,
+ max_results=100,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(create_context(op))
+
+ assert isinstance(exc.value.trigger, BigQueryGetDataTrigger), "Trigger is not a BigQueryGetDataTrigger"
+
+
+def test_bigquery_get_data_operator_execute_failure():
+ """Tests that an AirflowException is raised in case of error event"""
+
+ operator = BigQueryGetDataAsyncOperator(
+ task_id="get_data_from_bq",
+ dataset_id=TEST_DATASET,
+ table_id="any",
+ max_results=100,
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+
+def test_bigquery_get_data_op_execute_complete_with_records():
+ """Asserts that exception is raised with correct expected exception message"""
+
+ operator = BigQueryGetDataAsyncOperator(
+ task_id="get_data_from_bq",
+ dataset_id=TEST_DATASET,
+ table_id="any",
+ max_results=100,
+ )
+
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ operator.execute_complete(context=None, event={"status": "success", "records": [20]})
+ mock_log_info.assert_called_with("Total extracted rows: %s", 1)
+
+
+def _get_value_check_async_operator(use_legacy_sql: bool = False):
+ """Helper function to initialise BigQueryValueCheckOperatorAsync operator"""
+ query = "SELECT COUNT(*) FROM Any"
+ pass_val = 2
+
+ return BigQueryValueCheckAsyncOperator(
+ task_id="check_value",
+ sql=query,
+ pass_value=pass_val,
+ use_legacy_sql=use_legacy_sql,
+ )
+
+
+@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+def test_bigquery_value_check_async(mock_hook):
+ """
+ Asserts that a task is deferred and a BigQueryValueCheckTrigger will be fired
+ when the BigQueryValueCheckOperatorAsync is executed.
+ """
+ operator = _get_value_check_async_operator(True)
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(create_context(operator))
+
+ assert isinstance(
+ exc.value.trigger, BigQueryValueCheckTrigger
+ ), "Trigger is not a BigQueryValueCheckTrigger"
+
+
+def test_bigquery_value_check_operator_execute_complete_success():
+ """Tests response message in case of success event"""
+ operator = _get_value_check_async_operator()
+
+ assert (
+ operator.execute_complete(context=None, event={"status": "success", "message": "Job completed!"})
+ is None
+ )
+
+
+def test_bigquery_value_check_operator_execute_complete_failure():
+ """Tests that an AirflowException is raised in case of error event"""
+ operator = _get_value_check_async_operator()
+
+ with pytest.raises(AirflowException):
+ operator.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
+
+
+@pytest.mark.parametrize(
+ "kwargs, expected",
+ [
+ ({"sql": "SELECT COUNT(*) from Any"}, "missing keyword argument 'pass_value'"),
+ ({"pass_value": "Any"}, "missing keyword argument 'sql'"),
+ ],
+)
+def test_bigquery_value_check_missing_param(kwargs, expected):
+ """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator"""
+ with pytest.raises(AirflowException) as missing_param:
+ BigQueryValueCheckAsyncOperator(**kwargs)
+ assert missing_param.value.args[0] == expected
+
+
+def test_bigquery_value_check_empty():
+ """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator"""
+ expected, expected1 = (
+ "missing keyword arguments 'sql', 'pass_value'",
+ "missing keyword arguments 'pass_value', 'sql'",
+ )
+ with pytest.raises(AirflowException) as missing_param:
+ BigQueryValueCheckAsyncOperator(kwargs={})
+ assert (missing_param.value.args[0] == expected) or (missing_param.value.args[0] == expected1)
diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py b/tests/providers/google/cloud/sensors/test_bigquery.py
index 87ec3dbacb..5ea3b9b67a 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery.py
@@ -17,10 +17,15 @@
from unittest import TestCase, mock
+import pytest
+
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.google.cloud.sensors.bigquery import (
+ BigQueryTableExistenceAsyncSensor,
BigQueryTableExistenceSensor,
BigQueryTablePartitionExistenceSensor,
)
+from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger
TEST_PROJECT_ID = "test_project"
TEST_DATASET_ID = 'test_dataset'
@@ -87,3 +92,66 @@ class TestBigqueryTablePartitionExistenceSensor(TestCase):
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
)
+
+
+@pytest.fixture()
+def context():
+ """
+ Creates an empty context.
+ """
+ context = {}
+ yield context
+
+
+class TestBigQueryTableExistenceAsyncSensor(TestCase):
+ def test_big_query_table_existence_sensor_async(self):
+ """
+ Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
+ when the BigQueryTableExistenceAsyncSensor is executed.
+ """
+ task = BigQueryTableExistenceAsyncSensor(
+ task_id="check_table_exists",
+ project_id=TEST_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+ with pytest.raises(TaskDeferred) as exc:
+ task.execute(context={})
+ assert isinstance(
+ exc.value.trigger, BigQueryTableExistenceTrigger
+ ), "Trigger is not a BigQueryTableExistenceTrigger"
+
+ def test_big_query_table_existence_sensor_async_execute_failure(self):
+ """Tests that an AirflowException is raised in case of error event"""
+ task = BigQueryTableExistenceAsyncSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+ with pytest.raises(AirflowException):
+ task.execute_complete(context={}, event={"status": "error", "message": "test failure message"})
+
+ def test_big_query_table_existence_sensor_async_execute_complete(self):
+ """Asserts that logging occurs as expected"""
+ task = BigQueryTableExistenceAsyncSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+ table_uri = f"{TEST_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}"
+ with mock.patch.object(task.log, "info") as mock_log_info:
+ task.execute_complete(context={}, event={"status": "success", "message": "Job completed"})
+ mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri)
+
+ def test_big_query_sensor_async_execute_complete_event_none(self):
+ """Asserts that logging occurs as expected"""
+ task = BigQueryTableExistenceAsyncSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+ with pytest.raises(AirflowException):
+ task.execute_complete(context={}, event=None)
diff --git a/tests/providers/google/cloud/triggers/__init__.py b/tests/providers/google/cloud/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py
new file mode 100644
index 0000000000..7a771a124c
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -0,0 +1,1040 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import asyncio
+import logging
+import sys
+from typing import Any, Dict
+
+import pytest
+from aiohttp import ClientResponseError, RequestInfo
+from gcloud.aio.bigquery import Table
+from multidict import CIMultiDict
+from yarl import URL
+
+from airflow.providers.google.cloud.hooks.bigquery import BigQueryTableAsyncHook
+from airflow.providers.google.cloud.triggers.bigquery import (
+ BigQueryCheckTrigger,
+ BigQueryGetDataTrigger,
+ BigQueryInsertJobTrigger,
+ BigQueryIntervalCheckTrigger,
+ BigQueryTableExistenceTrigger,
+ BigQueryValueCheckTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+ from asynctest import mock
+ from asynctest.mock import CoroutineMock as AsyncMock
+else:
+ from unittest import mock
+ from unittest.mock import AsyncMock
+
+TEST_CONN_ID = "bq_default"
+TEST_JOB_ID = "1234"
+RUN_ID = "1"
+RETRY_LIMIT = 2
+RETRY_DELAY = 1.0
+TEST_GCP_PROJECT_ID = "test-project"
+TEST_DATASET_ID = "bq_dataset"
+TEST_TABLE_ID = "bq_table"
+POLLING_PERIOD_SECONDS = 4.0
+TEST_SQL_QUERY = "SELECT count(*) from Any"
+TEST_PASS_VALUE = 2
+TEST_TOLERANCE = 1
+TEST_FIRST_JOB_ID = "5678"
+TEST_SECOND_JOB_ID = "6789"
+TEST_METRIC_THRESHOLDS: Dict[str, int] = {}
+TEST_DATE_FILTER_COLUMN = "ds"
+TEST_DAYS_BACK = -7
+TEST_RATIO_FORMULA = "max_over_min"
+TEST_IGNORE_ZERO = True
+TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID"
+TEST_HOOK_PARAMS: Dict[str, Any] = {}
+
+
+def test_bigquery_insert_job_op_trigger_serialization():
+ """
+ Asserts that the BigQueryInsertJobTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = BigQueryInsertJobTrigger(
+ TEST_CONN_ID,
+ TEST_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger"
+ assert kwargs == {
+ "conn_id": TEST_CONN_ID,
+ "job_id": TEST_JOB_ID,
+ "project_id": TEST_GCP_PROJECT_ID,
+ "dataset_id": TEST_DATASET_ID,
+ "table_id": TEST_TABLE_ID,
+ "poll_interval": POLLING_PERIOD_SECONDS,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_insert_job_op_trigger_success(mock_job_status):
+ """
+ Tests the BigQueryInsertJobTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "success"
+
+ trigger = BigQueryInsertJobTrigger(
+ TEST_CONN_ID,
+ TEST_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "success", "message": "Job completed", "job_id": TEST_JOB_ID}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+async def test_bigquery_insert_job_trigger_running(mock_job_instance, caplog):
+ """
+ Test that BigQuery Triggers do not fire while a query is still running.
+ """
+
+ from gcloud.aio.bigquery import Job
+
+ mock_job_client = AsyncMock(Job)
+ mock_job_instance.return_value = mock_job_client
+ mock_job_instance.return_value.result.side_effect = OSError
+ caplog.set_level(logging.INFO)
+
+ trigger = BigQueryInsertJobTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+
+ assert f"Using the connection {TEST_CONN_ID} ." in caplog.text
+
+ assert "Query is still running..." in caplog.text
+ assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+async def test_bigquery_get_data_trigger_running(mock_job_instance, caplog):
+ """
+ Test that BigQuery Triggers do not fire while a query is still running.
+ """
+
+ from gcloud.aio.bigquery import Job
+
+ mock_job_client = AsyncMock(Job)
+ mock_job_instance.return_value = mock_job_client
+ mock_job_instance.return_value.result.side_effect = OSError
+ caplog.set_level(logging.INFO)
+
+ trigger = BigQueryGetDataTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+
+ assert f"Using the connection {TEST_CONN_ID} ." in caplog.text
+
+ assert "Query is still running..." in caplog.text
+ assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
+async def test_bigquery_check_trigger_running(mock_job_instance, caplog):
+ """
+ Test that BigQuery Triggers do not fire while a query is still running.
+ """
+
+ from gcloud.aio.bigquery import Job
+
+ mock_job_client = AsyncMock(Job)
+ mock_job_instance.return_value = mock_job_client
+ mock_job_instance.return_value.result.side_effect = OSError
+ caplog.set_level(logging.INFO)
+
+ trigger = BigQueryCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+
+ assert f"Using the connection {TEST_CONN_ID} ." in caplog.text
+
+ assert "Query is still running..." in caplog.text
+ assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_op_trigger_terminated(mock_job_status, caplog):
+ """
+ Test that BigQuery Triggers fire the correct event in case of an error.
+ """
+ # Set the status to a value other than success or pending
+
+ mock_job_status.return_value = "error"
+
+ trigger = BigQueryInsertJobTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "error"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_check_trigger_terminated(mock_job_status, caplog):
+ """
+ Test that BigQuery Triggers fire the correct event in case of an error.
+ """
+ # Set the status to a value other than success or pending
+
+ mock_job_status.return_value = "error"
+
+ trigger = BigQueryCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "error"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_get_data_trigger_terminated(mock_job_status, caplog):
+ """
+ Test that BigQuery Triggers fire the correct event in case of an error.
+ """
+ # Set the status to a value other than success or pending
+
+ mock_job_status.return_value = "error"
+
+ trigger = BigQueryGetDataTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "error"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_op_trigger_exception(mock_job_status, caplog):
+ """
+ Test that BigQuery Triggers fire the correct event in case of an error.
+ """
+ mock_job_status.side_effect = Exception("Test exception")
+
+ trigger = BigQueryInsertJobTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_check_trigger_exception(mock_job_status, caplog):
+ """
+ Test that BigQuery Triggers fire the correct event in case of an error.
+ """
+ mock_job_status.side_effect = Exception("Test exception")
+
+ trigger = BigQueryCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_get_data_trigger_exception(mock_job_status, caplog):
+ """
+ Test that BigQuery Triggers fire the correct event in case of an error.
+ """
+ mock_job_status.side_effect = Exception("Test exception")
+
+ trigger = BigQueryGetDataTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
+
+
+def test_bigquery_check_op_trigger_serialization():
+ """
+ Asserts that the BigQueryCheckTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = BigQueryCheckTrigger(
+ TEST_CONN_ID,
+ TEST_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger"
+ assert kwargs == {
+ "conn_id": TEST_CONN_ID,
+ "job_id": TEST_JOB_ID,
+ "dataset_id": TEST_DATASET_ID,
+ "project_id": TEST_GCP_PROJECT_ID,
+ "table_id": TEST_TABLE_ID,
+ "poll_interval": POLLING_PERIOD_SECONDS,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
+async def test_bigquery_check_op_trigger_success_with_data(mock_job_output, mock_job_status):
+ """
+ Test the BigQueryCheckTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "success"
+ mock_job_output.return_value = {
+ "kind": "bigquery#getQueryResultsResponse",
+ "etag": "test_etag",
+ "schema": {"fields": [{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}]},
+ "jobReference": {
+ "projectId": "test_airflow-providers",
+ "jobId": "test_jobid",
+ "location": "US",
+ },
+ "totalRows": "1",
+ "rows": [{"f": [{"v": "22"}]}],
+ "totalBytesProcessed": "0",
+ "jobComplete": True,
+ "cacheHit": False,
+ }
+
+ trigger = BigQueryCheckTrigger(
+ TEST_CONN_ID,
+ TEST_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+
+ assert TriggerEvent({"status": "success", "records": ["22"]}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
+async def test_bigquery_check_op_trigger_success_without_data(mock_job_output, mock_job_status):
+ """
+ Tests that BigQueryCheckTrigger sends TriggerEvent as { "status": "success", "records": None}
+ when no rows are available in the query result.
+ """
+ mock_job_status.return_value = "success"
+ mock_job_output.return_value = {
+ "kind": "bigquery#getQueryResultsResponse",
+ "etag": "test_etag",
+ "schema": {
+ "fields": [
+ {"name": "value", "type": "INTEGER", "mode": "NULLABLE"},
+ {"name": "name", "type": "STRING", "mode": "NULLABLE"},
+ {"name": "ds", "type": "DATE", "mode": "NULLABLE"},
+ ]
+ },
+ "jobReference": {
+ "projectId": "test_airflow-airflow-providers",
+ "jobId": "test_jobid",
+ "location": "US",
+ },
+ "totalRows": "0",
+ "totalBytesProcessed": "0",
+ "jobComplete": True,
+ "cacheHit": False,
+ }
+
+ trigger = BigQueryCheckTrigger(
+ TEST_CONN_ID,
+ TEST_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "success", "records": None}) == actual
+
+
+def test_bigquery_get_data_trigger_serialization():
+ """
+ Asserts that the BigQueryGetDataTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = BigQueryGetDataTrigger(
+ conn_id=TEST_CONN_ID,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger"
+ assert kwargs == {
+ "conn_id": TEST_CONN_ID,
+ "job_id": TEST_JOB_ID,
+ "dataset_id": TEST_DATASET_ID,
+ "project_id": TEST_GCP_PROJECT_ID,
+ "table_id": TEST_TABLE_ID,
+ "poll_interval": POLLING_PERIOD_SECONDS,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
+async def test_bigquery_get_data_trigger_success_with_data(mock_job_output, mock_job_status):
+ """
+ Tests that BigQueryGetDataTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "success"
+ mock_job_output.return_value = {
+ "kind": "bigquery#tableDataList",
+ "etag": "test_etag",
+ "schema": {"fields": [{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"}]},
+ "jobReference": {
+ "projectId": "test-airflow-providers",
+ "jobId": "test_jobid",
+ "location": "US",
+ },
+ "totalRows": "10",
+ "rows": [{"f": [{"v": "42"}, {"v": "monthy python"}]}, {"f": [{"v": "42"}, {"v": "fishy fish"}]}],
+ "totalBytesProcessed": "0",
+ "jobComplete": True,
+ "cacheHit": False,
+ }
+
+ trigger = BigQueryGetDataTrigger(
+ TEST_CONN_ID,
+ TEST_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ # # The extracted row will be parsed and formatted to retrieve the value from the
+ # # structure - 'rows":[{"f":[{"v":"42"},{"v":"monthy python"}]},{"f":[{"v":"42"},{"v":"fishy fish"}]}]
+
+ assert (
+ TriggerEvent(
+ {
+ "status": "success",
+ "message": "success",
+ "records": [["42", "monthy python"], ["42", "fishy fish"]],
+ }
+ )
+ == actual
+ )
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+
+def test_bigquery_interval_check_trigger_serialization():
+ """
+ Asserts that the BigQueryIntervalCheckTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = BigQueryIntervalCheckTrigger(
+ TEST_CONN_ID,
+ TEST_FIRST_JOB_ID,
+ TEST_SECOND_JOB_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_TABLE_ID,
+ TEST_METRIC_THRESHOLDS,
+ TEST_DATE_FILTER_COLUMN,
+ TEST_DAYS_BACK,
+ TEST_RATIO_FORMULA,
+ TEST_IGNORE_ZERO,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ POLLING_PERIOD_SECONDS,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger"
+ assert kwargs == {
+ "conn_id": TEST_CONN_ID,
+ "first_job_id": TEST_FIRST_JOB_ID,
+ "second_job_id": TEST_SECOND_JOB_ID,
+ "project_id": TEST_GCP_PROJECT_ID,
+ "table": TEST_TABLE_ID,
+ "metrics_thresholds": TEST_METRIC_THRESHOLDS,
+ "date_filter_column": TEST_DATE_FILTER_COLUMN,
+ "days_back": TEST_DAYS_BACK,
+ "ratio_formula": TEST_RATIO_FORMULA,
+ "ignore_zero": TEST_IGNORE_ZERO,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
+async def test_bigquery_interval_check_trigger_success(mock_get_job_output, mock_job_status):
+ """
+ Tests the BigQueryIntervalCheckTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "success"
+ mock_get_job_output.return_value = ["0"]
+
+ trigger = BigQueryIntervalCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ first_job_id=TEST_FIRST_JOB_ID,
+ second_job_id=TEST_SECOND_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ table=TEST_TABLE_ID,
+ metrics_thresholds=TEST_METRIC_THRESHOLDS,
+ date_filter_column=TEST_DATE_FILTER_COLUMN,
+ days_back=TEST_DAYS_BACK,
+ ratio_formula=TEST_RATIO_FORMULA,
+ ignore_zero=TEST_IGNORE_ZERO,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert actual == TriggerEvent({"status": "error", "message": "The second SQL query returned None"})
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_interval_check_trigger_pending(mock_job_status, caplog):
+ """
+ Tests that the BigQueryIntervalCheckTrigger do not fire while a query is still running.
+ """
+ mock_job_status.return_value = "pending"
+ caplog.set_level(logging.INFO)
+
+ trigger = BigQueryIntervalCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ first_job_id=TEST_FIRST_JOB_ID,
+ second_job_id=TEST_SECOND_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ table=TEST_TABLE_ID,
+ metrics_thresholds=TEST_METRIC_THRESHOLDS,
+ date_filter_column=TEST_DATE_FILTER_COLUMN,
+ days_back=TEST_DAYS_BACK,
+ ratio_formula=TEST_RATIO_FORMULA,
+ ignore_zero=TEST_IGNORE_ZERO,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+
+ assert f"Using the connection {TEST_CONN_ID} ." in caplog.text
+
+ assert "Query is still running..." in caplog.text
+ assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_interval_check_trigger_terminated(mock_job_status):
+ """
+ Tests the BigQueryIntervalCheckTrigger fires the correct event in case of an error.
+ """
+ # Set the status to a value other than success or pending
+ mock_job_status.return_value = "error"
+ trigger = BigQueryIntervalCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ first_job_id=TEST_FIRST_JOB_ID,
+ second_job_id=TEST_SECOND_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ table=TEST_TABLE_ID,
+ metrics_thresholds=TEST_METRIC_THRESHOLDS,
+ date_filter_column=TEST_DATE_FILTER_COLUMN,
+ days_back=TEST_DAYS_BACK,
+ ratio_formula=TEST_RATIO_FORMULA,
+ ignore_zero=TEST_IGNORE_ZERO,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+
+ assert TriggerEvent({"status": "error", "message": "error", "data": None}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_interval_check_trigger_exception(mock_job_status, caplog):
+ """
+ Tests that the BigQueryIntervalCheckTrigger fires the correct event in case of an error.
+ """
+ mock_job_status.side_effect = Exception("Test exception")
+ caplog.set_level(logging.DEBUG)
+
+ trigger = BigQueryIntervalCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ first_job_id=TEST_FIRST_JOB_ID,
+ second_job_id=TEST_SECOND_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ table=TEST_TABLE_ID,
+ metrics_thresholds=TEST_METRIC_THRESHOLDS,
+ date_filter_column=TEST_DATE_FILTER_COLUMN,
+ days_back=TEST_DAYS_BACK,
+ ratio_formula=TEST_RATIO_FORMULA,
+ ignore_zero=TEST_IGNORE_ZERO,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ # trigger event is yielded so it creates a generator object
+ # so i have used async for to get all the values and added it to task
+ task = [i async for i in trigger.run()]
+ # since we use return as soon as we yield the trigger event
+ # at any given point there should be one trigger event returned to the task
+ # so we validate for length of task to be 1
+
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+
+
+def test_bigquery_value_check_op_trigger_serialization():
+ """
+ Asserts that the BigQueryValueCheckTrigger correctly serializes its arguments
+ and classpath.
+ """
+
+ trigger = BigQueryValueCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ pass_value=TEST_PASS_VALUE,
+ job_id=TEST_JOB_ID,
+ dataset_id=TEST_DATASET_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ sql=TEST_SQL_QUERY,
+ table_id=TEST_TABLE_ID,
+ tolerance=TEST_TOLERANCE,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+ classpath, kwargs = trigger.serialize()
+
+ assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger"
+ assert kwargs == {
+ "conn_id": TEST_CONN_ID,
+ "pass_value": TEST_PASS_VALUE,
+ "job_id": TEST_JOB_ID,
+ "dataset_id": TEST_DATASET_ID,
+ "project_id": TEST_GCP_PROJECT_ID,
+ "sql": TEST_SQL_QUERY,
+ "table_id": TEST_TABLE_ID,
+ "tolerance": TEST_TOLERANCE,
+ "poll_interval": POLLING_PERIOD_SECONDS,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_records")
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_value_check_op_trigger_success(mock_job_status, get_job_output, get_records):
+ """
+ Tests that the BigQueryValueCheckTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "success"
+ get_job_output.return_value = {}
+ get_records.return_value = [[2], [4]]
+
+ trigger = BigQueryValueCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ pass_value=TEST_PASS_VALUE,
+ job_id=TEST_JOB_ID,
+ dataset_id=TEST_DATASET_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ sql=TEST_SQL_QUERY,
+ table_id=TEST_TABLE_ID,
+ tolerance=TEST_TOLERANCE,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ )
+
+ asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert actual == TriggerEvent({"status": "success", "message": "Job completed", "records": [4]})
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_value_check_op_trigger_pending(mock_job_status, caplog):
+ """
+ Tests that the BigQueryValueCheckTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "pending"
+ caplog.set_level(logging.INFO)
+
+ trigger = BigQueryValueCheckTrigger(
+ TEST_CONN_ID,
+ TEST_PASS_VALUE,
+ TEST_JOB_ID,
+ TEST_DATASET_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_SQL_QUERY,
+ TEST_TABLE_ID,
+ TEST_TOLERANCE,
+ POLLING_PERIOD_SECONDS,
+ )
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was returned
+ assert task.done() is False
+
+ assert "Query is still running..." in caplog.text
+
+ assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_value_check_op_trigger_fail(mock_job_status):
+ """
+ Tests that the BigQueryValueCheckTrigger only fires once the query execution reaches a successful state.
+ """
+ mock_job_status.return_value = "dummy"
+
+ trigger = BigQueryValueCheckTrigger(
+ TEST_CONN_ID,
+ TEST_PASS_VALUE,
+ TEST_JOB_ID,
+ TEST_DATASET_ID,
+ TEST_GCP_PROJECT_ID,
+ TEST_SQL_QUERY,
+ TEST_TABLE_ID,
+ TEST_TOLERANCE,
+ POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "dummy", "records": None}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+async def test_bigquery_value_check_trigger_exception(mock_job_status):
+ """
+ Tests the BigQueryValueCheckTrigger does not fire if there is an exception.
+ """
+ mock_job_status.side_effect = Exception("Test exception")
+
+ trigger = BigQueryValueCheckTrigger(
+ conn_id=TEST_CONN_ID,
+ sql=TEST_SQL_QUERY,
+ pass_value=TEST_PASS_VALUE,
+ tolerance=1,
+ job_id=TEST_JOB_ID,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+
+ # trigger event is yielded so it creates a generator object
+ # so i have used async for to get all the values and added it to task
+ task = [i async for i in trigger.run()]
+ # since we use return as soon as we yield the trigger event
+ # at any given point there should be one trigger event returned to the task
+ # so we validate for length of task to be 1
+
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+
+
+def test_big_query_table_existence_trigger_serialization():
+ """
+ Asserts that the BigQueryTableExistenceTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger"
+ assert kwargs == {
+ "dataset_id": TEST_DATASET_ID,
+ "project_id": TEST_GCP_PROJECT_ID,
+ "table_id": TEST_TABLE_ID,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "poll_interval": POLLING_PERIOD_SECONDS,
+ "hook_params": TEST_HOOK_PARAMS,
+ }
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists")
+async def test_big_query_table_existence_trigger_success(mock_table_exists):
+ """
+ Tests success case BigQueryTableExistenceTrigger
+ """
+ mock_table_exists.return_value = True
+
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "success", "message": "success"}) == actual
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists")
+async def test_big_query_table_existence_trigger_pending(mock_table_exists):
+ """
+ Test that BigQueryTableExistenceTrigger is in loop till the table exist.
+ """
+ mock_table_exists.return_value = False
+
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists")
+async def test_big_query_table_existence_trigger_exception(mock_table_exists):
+ """
+ Test BigQueryTableExistenceTrigger throws exception if any error.
+ """
+ mock_table_exists.side_effect = AsyncMock(side_effect=Exception("Test exception"))
+
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+ task = [i async for i in trigger.run()]
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) in task
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client")
+async def test_table_exists(mock_get_table_client):
+ """Test BigQueryTableExistenceTrigger._table_exists async function with mocked value
+ and mocked return value"""
+ hook = BigQueryTableAsyncHook()
+ mock_get_table_client.return_value = AsyncMock(Table)
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+ res = await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID)
+ assert res is True
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client")
+async def test_table_exists_exception(mock_get_table_client):
+ """Test BigQueryTableExistenceTrigger._table_exists async function with exception and return False"""
+ hook = BigQueryTableAsyncHook()
+ mock_get_table_client.side_effect = ClientResponseError(
+ history=(),
+ request_info=RequestInfo(
+ headers=CIMultiDict(),
+ real_url=URL("https://example.com"),
+ method="GET",
+ url=URL("https://example.com"),
+ ),
+ status=404,
+ message="Not Found",
+ )
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+ res = await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID)
+ expected_response = False
+ assert res == expected_response
+
+
+@pytest.mark.asyncio
+@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryTableAsyncHook.get_table_client")
+async def test_table_exists_raise_exception(mock_get_table_client):
+ """Test BigQueryTableExistenceTrigger._table_exists async function with raise exception"""
+ hook = BigQueryTableAsyncHook()
+ mock_get_table_client.side_effect = ClientResponseError(
+ history=(),
+ request_info=RequestInfo(
+ headers=CIMultiDict(),
+ real_url=URL("https://example.com"),
+ method="GET",
+ url=URL("https://example.com"),
+ ),
+ status=400,
+ message="Not Found",
+ )
+ trigger = BigQueryTableExistenceTrigger(
+ TEST_GCP_PROJECT_ID,
+ TEST_DATASET_ID,
+ TEST_TABLE_ID,
+ TEST_GCP_CONN_ID,
+ TEST_HOOK_PARAMS,
+ POLLING_PERIOD_SECONDS,
+ )
+ with pytest.raises(ClientResponseError):
+ await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID)
diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
new file mode 100644
index 0000000000..36e4844807
--- /dev/null
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries_async.py
@@ -0,0 +1,251 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google BigQuery service.
+Uses Async version of the Big Query Operators
+
+"""
+import os
+from datetime import datetime, timedelta
+
+from airflow import DAG
+from airflow.operators.bash import BashOperator
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.google.cloud.operators.bigquery import (
+ BigQueryCheckAsyncOperator,
+ BigQueryCreateEmptyDatasetOperator,
+ BigQueryCreateEmptyTableOperator,
+ BigQueryDeleteDatasetOperator,
+ BigQueryGetDataAsyncOperator,
+ BigQueryInsertJobAsyncOperator,
+ BigQueryIntervalCheckAsyncOperator,
+ BigQueryValueCheckAsyncOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.getenv("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "bigquery_queries_async"
+DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
+LOCATION = "us"
+EXECUTION_TIMEOUT = 6
+
+TABLE_1 = "table1"
+TABLE_2 = "table2"
+
+SCHEMA = [
+ {"name": "value", "type": "INTEGER", "mode": "REQUIRED"},
+ {"name": "name", "type": "STRING", "mode": "NULLABLE"},
+ {"name": "ds", "type": "STRING", "mode": "NULLABLE"},
+]
+
+DATASET = DATASET_NAME
+INSERT_DATE = datetime.now().strftime("%Y-%m-%d")
+INSERT_ROWS_QUERY = (
+ f"INSERT {DATASET}.{TABLE_1} VALUES "
+ f"(42, 'monthy python', '{INSERT_DATE}'), "
+ f"(42, 'fishy fish', '{INSERT_DATE}');"
+)
+
+default_args = {
+ "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
+ "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
+ "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
+}
+
+with DAG(
+ dag_id="example_async_bigquery_queries_async",
+ schedule=None,
+ start_date=datetime(2022, 1, 1),
+ catchup=False,
+ default_args=default_args,
+ tags=["example", "async", "bigquery"],
+ user_defined_macros={"DATASET": DATASET, "TABLE": TABLE_1},
+) as dag:
+ create_dataset = BigQueryCreateEmptyDatasetOperator(
+ task_id="create_dataset",
+ dataset_id=DATASET,
+ location=LOCATION,
+ )
+
+ create_table_1 = BigQueryCreateEmptyTableOperator(
+ task_id="create_table_1",
+ dataset_id=DATASET,
+ table_id=TABLE_1,
+ schema_fields=SCHEMA,
+ location=LOCATION,
+ )
+
+ create_dataset >> create_table_1
+
+ delete_dataset = BigQueryDeleteDatasetOperator(
+ task_id="delete_dataset", dataset_id=DATASET, delete_contents=True, trigger_rule=TriggerRule.ALL_DONE
+ )
+
+ # [START howto_operator_bigquery_insert_job_async]
+ insert_query_job = BigQueryInsertJobAsyncOperator(
+ task_id="insert_query_job",
+ configuration={
+ "query": {
+ "query": INSERT_ROWS_QUERY,
+ "useLegacySql": False,
+ }
+ },
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_insert_job_async]
+
+ # [START howto_operator_bigquery_select_job_async]
+ select_query_job = BigQueryInsertJobAsyncOperator(
+ task_id="select_query_job",
+ configuration={
+ "query": {
+ "query": "{% include 'example_bigquery_query.sql' %}",
+ "useLegacySql": False,
+ }
+ },
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_select_job_async]
+
+ # [START howto_operator_bigquery_value_check_async]
+ check_value = BigQueryValueCheckAsyncOperator(
+ task_id="check_value",
+ sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}",
+ pass_value=2,
+ use_legacy_sql=False,
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_value_check_async]
+
+ # [START howto_operator_bigquery_interval_check_async]
+ check_interval = BigQueryIntervalCheckAsyncOperator(
+ task_id="check_interval",
+ table=f"{DATASET}.{TABLE_1}",
+ days_back=1,
+ metrics_thresholds={"COUNT(*)": 1.5},
+ use_legacy_sql=False,
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_interval_check_async]
+
+ # [START howto_operator_bigquery_multi_query_async]
+ bigquery_execute_multi_query = BigQueryInsertJobAsyncOperator(
+ task_id="execute_multi_query",
+ configuration={
+ "query": {
+ "query": [
+ f"SELECT * FROM {DATASET}.{TABLE_2}",
+ f"SELECT COUNT(*) FROM {DATASET}.{TABLE_2}",
+ ],
+ "useLegacySql": False,
+ }
+ },
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_multi_query_async]
+
+ # [START howto_operator_bigquery_get_data_async]
+ get_data = BigQueryGetDataAsyncOperator(
+ task_id="get_data",
+ dataset_id=DATASET,
+ table_id=TABLE_1,
+ max_results=10,
+ selected_fields="value,name",
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_get_data_async]
+
+ get_data_result = BashOperator(
+ task_id="get_data_result",
+ bash_command=f"echo {get_data.output}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ # [START howto_operator_bigquery_check_async]
+ check_count = BigQueryCheckAsyncOperator(
+ task_id="check_count",
+ sql=f"SELECT COUNT(*) FROM {DATASET}.{TABLE_1}",
+ use_legacy_sql=False,
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_check_async]
+
+ # [START howto_operator_bigquery_execute_query_save_async]
+ execute_query_save = BigQueryInsertJobAsyncOperator(
+ task_id="execute_query_save",
+ configuration={
+ "query": {
+ "query": f"SELECT * FROM {DATASET}.{TABLE_1}",
+ "useLegacySql": False,
+ "destinationTable": {
+ "projectId": PROJECT_ID,
+ "datasetId": DATASET,
+ "tableId": TABLE_2,
+ },
+ }
+ },
+ location=LOCATION,
+ )
+ # [END howto_operator_bigquery_execute_query_save_async]
+
+ execute_long_running_query = BigQueryInsertJobAsyncOperator(
+ task_id="execute_long_running_query",
+ configuration={
+ "query": {
+ "query": f"""DECLARE success BOOL;
+ DECLARE size_bytes INT64;
+ DECLARE row_count INT64;
+ DECLARE DELAY_TIME DATETIME;
+ DECLARE WAIT STRING;
+ SET success = FALSE;
+
+ SELECT row_count = (SELECT row_count FROM {DATASET}.__TABLES__ WHERE table_id='NON_EXISTING_TABLE');
+ IF row_count > 0 THEN
+ SELECT 'Table Exists!' as message, retry_count as retries;
+ SET success = TRUE;
+ ELSE
+ SELECT 'Table does not exist' as message, row_count;
+ SET WAIT = 'TRUE';
+ SET DELAY_TIME = DATETIME_ADD(CURRENT_DATETIME,INTERVAL 1 MINUTE);
+ WHILE WAIT = 'TRUE' DO
+ IF (DELAY_TIME < CURRENT_DATETIME) THEN
+ SET WAIT = 'FALSE';
+ END IF;
+ END WHILE;
+ END IF;""",
+ "useLegacySql": False,
+ }
+ },
+ location=LOCATION,
+ )
+
+ end = EmptyOperator(task_id="end")
+
+ create_table_1 >> insert_query_job >> select_query_job >> check_count
+ insert_query_job >> get_data >> get_data_result
+ insert_query_job >> execute_query_save >> bigquery_execute_multi_query
+ insert_query_job >> execute_long_running_query >> check_value >> check_interval
+ [check_count, check_interval, bigquery_execute_multi_query, get_data_result] >> delete_dataset
+ [check_count, check_interval, bigquery_execute_multi_query, get_data_result, delete_dataset] >> end
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
index 45e44343a9..6faea7cbe9 100644
--- a/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
@@ -31,6 +31,7 @@ from airflow.providers.google.cloud.operators.bigquery import (
BigQueryInsertJobOperator,
)
from airflow.providers.google.cloud.sensors.bigquery import (
+ BigQueryTableExistenceAsyncSensor,
BigQueryTableExistenceSensor,
BigQueryTablePartitionExistenceSensor,
)
@@ -86,6 +87,15 @@ with models.DAG(
)
# [END howto_sensor_bigquery_table]
+ # [START howto_sensor_async_bigquery_table]
+ check_table_exists_async = BigQueryTableExistenceAsyncSensor(
+ task_id="check_table_exists_async",
+ project_id=PROJECT_ID,
+ dataset_id=DATASET_NAME,
+ table_id=TABLE_NAME,
+ )
+ # [END howto_sensor_async_bigquery_table]
+
execute_insert_query: BaseOperator = BigQueryInsertJobOperator(
task_id="execute_insert_query",
configuration={
@@ -116,7 +126,7 @@ with models.DAG(
create_dataset >> create_table
create_table >> [check_table_exists, execute_insert_query]
execute_insert_query >> check_table_partition_exists
- [check_table_exists, check_table_partition_exists] >> delete_dataset
+ [check_table_exists, check_table_exists_async, check_table_partition_exists] >> delete_dataset
from tests.system.utils.watcher import watcher