You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mo...@apache.org on 2023/08/04 14:44:01 UTC
[airflow] 01/01: openlineage, bigquery: add openlineage method support for BigQueryExecuteQueryOperator
This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch openlineage-bigquery-operation
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d047c6586dfc8666618e5680b87679d6b44ffc53
Author: Maciej Obuchowski <ob...@gmail.com>
AuthorDate: Mon May 15 17:21:47 2023 +0200
openlineage, bigquery: add openlineage method support for BigQueryExecuteQueryOperator
Signed-off-by: Maciej Obuchowski <ob...@gmail.com>
---
airflow/providers/google/cloud/hooks/bigquery.py | 2 +-
.../providers/google/cloud/operators/bigquery.py | 92 +++++++-
airflow/providers/openlineage/extractors/base.py | 6 +
airflow/providers/openlineage/utils/utils.py | 9 +-
.../google/cloud/operators/job_details.json | 240 +++++++++++++++++++++
.../google/cloud/operators/test_bigquery.py | 91 ++++++++
6 files changed, 428 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py
index bfaef219ad..2ac6c2645e 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2245,7 +2245,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
self.running_job_id = job.job_id
return job.job_id
- def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False):
+ def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str:
if force_rerun:
hash_base = str(uuid.uuid4())
else:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index f5e5a9634f..ca6f290004 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -133,6 +133,68 @@ class _BigQueryDbHookMixin:
)
+class _BigQueryOpenLineageMixin:
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Retrieve OpenLineage data for a COMPLETE BigQuery job.
+
+ This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider.
+ It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level
+ usage statistics.
+
+ Run facets should contain:
+ - ExternalQueryRunFacet
+ - BigQueryJobRunFacet
+
+ Job facets should contain:
+ - SqlJobFacet if operator has self.sql
+
+ Input datasets should contain facets:
+ - DataSourceDatasetFacet
+ - SchemaDatasetFacet
+
+ Output datasets should contain facets:
+ - DataSourceDatasetFacet
+ - SchemaDatasetFacet
+ - OutputStatisticsOutputDatasetFacet
+ """
+ from openlineage.client.facet import SqlJobFacet
+ from openlineage.common.provider.bigquery import BigQueryDatasetsProvider
+
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.utils.utils import normalize_sql
+
+ if not self.job_id:
+ return OperatorLineage()
+
+ client = self.hook.get_client(project_id=self.hook.project_id)
+ job_ids = self.job_id
+ if isinstance(self.job_id, str):
+ job_ids = [self.job_id]
+ inputs, outputs, run_facets = {}, {}, {}
+ for job_id in job_ids:
+ stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id)
+ for input in stats.inputs:
+ input = input.to_openlineage_dataset()
+ inputs[input.name] = input
+ if stats.output:
+ output = stats.output.to_openlineage_dataset()
+ outputs[output.name] = output
+ for key, value in stats.run_facets.items():
+ run_facets[key] = value
+
+ job_facets = {}
+ if hasattr(self, "sql"):
+ job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql))
+
+ return OperatorLineage(
+ inputs=list(inputs.values()),
+ outputs=list(outputs.values()),
+ run_facets=run_facets,
+ job_facets=job_facets,
+ )
+
+
class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
"""Performs checks against BigQuery.
@@ -1153,6 +1215,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
self.encryption_configuration = encryption_configuration
self.hook: BigQueryHook | None = None
self.impersonation_chain = impersonation_chain
+ self.job_id: str | list[str] | None = None
def execute(self, context: Context):
if self.hook is None:
@@ -1164,7 +1227,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
impersonation_chain=self.impersonation_chain,
)
if isinstance(self.sql, str):
- job_id: str | list[str] = self.hook.run_query(
+ self.job_id = self.hook.run_query(
sql=self.sql,
destination_dataset_table=self.destination_dataset_table,
write_disposition=self.write_disposition,
@@ -1184,7 +1247,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
encryption_configuration=self.encryption_configuration,
)
elif isinstance(self.sql, Iterable):
- job_id = [
+ self.job_id = [
self.hook.run_query(
sql=s,
destination_dataset_table=self.destination_dataset_table,
@@ -1210,9 +1273,9 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable")
project_id = self.hook.project_id
if project_id:
- job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
+ job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location)
context["task_instance"].xcom_push(key="job_id_path", value=job_id_path)
- return job_id
+ return self.job_id
def on_kill(self) -> None:
super().on_kill()
@@ -2562,7 +2625,7 @@ class BigQueryUpdateTableSchemaOperator(GoogleCloudBaseOperator):
return table
-class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
+class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryOpenLineageMixin):
"""Execute a BigQuery job.
Waits for the job to complete and returns job id.
@@ -2663,6 +2726,13 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
self.deferrable = deferrable
self.poll_interval = poll_interval
+ @property
+ def sql(self) -> str | None:
+ try:
+ return self.configuration["query"]["query"]
+ except KeyError:
+ return None
+
def prepare_template(self) -> None:
# If .json is passed then we have to read the file
if isinstance(self.configuration, str) and self.configuration.endswith(".json"):
@@ -2697,7 +2767,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
)
self.hook = hook
- job_id = hook.generate_job_id(
+ self.job_id = hook.generate_job_id(
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
@@ -2708,13 +2778,13 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
try:
self.log.info("Executing: %s'", self.configuration)
- job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id)
+ job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.project_id,
location=self.location,
- job_id=job_id,
+ job_id=self.job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
@@ -2723,7 +2793,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
else:
# Same job configuration so we need force_rerun
raise AirflowException(
- f"Job with id: {job_id} already exists and is in {job.state} state. If you "
+ f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)
@@ -2757,7 +2827,9 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
self.job_id = job.job_id
project_id = self.project_id or self.hook.project_id
if project_id:
- job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
+ job_id_path = convert_job_id(
+ job_id=self.job_id, project_id=project_id, location=self.location # type: ignore[arg-type]
+ )
context["ti"].xcom_push(key="job_id_path", value=job_id_path)
# Wait for the job to complete
if not self.deferrable:
diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py
index 95d8fa6f28..0926489c0d 100644
--- a/airflow/providers/openlineage/extractors/base.py
+++ b/airflow/providers/openlineage/extractors/base.py
@@ -86,6 +86,12 @@ class DefaultExtractor(BaseExtractor):
# OpenLineage methods are optional - if there's no method, return None
try:
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
+ except ImportError:
+ self.log.error(
+ "OpenLineage provider method failed to import OpenLineage integration. "
+ "This should not happen. Please report this bug to developers."
+ )
+ return None
except AttributeError:
return None
diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py
index ca8b559e3a..20b9afef49 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -23,7 +23,7 @@ import logging
import os
from contextlib import suppress
from functools import wraps
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Iterable
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
import attrs
@@ -414,3 +414,10 @@ def is_source_enabled() -> bool:
def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict:
not_required_keys = {"dag", "task_group"}
return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys}
+
+
+def normalize_sql(sql: str | Iterable[str]):
+ if isinstance(sql, str):
+ sql = [stmt for stmt in sql.split(";") if stmt != ""]
+ sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
+ return ";\n".join(sql)
diff --git a/tests/providers/google/cloud/operators/job_details.json b/tests/providers/google/cloud/operators/job_details.json
new file mode 100644
index 0000000000..f12ec1321d
--- /dev/null
+++ b/tests/providers/google/cloud/operators/job_details.json
@@ -0,0 +1,240 @@
+{
+ "kind": "bigquery#job",
+ "etag": "vd2aBaVVX6a4bUJW13+Tqg==",
+ "id": "airflow:US.job_IDnbVW6NACdFDkermznYm9o4mcVH",
+ "selfLink": "https://bigquery.googleapis.com/bigquery/v2/projects/airflow-openlineage/jobs/job_IDnbVW6NACdFDkermznYm9o4mcVH?location=US",
+ "user_email": "svc-account@airflow-openlineage.iam.gserviceaccount.com",
+ "configuration": {
+ "query": {
+ "query": "Select * from test_table",
+ "destinationTable": {
+ "projectId": "airflow-openlineage",
+ "datasetId": "new_dataset",
+ "tableId": "output_table"
+ },
+ "createDisposition": "CREATE_IF_NEEDED",
+ "writeDisposition": "WRITE_TRUNCATE",
+ "priority": "INTERACTIVE",
+ "allowLargeResults": false,
+ "useLegacySql": false
+ },
+ "jobType": "QUERY"
+ },
+ "jobReference": {
+ "projectId": "airflow-openlineage",
+ "jobId": "job_IDnbVW6NACdFDkermznYm9o4mcVH",
+ "location": "US"
+ },
+ "statistics": {
+ "creationTime": 1.60390893E12,
+ "startTime": 1.60390893E12,
+ "endTime": 1.60390893E12,
+ "totalBytesProcessed": "110355534",
+ "query": {
+ "queryPlan": [{
+ "name": "S00: Input",
+ "id": "0",
+ "startMs": "1603908925668",
+ "endMs": "1603908925880",
+ "waitRatioAvg": 0.0070422534,
+ "waitMsAvg": "2",
+ "waitRatioMax": 0.0070422534,
+ "waitMsMax": "2",
+ "readRatioAvg": 0.14084508,
+ "readMsAvg": "40",
+ "readRatioMax": 0.14084508,
+ "readMsMax": "40",
+ "computeRatioAvg": 1,
+ "computeMsAvg": "284",
+ "computeRatioMax": 1,
+ "computeMsMax": "284",
+ "writeRatioAvg": 0.017605634,
+ "writeMsAvg": "5",
+ "writeRatioMax": 0.017605634,
+ "writeMsMax": "5",
+ "shuffleOutputBytes": "439409",
+ "shuffleOutputBytesSpilled": "0",
+ "recordsRead": "5552452",
+ "recordsWritten": "16142",
+ "parallelInputs": "1",
+ "completedParallelInputs": "1",
+ "status": "COMPLETE",
+ "steps": [{
+ "kind": "READ",
+ "substeps": [
+ "$1:state, $2:name, $3:number",
+ "FROM bigquery-public-data.usa_names.usa_1910_2013",
+ "WHERE equal($1, 'TX')"
+ ]
+ },
+ {
+ "kind": "AGGREGATE",
+ "substeps": [
+ "GROUP BY $30 := $2, $31 := $1",
+ "$20 := SUM($3)"
+ ]
+ },
+ {
+ "kind": "WRITE",
+ "substeps": [
+ "$31, $30, $20",
+ "TO __stage00_output",
+ "BY HASH($30, $31)"
+ ]
+ }
+ ],
+ "slotMs": "448"
+ },
+ {
+ "name": "S01: Sort+",
+ "id": "1",
+ "startMs": "1603908925891",
+ "endMs": "1603908925911",
+ "inputStages": [
+ "0"
+ ],
+ "waitRatioAvg": 0.0070422534,
+ "waitMsAvg": "2",
+ "waitRatioMax": 0.0070422534,
+ "waitMsMax": "2",
+ "readRatioAvg": 0,
+ "readMsAvg": "0",
+ "readRatioMax": 0,
+ "readMsMax": "0",
+ "computeRatioAvg": 0.049295776,
+ "computeMsAvg": "14",
+ "computeRatioMax": 0.049295776,
+ "computeMsMax": "14",
+ "writeRatioAvg": 0.0070422534,
+ "writeMsAvg": "2",
+ "writeRatioMax": 0.0070422534,
+ "writeMsMax": "2",
+ "shuffleOutputBytes": "401",
+ "shuffleOutputBytesSpilled": "0",
+ "recordsRead": "16142",
+ "recordsWritten": "20",
+ "parallelInputs": "1",
+ "completedParallelInputs": "1",
+ "status": "COMPLETE",
+ "steps": [{
+ "kind": "READ",
+ "substeps": [
+ "$31, $30, $20",
+ "FROM __stage00_output"
+ ]
+ },
+ {
+ "kind": "SORT",
+ "substeps": [
+ "$10 DESC",
+ "LIMIT 20"
+ ]
+ },
+ {
+ "kind": "AGGREGATE",
+ "substeps": [
+ "GROUP BY $40 := $30, $41 := $31",
+ "$10 := SUM($20)"
+ ]
+ },
+ {
+ "kind": "WRITE",
+ "substeps": [
+ "$50, $51",
+ "TO __stage01_output"
+ ]
+ }
+ ],
+ "slotMs": "33"
+ },
+ {
+ "name": "S02: Output",
+ "id": "2",
+ "startMs": "1603908926017",
+ "endMs": "1603908926191",
+ "inputStages": [
+ "1"
+ ],
+ "waitRatioAvg": 0.4471831,
+ "waitMsAvg": "127",
+ "waitRatioMax": 0.4471831,
+ "waitMsMax": "127",
+ "readRatioAvg": 0,
+ "readMsAvg": "0",
+ "readRatioMax": 0,
+ "readMsMax": "0",
+ "computeRatioAvg": 0.03169014,
+ "computeMsAvg": "9",
+ "computeRatioMax": 0.03169014,
+ "computeMsMax": "9",
+ "writeRatioAvg": 0.5633803,
+ "writeMsAvg": "160",
+ "writeRatioMax": 0.5633803,
+ "writeMsMax": "160",
+ "shuffleOutputBytes": "321",
+ "shuffleOutputBytesSpilled": "0",
+ "recordsRead": "20",
+ "recordsWritten": "20",
+ "parallelInputs": "1",
+ "completedParallelInputs": "1",
+ "status": "COMPLETE",
+ "steps": [{
+ "kind": "READ",
+ "substeps": [
+ "$50, $51",
+ "FROM __stage01_output"
+ ]
+ },
+ {
+ "kind": "SORT",
+ "substeps": [
+ "$51 DESC",
+ "LIMIT 20"
+ ]
+ },
+ {
+ "kind": "WRITE",
+ "substeps": [
+ "$60, $61",
+ "TO __stage02_output"
+ ]
+ }
+ ],
+ "slotMs": "342"
+ }
+ ],
+ "estimatedBytesProcessed": "110355534",
+ "timeline": [{
+ "elapsedMs": "736",
+ "totalSlotMs": "482",
+ "pendingUnits": "1",
+ "completedUnits": "2",
+ "activeUnits": "1"
+ },
+ {
+ "elapsedMs": "1045",
+ "totalSlotMs": "825",
+ "pendingUnits": "0",
+ "completedUnits": "3",
+ "activeUnits": "1"
+ }
+ ],
+ "totalPartitionsProcessed": "0",
+ "totalBytesProcessed": "110355534",
+ "totalBytesBilled": "111149056",
+ "billingTier": 1,
+ "totalSlotMs": "825",
+ "cacheHit": false,
+ "referencedTables": [{
+ "projectId": "airflow-openlineage",
+ "datasetId": "new_dataset",
+ "tableId": "test_table"
+ }],
+ "statementType": "SELECT"
+ },
+ "totalSlotMs": "825"
+ },
+ "status": {
+ "state": "DONE"
+ }
+}
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py
index 4026b4ba45..d17f5498e2 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+import json
+from contextlib import suppress
from unittest import mock
from unittest.mock import ANY, MagicMock
@@ -24,6 +26,13 @@ import pandas as pd
import pytest
from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.exceptions import Conflict
+from openlineage.client.facet import (
+ DataSourceDatasetFacet,
+ ExternalQueryRunFacet,
+ SqlJobFacet,
+)
+from openlineage.client.run import Dataset
+from openlineage.common.provider.bigquery import BigQueryErrorRunFacet
from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout, TaskDeferred
from airflow.providers.google.cloud.operators.bigquery import (
@@ -1520,6 +1529,88 @@ class TestBigQueryInsertJobOperator:
force_rerun=True,
)
+ @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+ def test_execute_openlineage_events(self, mock_hook):
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM test_table",
+ "useLegacySql": False,
+ }
+ }
+ mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
+
+ op = BigQueryInsertJobOperator(
+ task_id="insert_query_job",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+ result = op.execute(context=MagicMock())
+
+ mock_hook.return_value.insert_job.assert_called_once_with(
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=real_job_id,
+ nowait=True,
+ project_id=TEST_GCP_PROJECT_ID,
+ retry=DEFAULT_RETRY,
+ timeout=None,
+ )
+
+ assert result == real_job_id
+
+ with open(file="tests/providers/google/cloud/operators/job_details.json") as f:
+ job_details = json.loads(f.read())
+ mock_hook.return_value.get_client.return_value.get_job.return_value._properties = job_details
+
+ lineage = op.get_openlineage_facets_on_complete(None)
+ assert lineage.inputs == [
+ Dataset(
+ namespace="bigquery",
+ name="airflow-openlineage.new_dataset.test_table",
+ facets={"dataSource": DataSourceDatasetFacet(name="bigquery", uri="bigquery")},
+ )
+ ]
+
+ assert lineage.run_facets == {
+ "bigQuery_job": mock.ANY,
+ "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"),
+ }
+ assert lineage.job_facets == {"sql": SqlJobFacet(query="SELECT * FROM test_table")}
+
+ @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+ def test_execute_fails_openlineage_events(self, mock_hook):
+ job_id = "1234"
+
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM test_table",
+ "useLegacySql": False,
+ }
+ }
+ operator = BigQueryInsertJobOperator(
+ task_id="insert_query_job_failed",
+ configuration=configuration,
+ location=TEST_DATASET_LOCATION,
+ job_id=job_id,
+ project_id=TEST_GCP_PROJECT_ID,
+ )
+ mock_hook.return_value.generate_job_id.return_value = "1234"
+ mock_hook.return_value.get_client.return_value.get_job.side_effect = RuntimeError()
+ mock_hook.return_value.insert_job.side_effect = RuntimeError()
+
+ with suppress(RuntimeError):
+ operator.execute(MagicMock())
+ lineage = operator.get_openlineage_facets_on_complete(None)
+
+ assert lineage.run_facets == {"bigQuery_error": BigQueryErrorRunFacet(clientError=mock.ANY)}
+
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_execute_force_rerun_async(self, mock_hook, create_task_instance_of_operator):
job_id = "123456"