You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mo...@apache.org on 2023/08/23 09:43:00 UTC

[airflow] 01/01: system tests: implement operator, variable transport

This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-system-tests
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 29143bd372ecad00cd474ebd36f03e4ddefd774d
Author: Maciej Obuchowski <ob...@gmail.com>
AuthorDate: Thu Jun 1 15:12:53 2023 +0200

    system tests: implement operator, variable transport
    
    Signed-off-by: Maciej Obuchowski <ob...@gmail.com>
---
 .../providers/google/cloud/operators/bigquery.py   |  18 +++
 airflow/providers/openlineage/plugins/listener.py  |  11 +-
 airflow/providers/openlineage/provider.yaml        |   4 +-
 .../providers/openlineage/transport/__init__.py    |  16 ++
 .../providers/openlineage/transport/variable.py    |  51 ++++++
 .../apache-airflow-providers-openlineage/index.rst |   4 +-
 generated/provider_dependencies.json               |   4 +-
 tests/system/conftest.py                           |  14 ++
 .../cloud/bigquery/example_bigquery_queries.py     |  13 +-
 tests/test_utils/openlineage.py                    | 173 +++++++++++++++++++++
 10 files changed, 298 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index 75ae73b6f9..1222bb1ab8 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -1074,6 +1074,24 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         self.log.info("Total extracted rows: %s", len(event["records"]))
         return event["records"]
 
+    def get_openlineage_facets_on_start(self):
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        if self.project_id is None:
+            self.project_id = BigQueryHook(
+                gcp_conn_id=self.gcp_conn_id,
+                impersonation_chain=self.impersonation_chain,
+                use_legacy_sql=self.use_legacy_sql,
+            ).project_id
+
+        return OperatorLineage(
+            inputs=[
+                Dataset(namespace="bigquery", name=f"{self.project_id}.{self.dataset_id}.{self.table_id}")
+            ]
+        )
+
 
 class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
     """Executes BigQuery SQL queries in a specific BigQuery database.
diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py
index 4a6b75f677..38156f2d02 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -46,6 +46,7 @@ class OpenLineageListener:
         self.log = logging.getLogger(__name__)
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
+        self.current_ti: TaskInstance | None = None
 
     @hookimpl
     def on_task_instance_running(
@@ -59,6 +60,7 @@ class OpenLineageListener:
             return
 
         self.log.debug("OpenLineage listener got notification about task instance start")
+        self.current_ti = task_instance
         dagrun = task_instance.dag_run
         task = task_instance.task
         dag = task.dag
@@ -101,12 +103,13 @@ class OpenLineageListener:
                     **get_airflow_run_facet(dagrun, dag, task_instance, task, task_uuid),
                 },
             )
-
         on_running()
 
+
     @hookimpl
     def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session):
         self.log.debug("OpenLineage listener got notification about task instance success")
+        self.current_ti = task_instance
 
         dagrun = task_instance.dag_run
         task = task_instance.task
@@ -135,6 +138,7 @@ class OpenLineageListener:
     @hookimpl
     def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session):
         self.log.debug("OpenLineage listener got notification about task instance failure")
+        self.current_ti = task_instance
 
         dagrun = task_instance.dag_run
         task = task_instance.task
@@ -174,8 +178,9 @@ class OpenLineageListener:
     def before_stopping(self, component):
         self.log.debug("before_stopping: %s", component.__class__.__name__)
         # TODO: configure this with Airflow config
-        with timeout(30):
-            self.executor.shutdown(wait=True)
+        if self._executor:
+            with timeout(30):
+                self.executor.shutdown(wait=True)
 
     @hookimpl
     def on_dag_run_running(self, dag_run: DagRun, msg: str):
diff --git a/airflow/providers/openlineage/provider.yaml b/airflow/providers/openlineage/provider.yaml
index 0839428e44..2512b33249 100644
--- a/airflow/providers/openlineage/provider.yaml
+++ b/airflow/providers/openlineage/provider.yaml
@@ -30,8 +30,8 @@ dependencies:
   - apache-airflow>=2.7.0
   - apache-airflow-providers-common-sql>=1.6.0
   - attrs>=22.2
-  - openlineage-integration-common>=0.28.0
-  - openlineage-python>=0.28.0
+  - openlineage-integration-common>=0.29.2
+  - openlineage-python>=0.29.2
 
 integrations:
   - integration-name: OpenLineage
diff --git a/airflow/providers/openlineage/transport/__init__.py b/airflow/providers/openlineage/transport/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/openlineage/transport/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/openlineage/transport/variable.py b/airflow/providers/openlineage/transport/variable.py
new file mode 100644
index 0000000000..fc9dd4a630
--- /dev/null
+++ b/airflow/providers/openlineage/transport/variable.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.models import Variable
+from airflow.plugins_manager import AirflowPlugin, plugins
+from airflow.utils.log.logging_mixin import LoggingMixin
+from openlineage.client.run import DatasetEvent, JobEvent, RunEvent
+from openlineage.client.serde import Serde
+from openlineage.client.transport import Transport
+
+
+class VariableTransport(Transport, LoggingMixin):
+    """This transport sends OpenLineage events to Variables.
+    Key schema is <DAG_ID>.<TASK_ID>.event.<EVENT_TYPE>.
+    It's made to be used in system tests, stored data read by OpenLineageTestOperator.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        ...
+
+    def emit(self, event: RunEvent | DatasetEvent | JobEvent):
+        from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin
+
+        plugin: AirflowPlugin | None = next(  # type: ignore[assignment]
+            filter(lambda x: isinstance(x, OpenLineageProviderPlugin), plugins)  # type: ignore[arg-type]
+        )
+        if not plugin:
+            raise RuntimeError("OpenLineage listener should be set up here")
+
+        listener = plugin.listeners[0]  # type: ignore
+        ti = listener.current_ti  # type: ignore
+
+        key = f"{ti.dag_id}.{ti.task_id}.event.{event.eventType.value.lower()}"  # type: ignore[union-attr]
+        str_event = Serde.to_json(event)
+        Variable.set(key=key, value=str_event)
diff --git a/docs/apache-airflow-providers-openlineage/index.rst b/docs/apache-airflow-providers-openlineage/index.rst
index 5d6c006d4e..a4a52f46db 100644
--- a/docs/apache-airflow-providers-openlineage/index.rst
+++ b/docs/apache-airflow-providers-openlineage/index.rst
@@ -116,8 +116,8 @@ PIP package                              Version required
 ``apache-airflow``                       ``>=2.7.0``
 ``apache-airflow-providers-common-sql``  ``>=1.6.0``
 ``attrs``                                ``>=22.2``
