You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/09/21 08:50:08 UTC

[airflow] branch main updated: Add BigQuery Column and Table Check Operators (#26368)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c4256ca1a0 Add BigQuery Column and Table Check Operators (#26368)
c4256ca1a0 is described below

commit c4256ca1a029240299b83841bdd034385665cdda
Author: Benji Lampel <be...@astronomer.io>
AuthorDate: Wed Sep 21 04:49:57 2022 -0400

    Add BigQuery Column and Table Check Operators (#26368)
    
    * Add Column and Table Check Operators
    
    Add two new operators based on the SQLColumnCheckOperator and
    SQLTableCheckOperator that also provide job_ids so results
    of the queries can be pulled and parsed, and so OpenLineage
    can parse datasets and provide lineage information.
---
 .../providers/google/cloud/operators/bigquery.py   | 239 +++++++++++++++++++++
 .../operators/cloud/bigquery.rst                   |  28 +++
 .../cloud/bigquery/example_bigquery_queries.py     |  18 ++
 3 files changed, 285 insertions(+)

diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 084205f4ed..33e51833e5 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -34,8 +34,12 @@ from airflow.models import BaseOperator, BaseOperatorLink
 from airflow.models.xcom import XCom
 from airflow.providers.common.sql.operators.sql import (
     SQLCheckOperator,
+    SQLColumnCheckOperator,
     SQLIntervalCheckOperator,
+    SQLTableCheckOperator,
     SQLValueCheckOperator,
+    _get_failed_checks,
+    parse_boolean,
 )
 from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
 from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
@@ -520,6 +524,241 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat
         )
 
 
+class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator):
+    """
+    BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator
+    in order to provide a job id for OpenLineage to parse. See base class
+    docstring for usage.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:BigQueryColumnCheckOperator`
+
+    :param table: the table name
+    :param column_mapping: a dictionary relating columns to their checks
+    :param partition_clause: a string SQL statement added to a WHERE clause
+        to partition data
+    :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
+    :param use_legacy_sql: Whether to use legacy SQL (true)
+        or standard SQL (false).
+    :param location: The geographic location of the job. See details at:
+        https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :param labels: a dictionary containing labels for the table, passed to BigQuery
+    """
+
+    def __init__(
+        self,
+        *,
+        table: str,
+        column_mapping: dict,
+        partition_clause: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        use_legacy_sql: bool = True,
+        location: str | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        labels: dict | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            table=table, column_mapping=column_mapping, partition_clause=partition_clause, **kwargs
+        )
+        self.table = table
+        self.column_mapping = column_mapping
+        self.partition_clause = partition_clause
+        self.gcp_conn_id = gcp_conn_id
+        self.use_legacy_sql = use_legacy_sql
+        self.location = location
+        self.impersonation_chain = impersonation_chain
+        self.labels = labels
+        # OpenLineage needs a valid SQL query with the input/output table(s) to parse
+        self.sql = ""
+
+    def _submit_job(
+        self,
+        hook: BigQueryHook,
+        job_id: str,
+    ) -> BigQueryJob:
+        """Submit a new job and get the job id for polling the status using Trigger."""
+        configuration = {"query": {"query": self.sql}}
+
+        return hook.insert_job(
+            configuration=configuration,
+            project_id=hook.project_id,
+            location=self.location,
+            job_id=job_id,
+            nowait=False,
+        )
+
+    def execute(self, context=None):
+        """Perform checks on the given columns."""
+        hook = self.get_db_hook()
+        failed_tests = []
+        for column in self.column_mapping:
+            checks = [*self.column_mapping[column]]
+            checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks])
+            partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else ""
+            self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};"
+
+            job_id = hook.generate_job_id(
+                dag_id=self.dag_id,
+                task_id=self.task_id,
+                logical_date=context["logical_date"],
+                configuration=self.configuration,
+            )
+            job = self._submit_job(hook, job_id=job_id)
+            context["ti"].xcom_push(key="job_id", value=job.job_id)
+            records = list(job.result().to_dataframe().values.flatten())
+
+            if not records:
+                raise AirflowException(f"The following query returned zero rows: {self.sql}")
+
+            self.log.info("Record: %s", records)
+
+            for idx, result in enumerate(records):
+                tolerance = self.column_mapping[column][checks[idx]].get("tolerance")
+
+                self.column_mapping[column][checks[idx]]["result"] = result
+                self.column_mapping[column][checks[idx]]["success"] = self._get_match(
+                    self.column_mapping[column][checks[idx]], result, tolerance
+                )
+
+            failed_tests.extend(_get_failed_checks(self.column_mapping[column], column))
+        if failed_tests:
+            raise AirflowException(
+                f"Test failed.\nResults:\n{records!s}\n"
+                "The following tests have failed:"
+                f"\n{''.join(failed_tests)}"
+            )
+
+        self.log.info("All tests have passed")
+
+
+class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
+    """
+    BigQueryTableCheckOperator subclasses the SQLTableCheckOperator
+    in order to provide a job id for OpenLineage to parse. See base class
+    for usage.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:BigQueryTableCheckOperator`
+
+    :param table: the table name
+    :param checks: a dictionary of check names and boolean SQL statements
+    :param partition_clause: a string SQL statement added to a WHERE clause
+        to partition data
+    :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
+    :param use_legacy_sql: Whether to use legacy SQL (true)
+        or standard SQL (false).
+    :param location: The geographic location of the job. See details at:
+        https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :param labels: a dictionary containing labels for the table, passed to BigQuery
+    """
+
+    def __init__(
+        self,
+        *,
+        table: str,
+        checks: dict,
+        partition_clause: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        use_legacy_sql: bool = True,
+        location: str | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        labels: dict | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs)
+        self.table = table
+        self.checks = checks
+        self.partition_clause = partition_clause
+        self.gcp_conn_id = gcp_conn_id
+        self.use_legacy_sql = use_legacy_sql
+        self.location = location
+        self.impersonation_chain = impersonation_chain
+        self.labels = labels
+        # OpenLineage needs a valid SQL query with the input/output table(s) to parse
+        self.sql = ""
+
+    def _submit_job(
+        self,
+        hook: BigQueryHook,
+        job_id: str,
+    ) -> BigQueryJob:
+        """Submit a new job and get the job id for polling the status using Trigger."""
+        configuration = {"query": {"query": self.sql}}
+
+        return hook.insert_job(
+            configuration=configuration,
+            project_id=hook.project_id,
+            location=self.location,
+            job_id=job_id,
+            nowait=False,
+        )
+
+    def execute(self, context=None):
+        """Execute the given checks on the table."""
+        hook = self.get_db_hook()
+        checks_sql = " UNION ALL ".join(
+            [
+                self.sql_check_template.replace("check_statement", value["check_statement"])
+                .replace("_check_name", check_name)
+                .replace("table", self.table)
+                for check_name, value in self.checks.items()
+            ]
+        )
+        partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else ""
+        self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) "
+        f"AS check_table {partition_clause_statement};"
+
+        job_id = hook.generate_job_id(
+            dag_id=self.dag_id,
+            task_id=self.task_id,
+            logical_date=context["logical_date"],
+            configuration=self.configuration,
+        )
+        job = self._submit_job(hook, job_id=job_id)
+        context["ti"].xcom_push(key="job_id", value=job.job_id)
+        records = job.result().to_dataframe()
+
+        if records.empty:
+            raise AirflowException(f"The following query returned zero rows: {self.sql}")
+
+        records.columns = records.columns.str.lower()
+        self.log.info("Record:\n%s", records)
+
+        for row in records.iterrows():
+            check = row[1].get("check_name")
+            result = row[1].get("check_result")
+            self.checks[check]["success"] = parse_boolean(str(result))
+
+        failed_tests = _get_failed_checks(self.checks)
+        if failed_tests:
+            raise AirflowException(
+                f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
+                "The following tests have failed:"
+                f"\n{', '.join(failed_tests)}"
+            )
+
+        self.log.info("All tests have passed")
+
+
 class BigQueryGetDataOperator(BaseOperator):
     """
     Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
index 548c37ace2..919c3dd898 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
@@ -438,6 +438,34 @@ Also you can use deferrable mode in this operator
     :start-after: [START howto_operator_bigquery_interval_check_async]
     :end-before: [END howto_operator_bigquery_interval_check_async]
 
+.. _howto/operator:BigQueryColumnCheckOperator:
+
+Check columns with predefined tests
+"""""""""""""""""""""""""""""""""""
+
+To check that columns pass user-configurable tests you can use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryColumnCheckOperator`
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_bigquery_column_check]
+    :end-before: [END howto_operator_bigquery_column_check]
+
+.. _howto/operator:BigQueryTableCheckOperator:
+
+Check table level data quality
+""""""""""""""""""""""""""""""
+
+To check that tables pass user-defined tests you can use
+:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryTableCheckOperator`
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_bigquery_table_check]
+    :end-before: [END howto_operator_bigquery_table_check]
+
 Sensors
 ^^^^^^^
 
diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
index eeb5454541..52637d37ce 100644
--- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
@@ -27,12 +27,14 @@ from airflow import models
 from airflow.operators.bash import BashOperator
 from airflow.providers.google.cloud.operators.bigquery import (
     BigQueryCheckOperator,
+    BigQueryColumnCheckOperator,
     BigQueryCreateEmptyDatasetOperator,
     BigQueryCreateEmptyTableOperator,
     BigQueryDeleteDatasetOperator,
     BigQueryGetDataOperator,
     BigQueryInsertJobOperator,
     BigQueryIntervalCheckOperator,
+    BigQueryTableCheckOperator,
     BigQueryValueCheckOperator,
 )
 from airflow.utils.trigger_rule import TriggerRule
@@ -209,6 +211,22 @@ for index, location in enumerate(locations, 1):
         )
         # [END howto_operator_bigquery_interval_check]
 
+        # [START howto_operator_bigquery_column_check]
+        column_check = BigQueryColumnCheckOperator(
+            task_id="column_check",
+            table=f"{DATASET}.{TABLE_1}",
+            column_mapping={"value": {"null_check": {"equal_to": 0}}},
+        )
+        # [END howto_operator_bigquery_column_check]
+
+        # [START howto_operator_bigquery_table_check]
+        table_check = BigQueryTableCheckOperator(
+            task_id="table_check",
+            table=f"{DATASET}.{TABLE_1}",
+            checks={"row_count_check": {"check_statement": {"COUNT(*) = 4"}}},
+        )
+        # [END howto_operator_bigquery_table_check]
+
         delete_dataset = BigQueryDeleteDatasetOperator(
             task_id="delete_dataset",
             dataset_id=DATASET,