You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mo...@apache.org on 2023/08/23 09:43:00 UTC
[airflow] 01/01: system tests: implement operator, variable transport
This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch openlineage-system-tests
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 29143bd372ecad00cd474ebd36f03e4ddefd774d
Author: Maciej Obuchowski <ob...@gmail.com>
AuthorDate: Thu Jun 1 15:12:53 2023 +0200
system tests: implement operator, variable transport
Signed-off-by: Maciej Obuchowski <ob...@gmail.com>
---
.../providers/google/cloud/operators/bigquery.py | 18 +++
airflow/providers/openlineage/plugins/listener.py | 11 +-
airflow/providers/openlineage/provider.yaml | 4 +-
.../providers/openlineage/transport/__init__.py | 16 ++
.../providers/openlineage/transport/variable.py | 51 ++++++
.../apache-airflow-providers-openlineage/index.rst | 4 +-
generated/provider_dependencies.json | 4 +-
tests/system/conftest.py | 14 ++
.../cloud/bigquery/example_bigquery_queries.py | 13 +-
tests/test_utils/openlineage.py | 173 +++++++++++++++++++++
10 files changed, 298 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 75ae73b6f9..1222bb1ab8 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -1074,6 +1074,24 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
self.log.info("Total extracted rows: %s", len(event["records"]))
return event["records"]
+ def get_openlineage_facets_on_start(self):
+ from openlineage.client.run import Dataset
+
+ from airflow.providers.openlineage.extractors import OperatorLineage
+
+ if self.project_id is None:
+ self.project_id = BigQueryHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ use_legacy_sql=self.use_legacy_sql,
+ ).project_id
+
+ return OperatorLineage(
+ inputs=[
+ Dataset(namespace="bigquery", name=f"{self.project_id}.{self.dataset_id}.{self.table_id}")
+ ]
+ )
+
class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
"""Executes BigQuery SQL queries in a specific BigQuery database.
diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py
index 4a6b75f677..38156f2d02 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -46,6 +46,7 @@ class OpenLineageListener:
self.log = logging.getLogger(__name__)
self.extractor_manager = ExtractorManager()
self.adapter = OpenLineageAdapter()
+ self.current_ti: TaskInstance | None = None
@hookimpl
def on_task_instance_running(
@@ -59,6 +60,7 @@ class OpenLineageListener:
return
self.log.debug("OpenLineage listener got notification about task instance start")
+ self.current_ti = task_instance
dagrun = task_instance.dag_run
task = task_instance.task
dag = task.dag
@@ -101,12 +103,13 @@ class OpenLineageListener:
**get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid),
},
)
-
on_running()
+
@hookimpl
def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session):
self.log.debug("OpenLineage listener got notification about task instance success")
+ self.current_ti = task_instance
dagrun = task_instance.dag_run
task = task_instance.task
@@ -135,6 +138,7 @@ class OpenLineageListener:
@hookimpl
def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session):
self.log.debug("OpenLineage listener got notification about task instance failure")
+ self.current_ti = task_instance
dagrun = task_instance.dag_run
task = task_instance.task
@@ -174,8 +178,9 @@ class OpenLineageListener:
def before_stopping(self, component):
self.log.debug("before_stopping: %s", component.__class__.__name__)
# TODO: configure this with Airflow config
- with timeout(30):
- self.executor.shutdown(wait=True)
+ if self._executor:
+ with timeout(30):
+ self.executor.shutdown(wait=True)
@hookimpl
def on_dag_run_running(self, dag_run: DagRun, msg: str):
diff --git a/airflow/providers/openlineage/provider.yaml b/airflow/providers/openlineage/provider.yaml
index 0839428e44..2512b33249 100644
--- a/airflow/providers/openlineage/provider.yaml
+++ b/airflow/providers/openlineage/provider.yaml
@@ -30,8 +30,8 @@ dependencies:
- apache-airflow>=2.7.0
- apache-airflow-providers-common-sql>=1.6.0
- attrs>=22.2
- - openlineage-integration-common>=0.28.0
- - openlineage-python>=0.28.0
+ - openlineage-integration-common>=0.29.2
+ - openlineage-python>=0.29.2
integrations:
- integration-name: OpenLineage
diff --git a/airflow/providers/openlineage/transport/__init__.py b/airflow/providers/openlineage/transport/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/openlineage/transport/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/openlineage/transport/variable.py b/airflow/providers/openlineage/transport/variable.py
new file mode 100644
index 0000000000..fc9dd4a630
--- /dev/null
+++ b/airflow/providers/openlineage/transport/variable.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.models import Variable
+from airflow.plugins_manager import AirflowPlugin, plugins
+from airflow.utils.log.logging_mixin import LoggingMixin
+from openlineage.client.run import DatasetEvent, JobEvent, RunEvent
+from openlineage.client.serde import Serde
+from openlineage.client.transport import Transport
+
+
+class VariableTransport(Transport, LoggingMixin):
+ """This transport sends OpenLineage events to Variables.
+ Key schema is <DAG_ID>.<TASK_ID>.event.<EVENT_TYPE>.
+ It's made to be used in system tests, stored data read by OpenLineageTestOperator.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ ...
+
+ def emit(self, event: RunEvent | DatasetEvent | JobEvent):
+ from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin
+
+ plugin: AirflowPlugin | None = next( # type: ignore[assignment]
+ filter(lambda x: isinstance(x, OpenLineageProviderPlugin), plugins) # type: ignore[arg-type]
+ )
+ if not plugin:
+ raise RuntimeError("OpenLineage listener should be set up here")
+
+ listener = plugin.listeners[0] # type: ignore
+ ti = listener.current_ti # type: ignore
+
+ key = f"{ti.dag_id}.{ti.task_id}.event.{event.eventType.value.lower()}" # type: ignore[union-attr]
+ str_event = Serde.to_json(event)
+ Variable.set(key=key, value=str_event)
diff --git a/docs/apache-airflow-providers-openlineage/index.rst b/docs/apache-airflow-providers-openlineage/index.rst
index 5d6c006d4e..a4a52f46db 100644
--- a/docs/apache-airflow-providers-openlineage/index.rst
+++ b/docs/apache-airflow-providers-openlineage/index.rst
@@ -116,8 +116,8 @@ PIP package Version required
``apache-airflow`` ``>=2.7.0``
``apache-airflow-providers-common-sql`` ``>=1.6.0``
``attrs`` ``>=22.2``
-``openlineage-integration-common`` ``>=0.28.0``
-``openlineage-python`` ``>=0.28.0``
+``openlineage-integration-common`` ``>=0.29.2``
+``openlineage-python`` ``>=0.29.2``
======================================= ==================
Cross provider package dependencies
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 0d5a7408ea..950507e3f9 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -662,8 +662,8 @@
"apache-airflow-providers-common-sql>=1.6.0",
"apache-airflow>=2.7.0",
"attrs>=22.2",
- "openlineage-integration-common>=0.28.0",
- "openlineage-python>=0.28.0"
+ "openlineage-integration-common>=0.29.2",
+ "openlineage-python>=0.29.2"
],
"cross-providers-deps": [
"common.sql"
diff --git a/tests/system/conftest.py b/tests/system/conftest.py
index 58eca1287c..21bc381c28 100644
--- a/tests/system/conftest.py
+++ b/tests/system/conftest.py
@@ -24,9 +24,23 @@ from unittest import mock
import pytest
+from airflow import plugins_manager
+from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin
+
REQUIRED_ENV_VARS = ("SYSTEM_TESTS_ENV_ID",)
+@pytest.fixture(scope="package", autouse=True)
+def setup_openlineage():
+ with mock.patch.dict(
+ "os.environ",
+ AIRFLOW__OPENLINEAGE__TRANSPORT='{"type": "airflow.providers.openlineage.transport.variable'
+ '.VariableTransport"}',
+ ):
+ plugins_manager.register_plugin(OpenLineageProviderPlugin())
+ yield
+
+
@pytest.fixture(scope="package", autouse=True)
def use_debug_executor():
with mock.patch.dict("os.environ", AIRFLOW__CORE__EXECUTOR="DebugExecutor"):
diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
index 3ce1bc2801..57d0ace625 100644
--- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
@@ -38,6 +38,7 @@ from airflow.providers.google.cloud.operators.bigquery import (
BigQueryValueCheckOperator,
)
from airflow.utils.trigger_rule import TriggerRule
+from tests.test_utils.openlineage import OpenlineageTestOperator
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
@@ -235,11 +236,21 @@ for index, location in enumerate(locations, 1):
trigger_rule=TriggerRule.ALL_DONE,
)
+ openlineage_test = OpenlineageTestOperator(
+ task_id="openlineage_test",
+ event_templates={
+ f"{DAG_ID}.get_data.event.start": {
+ "eventType": "START",
+ "inputs": [{"namespace": "bigquery", "name": f"{PROJECT_ID}.{DATASET}.{TABLE_1}"}],
+ }
+ },
+ )
+
# TEST SETUP
create_dataset >> [create_table_1, create_table_2]
# TEST BODY
[create_table_1, create_table_2] >> insert_query_job >> [select_query_job, execute_insert_query]
- execute_insert_query >> get_data >> get_data_result >> delete_dataset
+ execute_insert_query >> get_data >> get_data_result >> delete_dataset >> openlineage_test
execute_insert_query >> execute_query_save >> bigquery_execute_multi_query >> delete_dataset
execute_insert_query >> [check_count, check_value, check_interval] >> delete_dataset
execute_insert_query >> [column_check, table_check] >> delete_dataset
diff --git a/tests/test_utils/openlineage.py b/tests/test_utils/openlineage.py
new file mode 100644
index 0000000000..5dc7e7c45f
--- /dev/null
+++ b/tests/test_utils/openlineage.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+from urllib.parse import urlparse
+
+from jinja2 import Environment
+
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.variable import Variable
+from airflow.utils.context import Context
+
+log = logging.getLogger(__name__)
+
+
+def any(result):
+ return result
+
+
+def is_datetime(result):
+ try:
+ x = parse(result) # noqa
+ return "true"
+ except: # noqa
+ pass
+ return "false"
+
+
+def is_uuid(result):
+ try:
+ uuid.UUID(result) # noqa
+ return "true"
+ except: # noqa
+ pass
+ return "false"
+
+
+def env_var(var: str, default: str | None = None) -> str:
+ """The env_var() function. Return the environment variable named 'var'.
+ If there is no such environment variable set, return the default.
+ If the default is None, raise an exception for an undefined variable.
+ """
+ if var in os.environ:
+ return os.environ[var]
+ elif default is not None:
+ return default
+ else:
+ msg = f"Env var required but not provided: '{var}'"
+ raise Exception(msg)
+
+
+def not_match(result, pattern) -> str:
+ if pattern in result:
+ raise Exception(f"Found {pattern} in {result}")
+ return "true"
+
+
+def url_scheme_authority(url) -> str:
+ parsed = urlparse(url)
+ return f"{parsed.scheme}://{parsed.netloc}"
+
+
+def url_path(url) -> str:
+ return urlparse(url).path
+
+
+def setup_jinja() -> Environment:
+ env = Environment()
+ env.globals["any"] = any
+ env.globals["is_datetime"] = is_datetime
+ env.globals["is_uuid"] = is_uuid
+ env.globals["env_var"] = env_var
+ env.globals["not_match"] = not_match
+ env.filters["url_scheme_authority"] = url_scheme_authority
+ env.filters["url_path"] = url_path
+ return env
+
+
+env = setup_jinja()
+
+
+def match(expected, result) -> bool:
+ """
+ Check if result is "equal" to expected value. Omits keys not specified in expected value
+ and resolves any jinja templates found.
+ """
+ if isinstance(expected, dict):
+ # Take a look only at keys present at expected dictionary
+ for k, v in expected.items():
+ if k not in result:
+ log.error("Key %s not in received event %s\nExpected %s", k, result, expected)
+ return False
+ if not match(v, result[k]):
+ log.error(
+ "For key %s, expected value %s not equals received %s\nExpected: %s, request: %s",
+ k,
+ v,
+ result[k],
+ expected,
+ result,
+ )
+ return False
+ elif isinstance(expected, list):
+ if len(expected) != len(result):
+ log.error("Length does not match: expected %d, result: %d", len(expected), len(result))
+ return False
+ for i, x in enumerate(expected):
+ if not match(x, result[i]):
+ log.error(
+ "List not matched at %d\nexpected:\n%s\nresult: \n%s",
+ i,
+ json.dumps(x),
+ json.dumps(result[i]),
+ )
+ return False
+ elif isinstance(expected, str):
+ if "{{" in expected:
+ # Evaluate jinja: in some cases, we want to check only if key exists, or if
+ # value has the right type
+ rendered = env.from_string(expected).render(result=result)
+ if rendered == "true" or rendered == result:
+ return True
+ log.error("Rendered value %s does not equal 'true' or %s", rendered, result)
+ return False
+ elif expected != result:
+ log.error("Expected value %s does not equal result %s", expected, result)
+ return False
+ elif expected != result:
+ log.error("Object of type %s: %s does not match %s", type(expected), expected, result)
+ return False
+ return True
+
+
+class OpenlineageTestOperator(BaseOperator):
+ """Operator for testing purposes.
+ It compares expected event templates set on initialization with ones emitted by OpenLineage integration
+ and stored in Variables by VariableTransport.
+ :param event_templates: dictionary where key is the key used by VariableTransport in format of
+ <DAG_ID>.<TASK_ID>.event.<EVENT_TYPE>, and value is event template (fragment)
+ that need to be in received events.
+ :raises: ValueError if the received events do not match with expected ones.
+ """
+
+ def __init__(self, event_templates: dict[str, dict], **kwargs):
+ super().__init__(**kwargs)
+ self.event_templates = event_templates
+
+ def execute(self, context: Context):
+ for key, template in self.event_templates.items():
+ send_event = Variable.get(key=key)
+ self.log.error("Events: %s", send_event)
+ if send_event:
+ self.log.error("Events: %s, %s, %s", send_event, len(send_event), type(send_event))
+ if not match(template, json.loads(send_event)):
+ raise ValueError("Event received does not match one specified in test")