-``openlineage-integration-common``       ``>=0.28.0``
-``openlineage-python``                   ``>=0.28.0``
+``openlineage-integration-common``       ``>=0.29.2``
+``openlineage-python``                   ``>=0.29.2``
 =======================================  ==================
 
 Cross provider package dependencies
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 0d5a7408ea..950507e3f9 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -662,8 +662,8 @@
       "apache-airflow-providers-common-sql>=1.6.0",
       "apache-airflow>=2.7.0",
       "attrs>=22.2",
-      "openlineage-integration-common>=0.28.0",
-      "openlineage-python>=0.28.0"
+      "openlineage-integration-common>=0.29.2",
+      "openlineage-python>=0.29.2"
     ],
     "cross-providers-deps": [
       "common.sql"
diff --git a/tests/system/conftest.py b/tests/system/conftest.py
index 58eca1287c..21bc381c28 100644
--- a/tests/system/conftest.py
+++ b/tests/system/conftest.py
@@ -24,9 +24,23 @@ from unittest import mock
 
 import pytest
 
+from airflow import plugins_manager
+from airflow.providers.openlineage.plugins.openlineage import OpenLineageProviderPlugin
+
 REQUIRED_ENV_VARS = ("SYSTEM_TESTS_ENV_ID",)
 
 
+@pytest.fixture(scope="package", autouse=True)
+def setup_openlineage():
+    with mock.patch.dict(
+        "os.environ",
+        AIRFLOW__OPENLINEAGE__TRANSPORT='{"type": "airflow.providers.openlineage.transport.variable'
+        '.VariableTransport"}',
+    ):
+        plugins_manager.register_plugin(OpenLineageProviderPlugin())
+        yield
+
+
 @pytest.fixture(scope="package", autouse=True)
 def use_debug_executor():
     with mock.patch.dict("os.environ", AIRFLOW__CORE__EXECUTOR="DebugExecutor"):
diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
index 3ce1bc2801..57d0ace625 100644
--- a/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
@@ -38,6 +38,7 @@ from airflow.providers.google.cloud.operators.bigquery import (
     BigQueryValueCheckOperator,
 )
 from airflow.utils.trigger_rule import TriggerRule
+from tests.test_utils.openlineage import OpenlineageTestOperator
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
 PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
@@ -235,11 +236,21 @@ for index, location in enumerate(locations, 1):
             trigger_rule=TriggerRule.ALL_DONE,
         )
 
+        openlineage_test = OpenlineageTestOperator(
+            task_id="openlineage_test",
+            event_templates={
+                f"{DAG_ID}.get_data.event.start": {
+                    "eventType": "START",
+                    "inputs": [{"namespace": "bigquery", "name": f"{PROJECT_ID}.{DATASET}.{TABLE_1}"}],
+                }
+            },
+        )
+
         # TEST SETUP
         create_dataset >> [create_table_1, create_table_2]
         # TEST BODY
         [create_table_1, create_table_2] >> insert_query_job >> [select_query_job, execute_insert_query]
-        execute_insert_query >> get_data >> get_data_result >> delete_dataset
+        execute_insert_query >> get_data >> get_data_result >> delete_dataset >> openlineage_test
         execute_insert_query >> execute_query_save >> bigquery_execute_multi_query >> delete_dataset
         execute_insert_query >> [check_count, check_value, check_interval] >> delete_dataset
         execute_insert_query >> [column_check, table_check] >> delete_dataset
