You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2022/07/11 17:37:51 UTC

[airflow] branch main updated: Dataset event table (#24908)

This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ed70b696fb Dataset event table (#24908)
ed70b696fb is described below

commit ed70b696fb4bb6247f7019e9a70aec74d37f944d
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Mon Jul 11 10:37:42 2022 -0700

    Dataset event table (#24908)
    
    Part of AIP-48. Adds a table to record all dataset events, which are currently used to used to trigger dag runs, but which may in the future also contain information about the nature of the update which can be read by the downstream tasks working with the datasets.
---
 .../versions/0114_2_4_0_add_dataset_model.py       | 18 +++++
 airflow/models/dataset.py                          | 86 +++++++++++++++++++++-
 airflow/models/taskinstance.py                     | 13 +++-
 airflow/utils/db_cleanup.py                        |  1 +
 tests/models/test_taskinstance.py                  | 17 ++++-
 5 files changed, 131 insertions(+), 4 deletions(-)

diff --git a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
index 9cfca3766c..ba4be54451 100644
--- a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
+++ b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
@@ -118,12 +118,29 @@ def _create_dataset_dag_run_queue_table():
     )
 
 
+def _create_dataset_event_table():
+    op.create_table(
+        'dataset_event',
+        sa.Column('id', Integer, primary_key=True, autoincrement=True),
+        sa.Column('dataset_id', Integer, nullable=False),
+        sa.Column('extra', ExtendedJSON, nullable=True),
+        sa.Column('task_id', String(250), nullable=True),
+        sa.Column('dag_id', String(250), nullable=True),
+        sa.Column('run_id', String(250), nullable=True),
+        sa.Column('map_index', sa.Integer(), nullable=True, server_default='-1'),
+        sa.Column('created_at', TIMESTAMP, nullable=False),
+        sqlite_autoincrement=True,  # ensures PK values not reused
+    )
+    op.create_index('idx_dataset_id_created_at', 'dataset_event', ['dataset_id', 'created_at'])
+
+
 def upgrade():
     """Apply Add Dataset model"""
     _create_dataset_table()
     _create_dataset_dag_ref_table()
     _create_dataset_task_ref_table()
     _create_dataset_dag_run_queue_table()
+    _create_dataset_event_table()
 
 
 def downgrade():
@@ -131,4 +148,5 @@ def downgrade():
     op.drop_table('dataset_dag_ref')
     op.drop_table('dataset_task_ref')
     op.drop_table('dataset_dag_run_queue')
+    op.drop_table('dataset_event')
     op.drop_table('dataset')
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index 256ed2293e..7646eaae40 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -17,7 +17,7 @@
 # under the License.
 from urllib.parse import urlparse
 
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, text
 from sqlalchemy.orm import relationship
 
 from airflow.models.base import ID_LEN, Base, StringID
@@ -199,3 +199,87 @@ class DatasetDagRunQueue(Base):
         for attr in [x.name for x in self.__mapper__.primary_key]:
             args.append(f"{attr}={getattr(self, attr)!r}")
         return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetEvent(Base):
+    """
+    A table to store datasets events.
+
+    :param dataset_id: reference to Dataset record
+    :param extra: JSON field for arbitrary extra info
+    :param task_id: the task_id of the TI which updated the dataset
+    :param dag_id: the dag_id of the TI which updated the dataset
+    :param run_id: the run_id of the TI which updated the dataset
+    :param map_index: the map_index of the TI which updated the dataset
+
+    We use relationships instead of foreign keys so that dataset events are not deleted even
+    if the foreign key object is.
+    """
+
+    id = Column(Integer, primary_key=True, autoincrement=True)
+    dataset_id = Column(Integer, nullable=False)
+    extra = Column(ExtendedJSON, nullable=True)
+    task_id = Column(StringID(), nullable=True)
+    dag_id = Column(StringID(), nullable=True)
+    run_id = Column(StringID(), nullable=True)
+    map_index = Column(Integer, nullable=True, server_default=text("-1"))
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+    __tablename__ = "dataset_event"
+    __table_args__ = (
+        Index('idx_dataset_id_created_at', dataset_id, created_at),
+        {'sqlite_autoincrement': True},  # ensures PK values not reused
+    )
+
+    source_task_instance = relationship(
+        "TaskInstance",
+        primaryjoin="""and_(
+            DatasetEvent.dag_id == foreign(TaskInstance.dag_id),
+            DatasetEvent.run_id == foreign(TaskInstance.run_id),
+            DatasetEvent.task_id == foreign(TaskInstance.task_id),
+            DatasetEvent.map_index == foreign(TaskInstance.map_index),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    source_dag_run = relationship(
+        "DagRun",
+        primaryjoin="""and_(
+            DatasetEvent.dag_id == foreign(DagRun.dag_id),
+            DatasetEvent.run_id == foreign(DagRun.run_id),
+        )""",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+    dataset = relationship(
+        Dataset,
+        primaryjoin="DatasetEvent.dataset_id == foreign(Dataset.id)",
+        viewonly=True,
+        lazy="select",
+        uselist=False,
+    )
+
+    def __eq__(self, other) -> bool:
+        if isinstance(other, self.__class__):
+            return self.dataset_id == other.dataset_id and self.created_at == other.created_at
+        else:
+            return NotImplemented
+
+    def __hash__(self) -> int:
+        return hash((self.dataset_id, self.created_at))
+
+    def __repr__(self) -> str:
+        args = []
+        for attr in [
+            'id',
+            'dataset_id',
+            'extra',
+            'task_id',
+            'dag_id',
+            'run_id',
+            'map_index',
+        ]:
+            args.append(f"{attr}={getattr(self, attr)!r}")
+        return f"{self.__class__.__name__}({', '.join(args)})"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 18bae29a0c..024a0a4b1b 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -95,7 +95,7 @@ from airflow.exceptions import (
     XComForMappingNotPushed,
 )
 from airflow.models.base import Base, StringID
-from airflow.models.dataset import DatasetDagRunQueue
+from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent
 from airflow.models.log import Log
 from airflow.models.param import ParamsDict
 from airflow.models.taskfail import TaskFail
@@ -1516,7 +1516,7 @@ class TaskInstance(Base, LoggingMixin):
             self._create_dataset_dag_run_queue_records(session=session)
             session.commit()
 
-    def _create_dataset_dag_run_queue_records(self, *, session):
+    def _create_dataset_dag_run_queue_records(self, *, session: Session) -> None:
         from airflow.models import Dataset
 
         for obj in getattr(self.task, '_outlets', []):
@@ -1528,6 +1528,15 @@ class TaskInstance(Base, LoggingMixin):
                     continue
                 downstream_dag_ids = [x.dag_id for x in dataset.dag_references]
                 self.log.debug("downstream dag ids %s", downstream_dag_ids)
+                session.add(
+                    DatasetEvent(
+                        dataset_id=dataset.id,
+                        task_id=self.task_id,
+                        dag_id=self.dag_id,
+                        run_id=self.run_id,
+                        map_index=self.map_index,
+                    )
+                )
                 for dag_id in downstream_dag_ids:
                     session.merge(DatasetDagRunQueue(dataset_id=dataset.id, target_dag_id=dag_id))
 
diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py
index f77ae52a60..feba360f05 100644
--- a/airflow/utils/db_cleanup.py
+++ b/airflow/utils/db_cleanup.py
@@ -94,6 +94,7 @@ config_list: List[_TableConfig] = [
         keep_last_filters=[column('external_trigger') == false()],
         keep_last_group_by=['dag_id'],
     ),
+    _TableConfig(table_name='dataset_event', recency_column_name='created_at'),
     _TableConfig(table_name='import_error', recency_column_name='timestamp'),
     _TableConfig(table_name='log', recency_column_name='dttm'),
     _TableConfig(table_name='rendered_task_instance_fields', recency_column_name='execution_date'),
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 0bb38071e8..c0868b0b7f 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -57,7 +57,7 @@ from airflow.models import (
     Variable,
     XCom,
 )
-from airflow.models.dataset import DatasetDagRunQueue, DatasetTaskRef
+from airflow.models.dataset import Dataset, DatasetDagRunQueue, DatasetEvent, DatasetTaskRef
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskinstance import TaskInstance
@@ -1499,10 +1499,25 @@ class TestTaskInstance:
         ti._run_raw_task()
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
+
+        # check that one queue record created for each dag that depends on dataset 1
         assert session.query(DatasetDagRunQueue.target_dag_id).filter(
             DatasetTaskRef.dag_id == dag1.dag_id, DatasetTaskRef.task_id == 'upstream_task_1'
         ).all() == [('dag3',), ('dag4',), ('dag5',)]
 
+        # check that one event record created for dataset1 and this TI
+        assert session.query(Dataset.uri).join(DatasetEvent.dataset).filter(
+            DatasetEvent.source_task_instance == ti
+        ).one() == ('s3://dag1/output_1.txt',)
+
+        # check that no other dataset events recorded
+        assert (
+            session.query(Dataset.uri)
+            .join(DatasetEvent.dataset)
+            .filter(DatasetEvent.source_task_instance == ti)
+            .count()
+        ) == 1
+
     @staticmethod
     def _test_previous_dates_setup(
         schedule_interval: Union[str, datetime.timedelta, None],