You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mo...@apache.org on 2023/08/04 14:44:01 UTC

[airflow] 01/01: openlineage, bigquery: add openlineage method support for BigQueryExecuteQueryOperator

This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-bigquery-operation
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d047c6586dfc8666618e5680b87679d6b44ffc53
Author: Maciej Obuchowski <ob...@gmail.com>
AuthorDate: Mon May 15 17:21:47 2023 +0200

    openlineage, bigquery: add openlineage method support for BigQueryExecuteQueryOperator
    
    Signed-off-by: Maciej Obuchowski <ob...@gmail.com>
---
 airflow/providers/google/cloud/hooks/bigquery.py   |   2 +-
 .../providers/google/cloud/operators/bigquery.py   |  92 +++++++-
 airflow/providers/openlineage/extractors/base.py   |   6 +
 airflow/providers/openlineage/utils/utils.py       |   9 +-
 .../google/cloud/operators/job_details.json        | 240 +++++++++++++++++++++
 .../google/cloud/operators/test_bigquery.py        |  91 ++++++++
 6 files changed, 428 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py
index bfaef219ad..2ac6c2645e 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2245,7 +2245,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
         self.running_job_id = job.job_id
         return job.job_id
 
-    def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False):
+    def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str:
         if force_rerun:
             hash_base = str(uuid.uuid4())
         else:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index f5e5a9634f..ca6f290004 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -133,6 +133,68 @@ class _BigQueryDbHookMixin:
         )
 
 
+class _BigQueryOpenLineageMixin:
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """
+        Retrieve OpenLineage data for a COMPLETE BigQuery job.
+
+        This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider.
+        It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level
+        usage statistics.
+
+        Run facets should contain:
+            - ExternalQueryRunFacet
+            - BigQueryJobRunFacet
+
+        Job facets should contain:
+            - SqlJobFacet if operator has self.sql
+
+        Input datasets should contain facets:
+            - DataSourceDatasetFacet
+            - SchemaDatasetFacet
+
+        Output datasets should contain facets:
+            - DataSourceDatasetFacet
+            - SchemaDatasetFacet
+            - OutputStatisticsOutputDatasetFacet
+        """
+        from openlineage.client.facet import SqlJobFacet
+        from openlineage.common.provider.bigquery import BigQueryDatasetsProvider
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+        from airflow.providers.openlineage.utils.utils import normalize_sql
+
+        if not self.job_id:
+            return OperatorLineage()
+
+        client = self.hook.get_client(project_id=self.hook.project_id)
+        job_ids = self.job_id
+        if isinstance(self.job_id, str):
+            job_ids = [self.job_id]
+        inputs, outputs, run_facets = {}, {}, {}
+        for job_id in job_ids:
+            stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id)
+            for input in stats.inputs:
+                input = input.to_openlineage_dataset()
+                inputs[input.name] = input
+            if stats.output:
+                output = stats.output.to_openlineage_dataset()
+                outputs[output.name] = output
+            for key, value in stats.run_facets.items():
+                run_facets[key] = value
+
+        job_facets = {}
+        if hasattr(self, "sql"):
+            job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql))
+
+        return OperatorLineage(
+            inputs=list(inputs.values()),
+            outputs=list(outputs.values()),
+            run_facets=run_facets,
+            job_facets=job_facets,
+        )
+
+
 class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
     """Performs checks against BigQuery.
 
@@ -1153,6 +1215,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
         self.encryption_configuration = encryption_configuration
         self.hook: BigQueryHook | None = None
         self.impersonation_chain = impersonation_chain
+        self.job_id: str | list[str] | None = None
 
     def execute(self, context: Context):
         if self.hook is None:
@@ -1164,7 +1227,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
                 impersonation_chain=self.impersonation_chain,
             )
         if isinstance(self.sql, str):
-            job_id: str | list[str] = self.hook.run_query(
+            self.job_id = self.hook.run_query(
                 sql=self.sql,
                 destination_dataset_table=self.destination_dataset_table,
                 write_disposition=self.write_disposition,
@@ -1184,7 +1247,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
                 encryption_configuration=self.encryption_configuration,
             )
         elif isinstance(self.sql, Iterable):
-            job_id = [
+            self.job_id = [
                 self.hook.run_query(
                     sql=s,
                     destination_dataset_table=self.destination_dataset_table,
@@ -1210,9 +1273,9 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
             raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable")
         project_id = self.hook.project_id
         if project_id:
-            job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
+            job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location)
             context["task_instance"].xcom_push(key="job_id_path", value=job_id_path)
-        return job_id
+        return self.job_id
 
     def on_kill(self) -> None:
         super().on_kill()
@@ -2562,7 +2625,7 @@ class BigQueryUpdateTableSchemaOperator(GoogleCloudBaseOperator):
         return table
 
 
-class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
+class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryOpenLineageMixin):
     """Execute a BigQuery job.
 
     Waits for the job to complete and returns job id.
