You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/07/15 16:26:24 UTC

[airflow] branch main updated: Add dataset events to dataset api (#25039)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fcf8cc26f7 Add dataset events to dataset api (#25039)
fcf8cc26f7 is described below

commit fcf8cc26f7d94fb0baa78ac4c981a14cc88af533
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Fri Jul 15 10:26:14 2022 -0600

    Add dataset events to dataset api (#25039)
---
 .../api_connexion/endpoints/dataset_endpoint.py    |  45 +++-
 airflow/api_connexion/openapi/v1.yaml              | 116 +++++++++
 airflow/api_connexion/schemas/dataset_schema.py    |  38 ++-
 .../endpoints/test_dataset_endpoint.py             | 272 ++++++++++++++++++---
 tests/api_connexion/schemas/test_dataset_schema.py |  58 ++++-
 tests/test_utils/db.py                             |   3 +-
 6 files changed, 497 insertions(+), 35 deletions(-)

diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py
index cfdc04f615..5c9e7606fc 100644
--- a/airflow/api_connexion/endpoints/dataset_endpoint.py
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -15,19 +15,23 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from typing import Optional
+
 from sqlalchemy import func
 from sqlalchemy.orm import Session
 
-from airflow import Dataset
 from airflow.api_connexion import security
 from airflow.api_connexion.exceptions import NotFound
 from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
 from airflow.api_connexion.schemas.dataset_schema import (
     DatasetCollection,
+    DatasetEventCollection,
     dataset_collection_schema,
+    dataset_event_collection_schema,
     dataset_schema,
 )
 from airflow.api_connexion.types import APIResponse
+from airflow.models.dataset import Dataset, DatasetEvent
 from airflow.security import permissions
 from airflow.utils.session import NEW_SESSION, provide_session
 
@@ -59,3 +63,42 @@ def get_datasets(
     query = apply_sorting(query, order_by, {}, allowed_attrs)
     datasets = query.offset(offset).limit(limit).all()
     return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@provide_session
+@format_parameters({'limit': check_limit})
+def get_dataset_events(
+    *,
+    limit: int,
+    offset: int = 0,
+    order_by: str = "created_at",
+    dataset_id: Optional[int] = None,
+    source_dag_id: Optional[str] = None,
+    source_task_id: Optional[str] = None,
+    source_run_id: Optional[str] = None,
+    source_map_index: Optional[int] = None,
+    session: Session = NEW_SESSION,
+) -> APIResponse:
+    """Get dataset events"""
+    allowed_attrs = ['source_dag_id', 'source_task_id', 'source_run_id', 'source_map_index', 'created_at']
+
+    query = session.query(DatasetEvent)
+
+    if dataset_id:
+        query = query.filter(DatasetEvent.dataset_id == dataset_id)
+    if source_dag_id:
+        query = query.filter(DatasetEvent.source_dag_id == source_dag_id)
+    if source_task_id:
+        query = query.filter(DatasetEvent.source_task_id == source_task_id)
+    if source_run_id:
+        query = query.filter(DatasetEvent.source_run_id == source_run_id)
+    if source_map_index:
+        query = query.filter(DatasetEvent.source_map_index == source_map_index)
+
+    total_entries = query.count()
+    query = apply_sorting(query, order_by, {}, allowed_attrs)
+    events = query.offset(offset).limit(limit).all()
+    return dataset_event_collection_schema.dump(
+        DatasetEventCollection(dataset_events=events, total_entries=total_entries)
+    )
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index 4bfedbfb98..2dfef4780c 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1654,6 +1654,36 @@ paths:
         '404':
           $ref: '#/components/responses/NotFound'
 
+  /datasets/events:
+    parameters:
+      - $ref: '#/components/parameters/PageLimit'
+      - $ref: '#/components/parameters/PageOffset'
+      - $ref: '#/components/parameters/OrderBy'
+      - $ref: '#/components/parameters/FilterDatasetID'
+      - $ref: '#/components/parameters/FilterSourceDAGID'
+      - $ref: '#/components/parameters/FilterSourceTaskID'
+      - $ref: '#/components/parameters/FilterSourceRunID'
+      - $ref: '#/components/parameters/FilterSourceMapIndex'
+    get:
+      summary: Get dataset events
+      description: Get dataset events
+      x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+      operationId: get_dataset_events
+      tags: [Dataset]
+      responses:
+        '200':
+          description: Success.
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/DatasetEventCollection'
+        '401':
+          $ref: '#/components/responses/Unauthenticated'
+        '403':
+          $ref: '#/components/responses/PermissionDenied'
+        '404':
+          $ref: '#/components/responses/NotFound'
+
   /config:
     get:
       summary: Get current configuration
@@ -3461,6 +3491,57 @@ components:
                 $ref: '#/components/schemas/Dataset'
         - $ref: '#/components/schemas/CollectionInfo'
 
+    DatasetEvent:
+      description: |
+        A dataset event.
+
+        *New in version 2.4.0*
+      type: object
+      properties:
+        dataset_id:
+          type: integer
+          description: The dataset id
+        extra:
+          type: string
+          description: The dataset extra
+          nullable: true
+        source_dag_id:
+          type: string
+          description: The DAG ID that updated the dataset.
+          nullable: false
+        source_task_id:
+          type: string
+          description: The task ID that updated the dataset.
+          nullable: false
+        source_run_id:
+          type: string
+          description: The DAG run ID that updated the dataset.
+          nullable: false
+        source_map_index:
+          type: integer
+          description: The task map index that updated the dataset.
+          nullable: false
+        created_at:
+          type: string
+          description: The dataset event creation time
+          nullable: false
+
+
+    DatasetEventCollection:
+      description: |
+        A collection of dataset events.
+
+        *New in version 2.4.0*
+      type: object
+      allOf:
+        - type: object
+          properties:
+            dataset_events:
+              type: array
+              items:
+                $ref: '#/components/schemas/DatasetEvent'
+        - $ref: '#/components/schemas/CollectionInfo'
+
 
     # Configuration
     ConfigOption:
@@ -4287,6 +4368,41 @@ components:
 
           *New in version 2.2.0*
 
+    FilterDatasetID:
+      in: query
+      name: dataset_id
+      schema:
+        type: integer
+      description: The Dataset ID that updated the dataset.
+
+    FilterSourceDAGID:
+      in: query
+      name: source_dag_id
+      schema:
+        type: string
+      description: The DAG ID that updated the dataset.
+
+    FilterSourceTaskID:
+      in: query
+      name: source_task_id
+      schema:
+        type: string
+      description: The task ID that updated the dataset.
+
+    FilterSourceRunID:
+      in: query
+      name: source_run_id
+      schema:
+        type: string
+      description: The DAG run ID that updated the dataset.
+
+    FilterSourceMapIndex:
+      in: query
+      name: source_map_index
+      schema:
+        type: integer
+      description: The map index that updated the dataset.
+
     OrderBy:
       in: query
       name: order_by
diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/dataset_schema.py
index 2228ac4c9b..5b2601cea7 100644
--- a/airflow/api_connexion/schemas/dataset_schema.py
+++ b/airflow/api_connexion/schemas/dataset_schema.py
@@ -20,7 +20,7 @@ from typing import List, NamedTuple
 from marshmallow import Schema, fields
 from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
 
-from airflow import Dataset
+from airflow.models.dataset import Dataset, DatasetEvent
 
 
 class DatasetSchema(SQLAlchemySchema):
@@ -54,3 +54,39 @@ class DatasetCollectionSchema(Schema):
 
 dataset_schema = DatasetSchema()
 dataset_collection_schema = DatasetCollectionSchema()
+
+
+class DatasetEventSchema(SQLAlchemySchema):
+    """Dataset Event DB schema"""
+
+    class Meta:
+        """Meta"""
+
+        model = DatasetEvent
+
+    id = auto_field()
+    dataset_id = auto_field()
+    extra = auto_field()
+    source_task_id = auto_field()
+    source_dag_id = auto_field()
+    source_run_id = auto_field()
+    source_map_index = auto_field()
+    created_at = auto_field()
+
+
+class DatasetEventCollection(NamedTuple):
+    """List of Dataset events with meta"""
+
+    dataset_events: List[DatasetEvent]
+    total_entries: int
+
+
+class DatasetEventCollectionSchema(Schema):
+    """Dataset Event Collection Schema"""
+
+    dataset_events = fields.List(fields.Nested(DatasetEventSchema))
+    total_entries = fields.Int()
+
+
+dataset_event_schema = DatasetEventSchema()
+dataset_event_collection_schema = DatasetEventCollectionSchema()
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py
index 3d6dd59c8e..2696e60300 100644
--- a/tests/api_connexion/endpoints/test_dataset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -19,7 +19,7 @@ import pytest
 from parameterized import parameterized
 
 from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.models import Dataset
+from airflow.models.dataset import Dataset, DatasetEvent
 from airflow.security import permissions
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
@@ -62,6 +62,7 @@ class TestDatasetEndpoint:
 
     def _create_dataset(self, session):
         dataset_model = Dataset(
+            id=1,
             uri="s3://bucket/key",
             extra={"foo": "bar"},
             created_at=timezone.parse(self.default_time),
@@ -70,23 +71,17 @@ class TestDatasetEndpoint:
         session.add(dataset_model)
         session.commit()
 
-    @staticmethod
-    def _normalize_dataset_ids(datasets):
-        for i, dataset in enumerate(datasets, 1):
-            dataset["id"] = i
-
 
 class TestGetDatasetEndpoint(TestDatasetEndpoint):
     def test_should_respond_200(self, session):
         self._create_dataset(session)
-        result = session.query(Dataset).all()
-        assert len(result) == 1
-        response = self.client.get(
-            f"/api/v1/datasets/{result[0].id}", environ_overrides={'REMOTE_USER': "test"}
-        )
+        assert session.query(Dataset).count() == 1
+
+        response = self.client.get("/api/v1/datasets/1", environ_overrides={'REMOTE_USER': "test"})
+
         assert response.status_code == 200
         assert response.json == {
-            "id": result[0].id,
+            "id": 1,
             "uri": "s3://bucket/key",
             "extra": "{'foo': 'bar'}",
             "created_at": self.default_time,
@@ -105,9 +100,7 @@ class TestGetDatasetEndpoint(TestDatasetEndpoint):
 
     def test_should_raises_401_unauthenticated(self, session):
         self._create_dataset(session)
-        dataset = session.query(Dataset).first()
-        response = self.client.get(f"/api/v1/datasets/{dataset.id}")
-
+        response = self.client.get("/api/v1/datasets/1")
         assert_401(response)
 
 
@@ -115,22 +108,22 @@ class TestGetDatasets(TestDatasetEndpoint):
     def test_should_respond_200(self, session):
         datasets = [
             Dataset(
-                uri=f"s3://bucket/key/{i+1}",
+                id=i,
+                uri=f"s3://bucket/key/{i}",
                 extra={"foo": "bar"},
                 created_at=timezone.parse(self.default_time),
                 updated_at=timezone.parse(self.default_time),
             )
-            for i in range(2)
+            for i in [1, 2]
         ]
         session.add_all(datasets)
         session.commit()
-        result = session.query(Dataset).all()
-        assert len(result) == 2
+        assert session.query(Dataset).count() == 2
+
         response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
 
         assert response.status_code == 200
         response_data = response.json
-        self._normalize_dataset_ids(response_data['datasets'])
         assert response_data == {
             "datasets": [
                 {
@@ -154,17 +147,16 @@ class TestGetDatasets(TestDatasetEndpoint):
     def test_order_by_raises_400_for_invalid_attr(self, session):
         datasets = [
             Dataset(
-                uri=f"s3://bucket/key/{i+1}",
+                uri=f"s3://bucket/key/{i}",
                 extra={"foo": "bar"},
                 created_at=timezone.parse(self.default_time),
                 updated_at=timezone.parse(self.default_time),
             )
-            for i in range(2)
+            for i in [1, 2]
         ]
         session.add_all(datasets)
         session.commit()
-        result = session.query(Dataset).all()
-        assert len(result) == 2
+        assert session.query(Dataset).count() == 2
 
         response = self.client.get(
             "/api/v1/datasets?order_by=fake", environ_overrides={'REMOTE_USER': "test"}
@@ -177,17 +169,16 @@ class TestGetDatasets(TestDatasetEndpoint):
     def test_should_raises_401_unauthenticated(self, session):
         datasets = [
             Dataset(
-                uri=f"s3://bucket/key/{i+1}",
+                uri=f"s3://bucket/key/{i}",
                 extra={"foo": "bar"},
                 created_at=timezone.parse(self.default_time),
                 updated_at=timezone.parse(self.default_time),
             )
-            for i in range(2)
+            for i in [1, 2]
         ]
         session.add_all(datasets)
         session.commit()
-        result = session.query(Dataset).all()
-        assert len(result) == 2
+        assert session.query(Dataset).count() == 2
 
         response = self.client.get("/api/v1/datasets")
 
@@ -230,7 +221,7 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
     def test_should_respect_page_size_limit_default(self, session):
         datasets = [
             Dataset(
-                uri=f"s3://bucket/key/{i+1}",
+                uri=f"s3://bucket/key/{i}",
                 extra={"foo": "bar"},
                 created_at=timezone.parse(self.default_time),
                 updated_at=timezone.parse(self.default_time),
@@ -239,7 +230,9 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
         ]
         session.add_all(datasets)
         session.commit()
+
         response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
+
         assert response.status_code == 200
         assert len(response.json['datasets']) == 100
 
@@ -247,15 +240,232 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
     def test_should_return_conf_max_if_req_max_above_conf(self, session):
         datasets = [
             Dataset(
-                uri=f"s3://bucket/key/{i+1}",
+                uri=f"s3://bucket/key/{i}",
                 extra={"foo": "bar"},
                 created_at=timezone.parse(self.default_time),
                 updated_at=timezone.parse(self.default_time),
             )
-            for i in range(200)
+            for i in range(1, 200)
         ]
         session.add_all(datasets)
         session.commit()
+
         response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={'REMOTE_USER': "test"})
+
         assert response.status_code == 200
         assert len(response.json['datasets']) == 150
+
+
+class TestGetDatasetEvents(TestDatasetEndpoint):
+    def test_should_respond_200(self, session):
+        self._create_dataset(session)
+        common = {
+            "dataset_id": 1,
+            "extra": "{'foo': 'bar'}",
+            "source_dag_id": "foo",
+            "source_task_id": "bar",
+            "source_run_id": "custom",
+            "source_map_index": -1,
+        }
+
+        events = [DatasetEvent(id=i, created_at=timezone.parse(self.default_time), **common) for i in [1, 2]]
+        session.add_all(events)
+        session.commit()
+        assert session.query(DatasetEvent).count() == 2
+
+        response = self.client.get("/api/v1/datasets/events", environ_overrides={'REMOTE_USER': "test"})
+
+        assert response.status_code == 200
+        response_data = response.json
+        assert response_data == {
+            "dataset_events": [
+                {"id": 1, "created_at": self.default_time, **common},
+                {"id": 2, "created_at": self.default_time, **common},
+            ],
+            "total_entries": 2,
+        }
+
+    @parameterized.expand(
+        [
+            ('dataset_id', '2'),
+            ('source_dag_id', 'dag2'),
+            ('source_task_id', 'task2'),
+            ('source_run_id', 'run2'),
+            ('source_map_index', '2'),
+        ]
+    )
+    @provide_session
+    def test_filtering(self, attr, value, session):
+        datasets = [
+            Dataset(
+                id=i,
+                uri=f"s3://bucket/key/{i}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in [1, 2, 3]
+        ]
+        session.add_all(datasets)
+        session.commit()
+        events = [
+            DatasetEvent(
+                id=i,
+                dataset_id=i,
+                source_dag_id=f"dag{i}",
+                source_task_id=f"task{i}",
+                source_run_id=f"run{i}",
+                source_map_index=i,
+                created_at=timezone.parse(self.default_time),
+            )
+            for i in [1, 2, 3]
+        ]
+        session.add_all(events)
+        session.commit()
+        assert session.query(DatasetEvent).count() == 3
+
+        response = self.client.get(
+            f"/api/v1/datasets/events?{attr}={value}", environ_overrides={'REMOTE_USER': "test"}
+        )
+
+        assert response.status_code == 200
+        response_data = response.json
+        assert response_data == {
+            "dataset_events": [
+                {
+                    "id": 2,
+                    "dataset_id": 2,
+                    "extra": None,
+                    "source_dag_id": "dag2",
+                    "source_task_id": "task2",
+                    "source_run_id": "run2",
+                    "source_map_index": 2,
+                    "created_at": self.default_time,
+                }
+            ],
+            "total_entries": 1,
+        }
+
+    def test_order_by_raises_400_for_invalid_attr(self, session):
+        self._create_dataset(session)
+        events = [
+            DatasetEvent(
+                dataset_id=1,
+                extra="{'foo': 'bar'}",
+                source_dag_id="foo",
+                source_task_id="bar",
+                source_run_id="custom",
+                source_map_index=-1,
+                created_at=timezone.parse(self.default_time),
+            )
+            for i in [1, 2]
+        ]
+        session.add_all(events)
+        session.commit()
+        assert session.query(DatasetEvent).count() == 2
+
+        response = self.client.get(
+            "/api/v1/datasets/events?order_by=fake", environ_overrides={'REMOTE_USER': "test"}
+        )  # missing attr
+
+        assert response.status_code == 400
+        msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model"
+        assert response.json['detail'] == msg
+
+    def test_should_raises_401_unauthenticated(self, session):
+        response = self.client.get("/api/v1/datasets/events")
+        assert_401(response)
+
+
+class TestGetDatasetEventsEndpointPagination(TestDatasetEndpoint):
+    @parameterized.expand(
+        [
+            # Limit test data
+            ("/api/v1/datasets/events?limit=1&order_by=source_run_id", ["run1"]),
+            (
+                "/api/v1/datasets/events?limit=3&order_by=source_run_id",
+                [f"run{i}" for i in range(1, 4)],
+            ),
+            # Offset test data
+            (
+                "/api/v1/datasets/events?offset=1&order_by=source_run_id",
+                [f"run{i}" for i in range(2, 10)],
+            ),
+            (
+                "/api/v1/datasets/events?offset=3&order_by=source_run_id",
+                [f"run{i}" for i in range(4, 10)],
+            ),
+            # Limit and offset test data
+            (
+                "/api/v1/datasets/events?offset=3&limit=3&order_by=source_run_id",
+                [f"run{i}" for i in [4, 5, 6]],
+            ),
+        ]
+    )
+    @provide_session
+    def test_limit_and_offset(self, url, expected_event_runids, session):
+        self._create_dataset(session)
+        events = [
+            DatasetEvent(
+                dataset_id=1,
+                source_dag_id="foo",
+                source_task_id="bar",
+                source_run_id=f"run{i}",
+                source_map_index=-1,
+                created_at=timezone.parse(self.default_time),
+            )
+            for i in range(1, 10)
+        ]
+        session.add_all(events)
+        session.commit()
+
+        response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
+
+        assert response.status_code == 200
+        event_runids = [event["source_run_id"] for event in response.json["dataset_events"]]
+        assert event_runids == expected_event_runids
+
+    def test_should_respect_page_size_limit_default(self, session):
+        self._create_dataset(session)
+        events = [
+            DatasetEvent(
+                dataset_id=1,
+                source_dag_id="foo",
+                source_task_id="bar",
+                source_run_id=f"run{i}",
+                source_map_index=-1,
+                created_at=timezone.parse(self.default_time),
+            )
+            for i in range(1, 110)
+        ]
+        session.add_all(events)
+        session.commit()
+
+        response = self.client.get("/api/v1/datasets/events", environ_overrides={'REMOTE_USER': "test"})
+
+        assert response.status_code == 200
+        assert len(response.json['dataset_events']) == 100
+
+    @conf_vars({("api", "maximum_page_limit"): "150"})
+    def test_should_return_conf_max_if_req_max_above_conf(self, session):
+        self._create_dataset(session)
+        events = [
+            DatasetEvent(
+                dataset_id=1,
+                source_dag_id="foo",
+                source_task_id="bar",
+                source_run_id=f"run{i}",
+                source_map_index=-1,
+                created_at=timezone.parse(self.default_time),
+            )
+            for i in range(1, 200)
+        ]
+        session.add_all(events)
+        session.commit()
+
+        response = self.client.get(
+            "/api/v1/datasets/events?limit=180", environ_overrides={'REMOTE_USER': "test"}
+        )
+
+        assert response.status_code == 200
+        assert len(response.json['dataset_events']) == 150
diff --git a/tests/api_connexion/schemas/test_dataset_schema.py b/tests/api_connexion/schemas/test_dataset_schema.py
index bf30b9b12e..1d33fbf813 100644
--- a/tests/api_connexion/schemas/test_dataset_schema.py
+++ b/tests/api_connexion/schemas/test_dataset_schema.py
@@ -15,12 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from airflow import Dataset
 from airflow.api_connexion.schemas.dataset_schema import (
     DatasetCollection,
+    DatasetEventCollection,
     dataset_collection_schema,
+    dataset_event_collection_schema,
+    dataset_event_schema,
     dataset_schema,
 )
+from airflow.models.dataset import Dataset, DatasetEvent
 from airflow.utils import timezone
 from tests.test_utils.db import clear_db_datasets
 
@@ -93,3 +96,56 @@ class TestDatasetCollectionSchema(TestDatasetSchemaBase):
             ],
             "total_entries": 2,
         }
+
+
+class TestDatasetEventSchema(TestDatasetSchemaBase):
+    def test_serialize(self, session):
+        event = DatasetEvent(
+            id=1,
+            dataset_id=10,
+            extra={"foo": "bar"},
+            source_dag_id="foo",
+            source_task_id="bar",
+            source_run_id="custom",
+            source_map_index=-1,
+            created_at=timezone.parse(self.timestamp),
+        )
+        session.add(event)
+        session.flush()
+        serialized_data = dataset_event_schema.dump(event)
+        assert serialized_data == {
+            "id": 1,
+            "dataset_id": 10,
+            "extra": "{'foo': 'bar'}",
+            "source_dag_id": "foo",
+            "source_task_id": "bar",
+            "source_run_id": "custom",
+            "source_map_index": -1,
+            "created_at": self.timestamp,
+        }
+
+
+class TestDatasetEventCollectionSchema(TestDatasetSchemaBase):
+    def test_serialize(self, session):
+        common = {
+            "dataset_id": 10,
+            "extra": "{'foo': 'bar'}",
+            "source_dag_id": "foo",
+            "source_task_id": "bar",
+            "source_run_id": "custom",
+            "source_map_index": -1,
+        }
+
+        events = [DatasetEvent(id=i, created_at=timezone.parse(self.timestamp), **common) for i in [1, 2]]
+        session.add_all(events)
+        session.flush()
+        serialized_data = dataset_event_collection_schema.dump(
+            DatasetEventCollection(dataset_events=events, total_entries=2)
+        )
+        assert serialized_data == {
+            "dataset_events": [
+                {"id": 1, "created_at": self.timestamp, **common},
+                {"id": 2, "created_at": self.timestamp, **common},
+            ],
+            "total_entries": 2,
+        }
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index be91413b55..92629020fa 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -23,7 +23,6 @@ from airflow.models import (
     DagRun,
     DagTag,
     DagWarning,
-    Dataset,
     DbCallbackRequest,
     Log,
     Pool,
@@ -38,6 +37,7 @@ from airflow.models import (
     errors,
 )
 from airflow.models.dagcode import DagCode
+from airflow.models.dataset import Dataset, DatasetEvent
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.security.permissions import RESOURCE_DAG_PREFIX
 from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections, reflect_tables
@@ -55,6 +55,7 @@ def clear_db_runs():
 
 def clear_db_datasets():
     with create_session() as session:
+        session.query(DatasetEvent).delete()
         session.query(Dataset).delete()