You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2023/11/13 14:15:17 UTC
(airflow) branch main updated: Add Listener hooks for Datasets (#34418)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9439111e73 Add Listener hooks for Datasets (#34418)
9439111e73 is described below
commit 9439111e739e24f0e3751350186b0e2130d2c821
Author: Pasha Yermalovich <14...@users.noreply.github.com>
AuthorDate: Mon Nov 13 15:15:10 2023 +0100
Add Listener hooks for Datasets (#34418)
This PR creates listener hooks for the following Dataset events
* on_dataset_created
* on_dataset_changed
closes: #34327
---
airflow/datasets/manager.py | 23 ++++++++-
airflow/listeners/listener.py | 3 +-
airflow/listeners/spec/dataset.py | 41 +++++++++++++++
airflow/models/dag.py | 16 +++---
.../administration-and-deployment/listeners.rst | 7 +++
tests/datasets/test_manager.py | 35 +++++++++++++
tests/listeners/dataset_listener.py | 45 ++++++++++++++++
tests/listeners/test_dataset_listener.py | 60 ++++++++++++++++++++++
8 files changed, 221 insertions(+), 9 deletions(-)
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 8714ba0658..08871c9f65 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING
from sqlalchemy import exc, select
from airflow.configuration import conf
+from airflow.datasets import Dataset
+from airflow.listeners.listener import get_listener_manager
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -29,7 +31,6 @@ from airflow.utils.log.logging_mixin import LoggingMixin
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
- from airflow.datasets import Dataset
from airflow.models.taskinstance import TaskInstance
@@ -44,6 +45,15 @@ class DatasetManager(LoggingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
+ def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
+ """Create new datasets."""
+ for dataset_model in dataset_models:
+ session.add(dataset_model)
+ session.flush()
+
+ for dataset_model in dataset_models:
+ self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))
+
def register_dataset_change(
self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, session: Session, **kwargs
) -> None:
@@ -68,11 +78,22 @@ class DatasetManager(LoggingMixin):
)
)
session.flush()
+
+ self.notify_dataset_changed(dataset=dataset)
+
Stats.incr("dataset.updates")
if dataset_model.consuming_dags:
self._queue_dagruns(dataset_model, session)
session.flush()
+ def notify_dataset_created(self, dataset: Dataset):
+ """Run applicable notification actions when a dataset is created."""
+ get_listener_manager().hook.on_dataset_created(dataset=dataset)
+
+ def notify_dataset_changed(self, dataset: Dataset):
+ """Run applicable notification actions when a dataset is changed."""
+ get_listener_manager().hook.on_dataset_changed(dataset=dataset)
+
def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
# Possible race condition: if multiple dags or multiple (usually
# mapped) tasks update the same dataset, this can fail with a unique
diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py
index eb738c3e91..d7944aa4eb 100644
--- a/airflow/listeners/listener.py
+++ b/airflow/listeners/listener.py
@@ -37,11 +37,12 @@ class ListenerManager:
"""Manage listener registration and provides hook property for calling them."""
def __init__(self):
- from airflow.listeners.spec import dagrun, lifecycle, taskinstance
+ from airflow.listeners.spec import dagrun, dataset, lifecycle, taskinstance
self.pm = pluggy.PluginManager("airflow")
self.pm.add_hookspecs(lifecycle)
self.pm.add_hookspecs(dagrun)
+ self.pm.add_hookspecs(dataset)
self.pm.add_hookspecs(taskinstance)
@property
diff --git a/airflow/listeners/spec/dataset.py b/airflow/listeners/spec/dataset.py
new file mode 100644
index 0000000000..214ddad3ff
--- /dev/null
+++ b/airflow/listeners/spec/dataset.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from pluggy import HookspecMarker
+
+if TYPE_CHECKING:
+ from airflow.datasets import Dataset
+
+hookspec = HookspecMarker("airflow")
+
+
+@hookspec
+def on_dataset_created(
+ dataset: Dataset,
+):
+ """Execute when a new dataset is created."""
+
+
+@hookspec
+def on_dataset_changed(
+ dataset: Dataset,
+):
+ """Execute when dataset change is registered."""
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9a09b706e5..428f79e14e 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -78,6 +78,7 @@ import airflow.templates
from airflow import settings, utils
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf as airflow_conf, secrets_backend_list
+from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
@@ -3137,8 +3138,8 @@ class DAG(LoggingMixin):
dag_references = collections.defaultdict(set)
outlet_references = collections.defaultdict(set)
# We can't use a set here as we want to preserve order
- outlet_datasets: dict[Dataset, None] = {}
- input_datasets: dict[Dataset, None] = {}
+ outlet_datasets: dict[DatasetModel, None] = {}
+ input_datasets: dict[DatasetModel, None] = {}
# here we go through dags and tasks to check for dataset references
# if there are now None and previously there were some, we delete them
@@ -3171,7 +3172,8 @@ class DAG(LoggingMixin):
all_datasets.update(input_datasets)
# store datasets
- stored_datasets = {}
+ stored_datasets: dict[str, DatasetModel] = {}
+ new_datasets: list[DatasetModel] = []
for dataset in all_datasets:
stored_dataset = session.scalar(
select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1)
@@ -3183,11 +3185,11 @@ class DAG(LoggingMixin):
stored_dataset.is_orphaned = expression.false()
stored_datasets[stored_dataset.uri] = stored_dataset
else:
- session.add(dataset)
- stored_datasets[dataset.uri] = dataset
-
- session.flush() # this is required to ensure each dataset has its PK loaded
+ new_datasets.append(dataset)
+ dataset_manager.create_datasets(dataset_models=new_datasets, session=session)
+ stored_datasets.update({dataset.uri: dataset for dataset in new_datasets})
+ del new_datasets
del all_datasets
# reconcile dag-schedule-on-dataset references
diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst b/docs/apache-airflow/administration-and-deployment/listeners.rst
index 4182d135a1..0672e07779 100644
--- a/docs/apache-airflow/administration-and-deployment/listeners.rst
+++ b/docs/apache-airflow/administration-and-deployment/listeners.rst
@@ -50,6 +50,13 @@ TaskInstance State Change Events
TaskInstance state change events occur when a ``TaskInstance`` changes state.
You can use these events to react to ``LocalTaskJob`` state changes.
+Dataset Events
+--------------
+
+- ``on_dataset_created``
+- ``on_dataset_changed``
+
+Dataset events occur when Dataset management operations are run.
Usage
-----
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index 19b6b1ed45..514ed8877a 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -24,8 +24,10 @@ import pytest
from airflow.datasets import Dataset
from airflow.datasets.manager import DatasetManager
+from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DagModel
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
+from tests.listeners import dataset_listener
pytestmark = pytest.mark.db_test
@@ -96,3 +98,36 @@ class TestDatasetManager:
# Ensure we've created a dataset
assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
assert session.query(DatasetDagRunQueue).count() == 0
+
+ def test_register_dataset_change_notifies_dataset_listener(self, session, mock_task_instance):
+ dsem = DatasetManager()
+ dataset_listener.clear()
+ get_listener_manager().add_listener(dataset_listener)
+
+ ds = Dataset(uri="test_dataset_uri")
+ dag1 = DagModel(dag_id="dag1")
+ session.add_all([dag1])
+
+ dsm = DatasetModel(uri="test_dataset_uri")
+ session.add(dsm)
+ dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)]
+ session.flush()
+
+ dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session)
+
+ # Ensure the listener was notified
+ assert len(dataset_listener.changed) == 1
+ assert dataset_listener.changed[0].uri == ds.uri
+
+ def test_create_datasets_notifies_dataset_listener(self, session):
+ dsem = DatasetManager()
+ dataset_listener.clear()
+ get_listener_manager().add_listener(dataset_listener)
+
+ dsm = DatasetModel(uri="test_dataset_uri")
+
+ dsem.create_datasets([dsm], session)
+
+ # Ensure the listener was notified
+ assert len(dataset_listener.created) == 1
+ assert dataset_listener.created[0].uri == dsm.uri
diff --git a/tests/listeners/dataset_listener.py b/tests/listeners/dataset_listener.py
new file mode 100644
index 0000000000..0e4b768c69
--- /dev/null
+++ b/tests/listeners/dataset_listener.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import copy
+import typing
+
+from airflow.listeners import hookimpl
+
+if typing.TYPE_CHECKING:
+ from airflow.datasets import Dataset
+
+
+changed: list[Dataset] = []
+created: list[Dataset] = []
+
+
+@hookimpl
+def on_dataset_changed(dataset):
+ changed.append(copy.deepcopy(dataset))
+
+
+@hookimpl
+def on_dataset_created(dataset):
+ created.append(copy.deepcopy(dataset))
+
+
+def clear():
+ global changed, created
+ changed, created = [], []
diff --git a/tests/listeners/test_dataset_listener.py b/tests/listeners/test_dataset_listener.py
new file mode 100644
index 0000000000..d17f079e2d
--- /dev/null
+++ b/tests/listeners/test_dataset_listener.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.datasets import Dataset
+from airflow.listeners.listener import get_listener_manager
+from airflow.models.dataset import DatasetModel
+from airflow.operators.empty import EmptyOperator
+from airflow.utils.session import provide_session
+from tests.listeners import dataset_listener
+
+
+@pytest.fixture(autouse=True)
+def clean_listener_manager():
+ lm = get_listener_manager()
+ lm.clear()
+ lm.add_listener(dataset_listener)
+ yield
+ lm = get_listener_manager()
+ lm.clear()
+ dataset_listener.clear()
+
+
+@pytest.mark.db_test
+@provide_session
+def test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_operator, session):
+ dataset_uri = "test_dataset_uri"
+ ds = Dataset(uri=dataset_uri)
+ ds_model = DatasetModel(uri=dataset_uri)
+ session.add(ds_model)
+
+ session.flush()
+
+ ti = create_task_instance_of_operator(
+ operator_class=EmptyOperator,
+ dag_id="producing_dag",
+ task_id="test_task",
+ session=session,
+ outlets=[ds],
+ )
+ ti.run()
+
+ assert len(dataset_listener.changed) == 1
+ assert dataset_listener.changed[0].uri == dataset_uri