diff --git a/tests/test_utils/openlineage.py b/tests/test_utils/openlineage.py
new file mode 100644
index 0000000000..5dc7e7c45f
--- /dev/null
+++ b/tests/test_utils/openlineage.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+from urllib.parse import urlparse
+
+from jinja2 import Environment
+
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.variable import Variable
+from airflow.utils.context import Context
+
+log = logging.getLogger(__name__)
+
+
+def any(result):
+    return result
+
+
+def is_datetime(result):
+    try:
+        x = parse(result)  # noqa
+        return "true"
+    except:  # noqa
+        pass
+    return "false"
+
+
+def is_uuid(result):
+    try:
+        uuid.UUID(result)  # noqa
+        return "true"
+    except:  # noqa
+        pass
+    return "false"
+
+
+def env_var(var: str, default: str | None = None) -> str:
+    """The env_var() function. Return the environment variable named 'var'.
+    If there is no such environment variable set, return the default.
+    If the default is None, raise an exception for an undefined variable.
+    """
+    if var in os.environ:
+        return os.environ[var]
+    elif default is not None:
+        return default
+    else:
+        msg = f"Env var required but not provided: '{var}'"
+        raise Exception(msg)
+
+
+def not_match(result, pattern) -> str:
+    if pattern in result:
+        raise Exception(f"Found {pattern} in {result}")
+    return "true"
+
+
+def url_scheme_authority(url) -> str:
+    parsed = urlparse(url)
+    return f"{parsed.scheme}://{parsed.netloc}"
+
+
+def url_path(url) -> str:
+    return urlparse(url).path
+
+
+def setup_jinja() -> Environment:
+    env = Environment()
+    env.globals["any"] = any
+    env.globals["is_datetime"] = is_datetime
+    env.globals["is_uuid"] = is_uuid
+    env.globals["env_var"] = env_var
+    env.globals["not_match"] = not_match
+    env.filters["url_scheme_authority"] = url_scheme_authority
+    env.filters["url_path"] = url_path
+    return env
+
+
+env = setup_jinja()
+
+
+def match(expected, result) -> bool:
+    """
+    Check if result is "equal" to expected value. Omits keys not specified in expected value
+    and resolves any jinja templates found.
+    """
+    if isinstance(expected, dict):
+        # Take a look only at keys present at expected dictionary
+        for k, v in expected.items():
+            if k not in result:
+                log.error("Key %s not in received event %s\nExpected %s", k, result, expected)
+                return False
+            if not match(v, result[k]):
+                log.error(
+                    "For key %s, expected value %s not equals received %s\nExpected: %s, request: %s",
+                    k,
+                    v,
+                    result[k],
+                    expected,
+                    result,
+                )
+                return False
+    elif isinstance(expected, list):
+        if len(expected) != len(result):
+            log.error("Length does not match: expected %d, result: %d", len(expected), len(result))
+            return False
+        for i, x in enumerate(expected):
+            if not match(x, result[i]):
+                log.error(
+                    "List not matched at %d\nexpected:\n%s\nresult: \n%s",
+                    i,
+                    json.dumps(x),
+                    json.dumps(result[i]),
+                )
+                return False
+    elif isinstance(expected, str):
+        if "{{" in expected:
+            # Evaluate jinja: in some cases, we want to check only if key exists, or if
+            # value has the right type
+            rendered = env.from_string(expected).render(result=result)
+            if rendered == "true" or rendered == result:
+                return True
+            log.error("Rendered value %s does not equal 'true' or %s", rendered, result)
+            return False
+        elif expected != result:
+            log.error("Expected value %s does not equal result %s", expected, result)
+            return False
+    elif expected != result:
+        log.error("Object of type %s: %s does not match %s", type(expected), expected, result)
+        return False
+    return True
+
+
+class OpenlineageTestOperator(BaseOperator):
+    """Operator for testing purposes.
+    It compares expected event templates set on initialization with ones emitted by OpenLineage integration
+    and stored in Variables by VariableTransport.
+    :param event_templates: dictionary where key is the key used by VariableTransport in format of
+        <DAG_ID>.<TASK_ID>.event.<EVENT_TYPE>, and value is event template (fragment)
+         that need to be in received events.
+    :raises: ValueError if the received events do not match with expected ones.
+    """
+
+    def __init__(self, event_templates: dict[str, dict], **kwargs):
+        super().__init__(**kwargs)
+        self.event_templates = event_templates
+
+    def execute(self, context: Context):
+        for key, template in self.event_templates.items():
+            send_event = Variable.get(key=key)
+            self.log.error("Events: %s", send_event)
+            if send_event:
+                self.log.error("Events: %s, %s, %s", send_event, len(send_event), type(send_event))
+            if not match(template, json.loads(send_event)):
+                raise ValueError("Event received does not match one specified in test")