@@ -2663,6 +2726,13 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
         self.deferrable = deferrable
         self.poll_interval = poll_interval
 
+    @property
+    def sql(self) -> str | None:
+        try:
+            return self.configuration["query"]["query"]
+        except KeyError:
+            return None
+
     def prepare_template(self) -> None:
         # If .json is passed then we have to read the file
         if isinstance(self.configuration, str) and self.configuration.endswith(".json"):
@@ -2697,7 +2767,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
         )
         self.hook = hook
 
-        job_id = hook.generate_job_id(
+        self.job_id = hook.generate_job_id(
             job_id=self.job_id,
             dag_id=self.dag_id,
             task_id=self.task_id,
@@ -2708,13 +2778,13 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
 
         try:
             self.log.info("Executing: %s'", self.configuration)
-            job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id)
+            job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id)
         except Conflict:
             # If the job already exists retrieve it
             job = hook.get_job(
                 project_id=self.project_id,
                 location=self.location,
-                job_id=job_id,
+                job_id=self.job_id,
             )
             if job.state in self.reattach_states:
                 # We are reattaching to a job
@@ -2723,7 +2793,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
             else:
                 # Same job configuration so we need force_rerun
                 raise AirflowException(
-                    f"Job with id: {job_id} already exists and is in {job.state} state. If you "
+                    f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
                     f"want to force rerun it consider setting `force_rerun=True`."
                     f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
                 )
@@ -2757,7 +2827,9 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
         self.job_id = job.job_id
         project_id = self.project_id or self.hook.project_id
         if project_id:
-            job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
+            job_id_path = convert_job_id(
+                job_id=self.job_id, project_id=project_id, location=self.location  # type: ignore[arg-type]
+            )
             context["ti"].xcom_push(key="job_id_path", value=job_id_path)
         # Wait for the job to complete
         if not self.deferrable:
diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py
index 95d8fa6f28..0926489c0d 100644
--- a/airflow/providers/openlineage/extractors/base.py
+++ b/airflow/providers/openlineage/extractors/base.py
@@ -86,6 +86,12 @@ class DefaultExtractor(BaseExtractor):
         # OpenLineage methods are optional - if there's no method, return None
         try:
             return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start)  # type: ignore
+        except ImportError:
+            self.log.error(
+                "OpenLineage provider method failed to import OpenLineage integration. "
+                "This should not happen. Please report this bug to developers."
+            )
+            return None
         except AttributeError:
             return None
 
diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py
index ca8b559e3a..20b9afef49 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -23,7 +23,7 @@ import logging
 import os
 from contextlib import suppress
 from functools import wraps
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Iterable
 from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
 
 import attrs
@@ -414,3 +414,10 @@ def is_source_enabled() -> bool:
 def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict:
     not_required_keys = {"dag", "task_group"}
     return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys}
+
+
+def normalize_sql(sql: str | Iterable[str]):
+    if isinstance(sql, str):
+        sql = [stmt for stmt in sql.split(";") if stmt != ""]
+    sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
+    return ";\n".join(sql)
diff --git a/tests/providers/google/cloud/operators/job_details.json b/tests/providers/google/cloud/operators/job_details.json
new file mode 100644
index 0000000000..f12ec1321d
--- /dev/null
+++ b/tests/providers/google/cloud/operators/job_details.json
@@ -0,0 +1,240 @@
+{
+    "kind": "bigquery#job",
+    "etag": "vd2aBaVVX6a4bUJW13+Tqg==",
+    "id": "airflow:US.job_IDnbVW6NACdFDkermznYm9o4mcVH",
+    "selfLink": "https://bigquery.googleapis.com/bigquery/v2/projects/airflow-openlineage/jobs/job_IDnbVW6NACdFDkermznYm9o4mcVH?location=US",
+    "user_email": "svc-account@airflow-openlineage.iam.gserviceaccount.com",
+    "configuration": {
+        "query": {
+            "query": "Select * from test_table",
+            "destinationTable": {
+                "projectId": "airflow-openlineage",
+                "datasetId": "new_dataset",
+                "tableId": "output_table"
+            },
+            "createDisposition": "CREATE_IF_NEEDED",
+            "writeDisposition": "WRITE_TRUNCATE",
+            "priority": "INTERACTIVE",
+            "allowLargeResults": false,
+            "useLegacySql": false
+        },
+        "jobType": "QUERY"
+    },
+    "jobReference": {
+        "projectId": "airflow-openlineage",
+        "jobId": "job_IDnbVW6NACdFDkermznYm9o4mcVH",
+        "location": "US"
+    },
+    "statistics": {
+        "creationTime": 1.60390893E12,
+        "startTime": 1.60390893E12,
+        "endTime": 1.60390893E12,
+        "totalBytesProcessed": "110355534",
+        "query": {
+            "queryPlan": [{
+                    "name": "S00: Input",
+                    "id": "0",
+                    "startMs": "1603908925668",
+                    "endMs": "1603908925880",
+                    "waitRatioAvg": 0.0070422534,
+                    "waitMsAvg": "2",
+                    "waitRatioMax": 0.0070422534,
+                    "waitMsMax": "2",
+                    "readRatioAvg": 0.14084508,
+                    "readMsAvg": "40",
+                    "readRatioMax": 0.14084508,
+                    "readMsMax": "40",
+                    "computeRatioAvg": 1,
+                    "computeMsAvg": "284",
+                    "computeRatioMax": 1,
+                    "computeMsMax": "284",
+                    "writeRatioAvg": 0.017605634,
+                    "writeMsAvg": "5",
+                    "writeRatioMax": 0.017605634,
+                    "writeMsMax": "5",
+                    "shuffleOutputBytes": "439409",
+                    "shuffleOutputBytesSpilled": "0",
+                    "recordsRead": "5552452",
+                    "recordsWritten": "16142",
+                    "parallelInputs": "1",
+                    "completedParallelInputs": "1",
+                    "status": "COMPLETE",
+                    "steps": [{
+                            "kind": "READ",
+                            "substeps": [
+                                "$1:state, $2:name, $3:number",
+                                "FROM bigquery-public-data.usa_names.usa_1910_2013",
+                                "WHERE equal($1, 'TX')"
+                            ]
+                        },
+                        {
+                            "kind": "AGGREGATE",
+                            "substeps": [
+                                "GROUP BY $30 := $2, $31 := $1",
+                                "$20 := SUM($3)"
+                            ]
+                        },
+                        {
+                            "kind": "WRITE",
+                            "substeps": [
+                                "$31, $30, $20",
+                                "TO __stage00_output",
+                                "BY HASH($30, $31)"
+                            ]
+                        }
+                    ],
+                    "slotMs": "448"
+                },
+                {
+                    "name": "S01: Sort+",
+                    "id": "1",
+                    "startMs": "1603908925891",
+                    "endMs": "1603908925911",
+                    "inputStages": [
+                        "0"
+                    ],
+                    "waitRatioAvg": 0.0070422534,
+                    "waitMsAvg": "2",
+                    "waitRatioMax": 0.0070422534,
+                    "waitMsMax": "2",
+                    "readRatioAvg": 0,
+                    "readMsAvg": "0",
+                    "readRatioMax": 0,
+                    "readMsMax": "0",
+                    "computeRatioAvg": 0.049295776,
+                    "computeMsAvg": "14",
+                    "computeRatioMax": 0.049295776,
+                    "computeMsMax": "14",
+                    "writeRatioAvg": 0.0070422534,
+                    "writeMsAvg": "2",
+                    "writeRatioMax": 0.0070422534,
+                    "writeMsMax": "2",
+                    "shuffleOutputBytes": "401",
+                    "shuffleOutputBytesSpilled": "0",
+                    "recordsRead": "16142",
+                    "recordsWritten": "20",
+                    "parallelInputs": "1",
+                    "completedParallelInputs": "1",
+                    "status": "COMPLETE",
+                    "steps": [{
+                            "kind": "READ",
+                            "substeps": [
+                                "$31, $30, $20",
+                                "FROM __stage00_output"
+                            ]
+                        },
+                        {
+                            "kind": "SORT",
+                            "substeps": [
+                                "$10 DESC",
+                                "LIMIT 20"
+                            ]
+                        },
+                        {
+                            "kind": "AGGREGATE",
+                            "substeps": [
+                                "GROUP BY $40 := $30, $41 := $31",
+                                "$10 := SUM($20)"
+                            ]
+                        },
+                        {
+                            "kind": "WRITE",
+                            "substeps": [
+                                "$50, $51",
+                                "TO __stage01_output"
+                            ]
+                        }
+                    ],
+                    "slotMs": "33"
+                },
+                {
+                    "name": "S02: Output",
+                    "id": "2",
+                    "startMs": "1603908926017",
+                    "endMs": "1603908926191",
+                    "inputStages": [
+                        "1"
+                    ],
+                    "waitRatioAvg": 0.4471831,
+                    "waitMsAvg": "127",
+                    "waitRatioMax": 0.4471831,
+                    "waitMsMax": "127",
+                    "readRatioAvg": 0,
+                    "readMsAvg": "0",
+                    "readRatioMax": 0,
+                    "readMsMax": "0",
+                    "computeRatioAvg": 0.03169014,
+                    "computeMsAvg": "9",
+                    "computeRatioMax": 0.03169014,
+                    "computeMsMax": "9",
+                    "writeRatioAvg": 0.5633803,
+                    "writeMsAvg": "160",
+                    "writeRatioMax": 0.5633803,
+                    "writeMsMax": "160",
+                    "shuffleOutputBytes": "321",
+                    "shuffleOutputBytesSpilled": "0",
+                    "recordsRead": "20",
+                    "recordsWritten": "20",
+                    "parallelInputs": "1",
+                    "completedParallelInputs": "1",
+                    "status": "COMPLETE",
+                    "steps": [{
+                            "kind": "READ",
+                            "substeps": [
+                                "$50, $51",
+                                "FROM __stage01_output"
+                            ]
+                        },
+                        {
+                            "kind": "SORT",
+                            "substeps": [
+                                "$51 DESC",
+                                "LIMIT 20"
+                            ]
+                        },
+                        {
+                            "kind": "WRITE",
+                            "substeps": [
+                                "$60, $61",
+                                "TO __stage02_output"
+                            ]
+                        }
+                    ],
+                    "slotMs": "342"
+                }
+            ],
+            "estimatedBytesProcessed": "110355534",
+            "timeline": [{
+                    "elapsedMs": "736",
+                    "totalSlotMs": "482",
+                    "pendingUnits": "1",
+                    "completedUnits": "2",
+                    "activeUnits": "1"
+                },
+                {
+                    "elapsedMs": "1045",
+                    "totalSlotMs": "825",
+                    "pendingUnits": "0",
+                    "completedUnits": "3",
+                    "activeUnits": "1"
+                }
+            ],
+            "totalPartitionsProcessed": "0",
+            "totalBytesProcessed": "110355534",
+            "totalBytesBilled": "111149056",
+            "billingTier": 1,
+            "totalSlotMs": "825",
+            "cacheHit": false,
+            "referencedTables": [{
+                "projectId": "airflow-openlineage",
+                "datasetId": "new_dataset",
+                "tableId": "test_table"
+            }],
+            "statementType": "SELECT"
+        },
+        "totalSlotMs": "825"
+    },
+    "status": {
+        "state": "DONE"
+    }
+}
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py
index 4026b4ba45..d17f5498e2 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+import json
+from contextlib import suppress
 from unittest import mock
 from unittest.mock import ANY, MagicMock
 
@@ -24,6 +26,13 @@ import pandas as pd
 import pytest
 from google.cloud.bigquery import DEFAULT_RETRY
 from google.cloud.exceptions import Conflict
+from openlineage.client.facet import (
+    DataSourceDatasetFacet,
+    ExternalQueryRunFacet,
+    SqlJobFacet,
+)
+from openlineage.client.run import Dataset
+from openlineage.common.provider.bigquery import BigQueryErrorRunFacet
 
 from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout, TaskDeferred
 from airflow.providers.google.cloud.operators.bigquery import (
@@ -1520,6 +1529,88 @@ class TestBigQueryInsertJobOperator:
             force_rerun=True,
         )
 
+    @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+    def test_execute_openlineage_events(self, mock_hook):
+        job_id = "123456"
+        hash_ = "hash"
+        real_job_id = f"{job_id}_{hash_}"
+
+        configuration = {
+            "query": {
+                "query": "SELECT * FROM test_table",
+                "useLegacySql": False,
+            }
+        }
+        mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+        mock_hook.return_value.generate_job_id.return_value = real_job_id
+
+        op = BigQueryInsertJobOperator(
+            task_id="insert_query_job",
+            configuration=configuration,
+            location=TEST_DATASET_LOCATION,
+            job_id=job_id,
+            project_id=TEST_GCP_PROJECT_ID,
+        )
+        result = op.execute(context=MagicMock())
+
+        mock_hook.return_value.insert_job.assert_called_once_with(
+            configuration=configuration,
+            location=TEST_DATASET_LOCATION,
+            job_id=real_job_id,
+            nowait=True,
+            project_id=TEST_GCP_PROJECT_ID,
+            retry=DEFAULT_RETRY,
+            timeout=None,
+        )
+
+        assert result == real_job_id
+
+        with open(file="tests/providers/google/cloud/operators/job_details.json") as f:
+            job_details = json.loads(f.read())
+        mock_hook.return_value.get_client.return_value.get_job.return_value._properties = job_details
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+        assert lineage.inputs == [
+            Dataset(
+                namespace="bigquery",
+                name="airflow-openlineage.new_dataset.test_table",
+                facets={"dataSource": DataSourceDatasetFacet(name="bigquery", uri="bigquery")},
+            )
+        ]
+
+        assert lineage.run_facets == {
+            "bigQuery_job": mock.ANY,
+            "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"),
+        }
+        assert lineage.job_facets == {"sql": SqlJobFacet(query="SELECT * FROM test_table")}
+
+    @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+    def test_execute_fails_openlineage_events(self, mock_hook):
+        job_id = "1234"
+
+        configuration = {
+            "query": {
+                "query": "SELECT * FROM test_table",
+                "useLegacySql": False,
+            }
+        }
+        operator = BigQueryInsertJobOperator(
+            task_id="insert_query_job_failed",
+            configuration=configuration,
+            location=TEST_DATASET_LOCATION,
+            job_id=job_id,
+            project_id=TEST_GCP_PROJECT_ID,
+        )
+        mock_hook.return_value.generate_job_id.return_value = "1234"
+        mock_hook.return_value.get_client.return_value.get_job.side_effect = RuntimeError()
+        mock_hook.return_value.insert_job.side_effect = RuntimeError()
+
+        with suppress(RuntimeError):
+            operator.execute(MagicMock())
+        lineage = operator.get_openlineage_facets_on_complete(None)
+
+        assert lineage.run_facets == {"bigQuery_error": BigQueryErrorRunFacet(clientError=mock.ANY)}
+
     @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
     def test_execute_force_rerun_async(self, mock_hook, create_task_instance_of_operator):
         job_id = "123456"