You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/07/15 16:26:24 UTC
[airflow] branch main updated: Add dataset events to dataset api (#25039)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fcf8cc26f7 Add dataset events to dataset api (#25039)
fcf8cc26f7 is described below
commit fcf8cc26f7d94fb0baa78ac4c981a14cc88af533
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Fri Jul 15 10:26:14 2022 -0600
Add dataset events to dataset api (#25039)
---
.../api_connexion/endpoints/dataset_endpoint.py | 45 +++-
airflow/api_connexion/openapi/v1.yaml | 116 +++++++++
airflow/api_connexion/schemas/dataset_schema.py | 38 ++-
.../endpoints/test_dataset_endpoint.py | 272 ++++++++++++++++++---
tests/api_connexion/schemas/test_dataset_schema.py | 58 ++++-
tests/test_utils/db.py | 3 +-
6 files changed, 497 insertions(+), 35 deletions(-)
diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py
index cfdc04f615..5c9e7606fc 100644
--- a/airflow/api_connexion/endpoints/dataset_endpoint.py
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -15,19 +15,23 @@
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
+
from sqlalchemy import func
from sqlalchemy.orm import Session
-from airflow import Dataset
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.schemas.dataset_schema import (
DatasetCollection,
+ DatasetEventCollection,
dataset_collection_schema,
+ dataset_event_collection_schema,
dataset_schema,
)
from airflow.api_connexion.types import APIResponse
+from airflow.models.dataset import Dataset, DatasetEvent
from airflow.security import permissions
from airflow.utils.session import NEW_SESSION, provide_session
@@ -59,3 +63,42 @@ def get_datasets(
query = apply_sorting(query, order_by, {}, allowed_attrs)
datasets = query.offset(offset).limit(limit).all()
return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@provide_session
+@format_parameters({'limit': check_limit})
+def get_dataset_events(
+ *,
+ limit: int,
+ offset: int = 0,
+ order_by: str = "created_at",
+ dataset_id: Optional[int] = None,
+ source_dag_id: Optional[str] = None,
+ source_task_id: Optional[str] = None,
+ source_run_id: Optional[str] = None,
+ source_map_index: Optional[int] = None,
+ session: Session = NEW_SESSION,
+) -> APIResponse:
+ """Get dataset events"""
+ allowed_attrs = ['source_dag_id', 'source_task_id', 'source_run_id', 'source_map_index', 'created_at']
+
+ query = session.query(DatasetEvent)
+
+ if dataset_id:
+ query = query.filter(DatasetEvent.dataset_id == dataset_id)
+ if source_dag_id:
+ query = query.filter(DatasetEvent.source_dag_id == source_dag_id)
+ if source_task_id:
+ query = query.filter(DatasetEvent.source_task_id == source_task_id)
+ if source_run_id:
+ query = query.filter(DatasetEvent.source_run_id == source_run_id)
+ if source_map_index:
+ query = query.filter(DatasetEvent.source_map_index == source_map_index)
+
+ total_entries = query.count()
+ query = apply_sorting(query, order_by, {}, allowed_attrs)
+ events = query.offset(offset).limit(limit).all()
+ return dataset_event_collection_schema.dump(
+ DatasetEventCollection(dataset_events=events, total_entries=total_entries)
+ )
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index 4bfedbfb98..2dfef4780c 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1654,6 +1654,36 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
+ /datasets/events:
+ parameters:
+ - $ref: '#/components/parameters/PageLimit'
+ - $ref: '#/components/parameters/PageOffset'
+ - $ref: '#/components/parameters/OrderBy'
+ - $ref: '#/components/parameters/FilterDatasetID'
+ - $ref: '#/components/parameters/FilterSourceDAGID'
+ - $ref: '#/components/parameters/FilterSourceTaskID'
+ - $ref: '#/components/parameters/FilterSourceRunID'
+ - $ref: '#/components/parameters/FilterSourceMapIndex'
+ get:
+ summary: Get dataset events
+ description: Get dataset events
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+ operationId: get_dataset_events
+ tags: [Dataset]
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DatasetEventCollection'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
/config:
get:
summary: Get current configuration
@@ -3461,6 +3491,57 @@ components:
$ref: '#/components/schemas/Dataset'
- $ref: '#/components/schemas/CollectionInfo'
+ DatasetEvent:
+ description: |
+ A dataset event.
+
+ *New in version 2.4.0*
+ type: object
+ properties:
+ dataset_id:
+ type: integer
+ description: The dataset id
+ extra:
+ type: string
+ description: The dataset extra
+ nullable: true
+ source_dag_id:
+ type: string
+ description: The DAG ID that updated the dataset.
+ nullable: false
+ source_task_id:
+ type: string
+ description: The task ID that updated the dataset.
+ nullable: false
+ source_run_id:
+ type: string
+ description: The DAG run ID that updated the dataset.
+ nullable: false
+ source_map_index:
+ type: integer
+ description: The task map index that updated the dataset.
+ nullable: false
+ created_at:
+ type: string
+ description: The dataset event creation time
+ nullable: false
+
+
+ DatasetEventCollection:
+ description: |
+ A collection of dataset events.
+
+ *New in version 2.4.0*
+ type: object
+ allOf:
+ - type: object
+ properties:
+ dataset_events:
+ type: array
+ items:
+ $ref: '#/components/schemas/DatasetEvent'
+ - $ref: '#/components/schemas/CollectionInfo'
+
# Configuration
ConfigOption:
@@ -4287,6 +4368,41 @@ components:
*New in version 2.2.0*
+ FilterDatasetID:
+ in: query
+ name: dataset_id
+ schema:
+ type: integer
+ description: The Dataset ID that updated the dataset.
+
+ FilterSourceDAGID:
+ in: query
+ name: source_dag_id
+ schema:
+ type: string
+ description: The DAG ID that updated the dataset.
+
+ FilterSourceTaskID:
+ in: query
+ name: source_task_id
+ schema:
+ type: string
+ description: The task ID that updated the dataset.
+
+ FilterSourceRunID:
+ in: query
+ name: source_run_id
+ schema:
+ type: string
+ description: The DAG run ID that updated the dataset.
+
+ FilterSourceMapIndex:
+ in: query
+ name: source_map_index
+ schema:
+ type: integer
+ description: The map index that updated the dataset.
+
OrderBy:
in: query
name: order_by
diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/dataset_schema.py
index 2228ac4c9b..5b2601cea7 100644
--- a/airflow/api_connexion/schemas/dataset_schema.py
+++ b/airflow/api_connexion/schemas/dataset_schema.py
@@ -20,7 +20,7 @@ from typing import List, NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
-from airflow import Dataset
+from airflow.models.dataset import Dataset, DatasetEvent
class DatasetSchema(SQLAlchemySchema):
@@ -54,3 +54,39 @@ class DatasetCollectionSchema(Schema):
dataset_schema = DatasetSchema()
dataset_collection_schema = DatasetCollectionSchema()
+
+
+class DatasetEventSchema(SQLAlchemySchema):
+ """Dataset Event DB schema"""
+
+ class Meta:
+ """Meta"""
+
+ model = DatasetEvent
+
+ id = auto_field()
+ dataset_id = auto_field()
+ extra = auto_field()
+ source_task_id = auto_field()
+ source_dag_id = auto_field()
+ source_run_id = auto_field()
+ source_map_index = auto_field()
+ created_at = auto_field()
+
+
+class DatasetEventCollection(NamedTuple):
+ """List of Dataset events with meta"""
+
+ dataset_events: List[DatasetEvent]
+ total_entries: int
+
+
+class DatasetEventCollectionSchema(Schema):
+ """Dataset Event Collection Schema"""
+
+ dataset_events = fields.List(fields.Nested(DatasetEventSchema))
+ total_entries = fields.Int()
+
+
+dataset_event_schema = DatasetEventSchema()
+dataset_event_collection_schema = DatasetEventCollectionSchema()
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py
index 3d6dd59c8e..2696e60300 100644
--- a/tests/api_connexion/endpoints/test_dataset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -19,7 +19,7 @@ import pytest
from parameterized import parameterized
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.models import Dataset
+from airflow.models.dataset import Dataset, DatasetEvent
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import provide_session
@@ -62,6 +62,7 @@ class TestDatasetEndpoint:
def _create_dataset(self, session):
dataset_model = Dataset(
+ id=1,
uri="s3://bucket/key",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
@@ -70,23 +71,17 @@ class TestDatasetEndpoint:
session.add(dataset_model)
session.commit()
- @staticmethod
- def _normalize_dataset_ids(datasets):
- for i, dataset in enumerate(datasets, 1):
- dataset["id"] = i
-
class TestGetDatasetEndpoint(TestDatasetEndpoint):
def test_should_respond_200(self, session):
self._create_dataset(session)
- result = session.query(Dataset).all()
- assert len(result) == 1
- response = self.client.get(
- f"/api/v1/datasets/{result[0].id}", environ_overrides={'REMOTE_USER': "test"}
- )
+ assert session.query(Dataset).count() == 1
+
+ response = self.client.get("/api/v1/datasets/1", environ_overrides={'REMOTE_USER': "test"})
+
assert response.status_code == 200
assert response.json == {
- "id": result[0].id,
+ "id": 1,
"uri": "s3://bucket/key",
"extra": "{'foo': 'bar'}",
"created_at": self.default_time,
@@ -105,9 +100,7 @@ class TestGetDatasetEndpoint(TestDatasetEndpoint):
def test_should_raises_401_unauthenticated(self, session):
self._create_dataset(session)
- dataset = session.query(Dataset).first()
- response = self.client.get(f"/api/v1/datasets/{dataset.id}")
-
+ response = self.client.get("/api/v1/datasets/1")
assert_401(response)
@@ -115,22 +108,22 @@ class TestGetDatasets(TestDatasetEndpoint):
def test_should_respond_200(self, session):
datasets = [
Dataset(
- uri=f"s3://bucket/key/{i+1}",
+ id=i,
+ uri=f"s3://bucket/key/{i}",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
)
- for i in range(2)
+ for i in [1, 2]
]
session.add_all(datasets)
session.commit()
- result = session.query(Dataset).all()
- assert len(result) == 2
+ assert session.query(Dataset).count() == 2
+
response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
response_data = response.json
- self._normalize_dataset_ids(response_data['datasets'])
assert response_data == {
"datasets": [
{
@@ -154,17 +147,16 @@ class TestGetDatasets(TestDatasetEndpoint):
def test_order_by_raises_400_for_invalid_attr(self, session):
datasets = [
Dataset(
- uri=f"s3://bucket/key/{i+1}",
+ uri=f"s3://bucket/key/{i}",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
)
- for i in range(2)
+ for i in [1, 2]
]
session.add_all(datasets)
session.commit()
- result = session.query(Dataset).all()
- assert len(result) == 2
+ assert session.query(Dataset).count() == 2
response = self.client.get(
"/api/v1/datasets?order_by=fake", environ_overrides={'REMOTE_USER': "test"}
@@ -177,17 +169,16 @@ class TestGetDatasets(TestDatasetEndpoint):
def test_should_raises_401_unauthenticated(self, session):
datasets = [
Dataset(
- uri=f"s3://bucket/key/{i+1}",
+ uri=f"s3://bucket/key/{i}",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
)
- for i in range(2)
+ for i in [1, 2]
]
session.add_all(datasets)
session.commit()
- result = session.query(Dataset).all()
- assert len(result) == 2
+ assert session.query(Dataset).count() == 2
response = self.client.get("/api/v1/datasets")
@@ -230,7 +221,7 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
def test_should_respect_page_size_limit_default(self, session):
datasets = [
Dataset(
- uri=f"s3://bucket/key/{i+1}",
+ uri=f"s3://bucket/key/{i}",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
@@ -239,7 +230,9 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
]
session.add_all(datasets)
session.commit()
+
response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
+
assert response.status_code == 200
assert len(response.json['datasets']) == 100
@@ -247,15 +240,232 @@ class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
def test_should_return_conf_max_if_req_max_above_conf(self, session):
datasets = [
Dataset(
- uri=f"s3://bucket/key/{i+1}",
+ uri=f"s3://bucket/key/{i}",
extra={"foo": "bar"},
created_at=timezone.parse(self.default_time),
updated_at=timezone.parse(self.default_time),
)
- for i in range(200)
+ for i in range(1, 200)
]
session.add_all(datasets)
session.commit()
+
response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={'REMOTE_USER': "test"})
+
assert response.status_code == 200
assert len(response.json['datasets']) == 150
+
+
+class TestGetDatasetEvents(TestDatasetEndpoint):
+ def test_should_respond_200(self, session):
+ self._create_dataset(session)
+ common = {
+ "dataset_id": 1,
+ "extra": "{'foo': 'bar'}",
+ "source_dag_id": "foo",
+ "source_task_id": "bar",
+ "source_run_id": "custom",
+ "source_map_index": -1,
+ }
+
+ events = [DatasetEvent(id=i, created_at=timezone.parse(self.default_time), **common) for i in [1, 2]]
+ session.add_all(events)
+ session.commit()
+ assert session.query(DatasetEvent).count() == 2
+
+ response = self.client.get("/api/v1/datasets/events", environ_overrides={'REMOTE_USER': "test"})
+
+ assert response.status_code == 200
+ response_data = response.json
+ assert response_data == {
+ "dataset_events": [
+ {"id": 1, "created_at": self.default_time, **common},
+ {"id": 2, "created_at": self.default_time, **common},
+ ],
+ "total_entries": 2,
+ }
+
+ @parameterized.expand(
+ [
+ ('dataset_id', '2'),
+ ('source_dag_id', 'dag2'),
+ ('source_task_id', 'task2'),
+ ('source_run_id', 'run2'),
+ ('source_map_index', '2'),
+ ]
+ )
+ @provide_session
+ def test_filtering(self, attr, value, session):
+ datasets = [
+ Dataset(
+ id=i,
+ uri=f"s3://bucket/key/{i}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in [1, 2, 3]
+ ]
+ session.add_all(datasets)
+ session.commit()
+ events = [
+ DatasetEvent(
+ id=i,
+ dataset_id=i,
+ source_dag_id=f"dag{i}",
+ source_task_id=f"task{i}",
+ source_run_id=f"run{i}",
+ source_map_index=i,
+ created_at=timezone.parse(self.default_time),
+ )
+ for i in [1, 2, 3]
+ ]
+ session.add_all(events)
+ session.commit()
+ assert session.query(DatasetEvent).count() == 3
+
+ response = self.client.get(
+ f"/api/v1/datasets/events?{attr}={value}", environ_overrides={'REMOTE_USER': "test"}
+ )
+
+ assert response.status_code == 200
+ response_data = response.json
+ assert response_data == {
+ "dataset_events": [
+ {
+ "id": 2,
+ "dataset_id": 2,
+ "extra": None,
+ "source_dag_id": "dag2",
+ "source_task_id": "task2",
+ "source_run_id": "run2",
+ "source_map_index": 2,
+ "created_at": self.default_time,
+ }
+ ],
+ "total_entries": 1,
+ }
+
+ def test_order_by_raises_400_for_invalid_attr(self, session):
+ self._create_dataset(session)
+ events = [
+ DatasetEvent(
+ dataset_id=1,
+ extra="{'foo': 'bar'}",
+ source_dag_id="foo",
+ source_task_id="bar",
+ source_run_id="custom",
+ source_map_index=-1,
+ created_at=timezone.parse(self.default_time),
+ )
+ for i in [1, 2]
+ ]
+ session.add_all(events)
+ session.commit()
+ assert session.query(DatasetEvent).count() == 2
+
+ response = self.client.get(
+ "/api/v1/datasets/events?order_by=fake", environ_overrides={'REMOTE_USER': "test"}
+ ) # missing attr
+
+ assert response.status_code == 400
+ msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model"
+ assert response.json['detail'] == msg
+
+ def test_should_raises_401_unauthenticated(self, session):
+ response = self.client.get("/api/v1/datasets/events")
+ assert_401(response)
+
+
+class TestGetDatasetEventsEndpointPagination(TestDatasetEndpoint):
+ @parameterized.expand(
+ [
+ # Limit test data
+ ("/api/v1/datasets/events?limit=1&order_by=source_run_id", ["run1"]),
+ (
+ "/api/v1/datasets/events?limit=3&order_by=source_run_id",
+ [f"run{i}" for i in range(1, 4)],
+ ),
+ # Offset test data
+ (
+ "/api/v1/datasets/events?offset=1&order_by=source_run_id",
+ [f"run{i}" for i in range(2, 10)],
+ ),
+ (
+ "/api/v1/datasets/events?offset=3&order_by=source_run_id",
+ [f"run{i}" for i in range(4, 10)],
+ ),
+ # Limit and offset test data
+ (
+ "/api/v1/datasets/events?offset=3&limit=3&order_by=source_run_id",
+ [f"run{i}" for i in [4, 5, 6]],
+ ),
+ ]
+ )
+ @provide_session
+ def test_limit_and_offset(self, url, expected_event_runids, session):
+ self._create_dataset(session)
+ events = [
+ DatasetEvent(
+ dataset_id=1,
+ source_dag_id="foo",
+ source_task_id="bar",
+ source_run_id=f"run{i}",
+ source_map_index=-1,
+ created_at=timezone.parse(self.default_time),
+ )
+ for i in range(1, 10)
+ ]
+ session.add_all(events)
+ session.commit()
+
+ response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
+
+ assert response.status_code == 200
+ event_runids = [event["source_run_id"] for event in response.json["dataset_events"]]
+ assert event_runids == expected_event_runids
+
+ def test_should_respect_page_size_limit_default(self, session):
+ self._create_dataset(session)
+ events = [
+ DatasetEvent(
+ dataset_id=1,
+ source_dag_id="foo",
+ source_task_id="bar",
+ source_run_id=f"run{i}",
+ source_map_index=-1,
+ created_at=timezone.parse(self.default_time),
+ )
+ for i in range(1, 110)
+ ]
+ session.add_all(events)
+ session.commit()
+
+ response = self.client.get("/api/v1/datasets/events", environ_overrides={'REMOTE_USER': "test"})
+
+ assert response.status_code == 200
+ assert len(response.json['dataset_events']) == 100
+
+ @conf_vars({("api", "maximum_page_limit"): "150"})
+ def test_should_return_conf_max_if_req_max_above_conf(self, session):
+ self._create_dataset(session)
+ events = [
+ DatasetEvent(
+ dataset_id=1,
+ source_dag_id="foo",
+ source_task_id="bar",
+ source_run_id=f"run{i}",
+ source_map_index=-1,
+ created_at=timezone.parse(self.default_time),
+ )
+ for i in range(1, 200)
+ ]
+ session.add_all(events)
+ session.commit()
+
+ response = self.client.get(
+ "/api/v1/datasets/events?limit=180", environ_overrides={'REMOTE_USER': "test"}
+ )
+
+ assert response.status_code == 200
+ assert len(response.json['dataset_events']) == 150
diff --git a/tests/api_connexion/schemas/test_dataset_schema.py b/tests/api_connexion/schemas/test_dataset_schema.py
index bf30b9b12e..1d33fbf813 100644
--- a/tests/api_connexion/schemas/test_dataset_schema.py
+++ b/tests/api_connexion/schemas/test_dataset_schema.py
@@ -15,12 +15,15 @@
# specific language governing permissions and limitations
# under the License.
-from airflow import Dataset
from airflow.api_connexion.schemas.dataset_schema import (
DatasetCollection,
+ DatasetEventCollection,
dataset_collection_schema,
+ dataset_event_collection_schema,
+ dataset_event_schema,
dataset_schema,
)
+from airflow.models.dataset import Dataset, DatasetEvent
from airflow.utils import timezone
from tests.test_utils.db import clear_db_datasets
@@ -93,3 +96,56 @@ class TestDatasetCollectionSchema(TestDatasetSchemaBase):
],
"total_entries": 2,
}
+
+
+class TestDatasetEventSchema(TestDatasetSchemaBase):
+ def test_serialize(self, session):
+ event = DatasetEvent(
+ id=1,
+ dataset_id=10,
+ extra={"foo": "bar"},
+ source_dag_id="foo",
+ source_task_id="bar",
+ source_run_id="custom",
+ source_map_index=-1,
+ created_at=timezone.parse(self.timestamp),
+ )
+ session.add(event)
+ session.flush()
+ serialized_data = dataset_event_schema.dump(event)
+ assert serialized_data == {
+ "id": 1,
+ "dataset_id": 10,
+ "extra": "{'foo': 'bar'}",
+ "source_dag_id": "foo",
+ "source_task_id": "bar",
+ "source_run_id": "custom",
+ "source_map_index": -1,
+ "created_at": self.timestamp,
+ }
+
+
+class TestDatasetEventCollectionSchema(TestDatasetSchemaBase):
+ def test_serialize(self, session):
+ common = {
+ "dataset_id": 10,
+ "extra": "{'foo': 'bar'}",
+ "source_dag_id": "foo",
+ "source_task_id": "bar",
+ "source_run_id": "custom",
+ "source_map_index": -1,
+ }
+
+ events = [DatasetEvent(id=i, created_at=timezone.parse(self.timestamp), **common) for i in [1, 2]]
+ session.add_all(events)
+ session.flush()
+ serialized_data = dataset_event_collection_schema.dump(
+ DatasetEventCollection(dataset_events=events, total_entries=2)
+ )
+ assert serialized_data == {
+ "dataset_events": [
+ {"id": 1, "created_at": self.timestamp, **common},
+ {"id": 2, "created_at": self.timestamp, **common},
+ ],
+ "total_entries": 2,
+ }
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index be91413b55..92629020fa 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -23,7 +23,6 @@ from airflow.models import (
DagRun,
DagTag,
DagWarning,
- Dataset,
DbCallbackRequest,
Log,
Pool,
@@ -38,6 +37,7 @@ from airflow.models import (
errors,
)
from airflow.models.dagcode import DagCode
+from airflow.models.dataset import Dataset, DatasetEvent
from airflow.models.serialized_dag import SerializedDagModel
from airflow.security.permissions import RESOURCE_DAG_PREFIX
from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections, reflect_tables
@@ -55,6 +55,7 @@ def clear_db_runs():
def clear_db_datasets():
with create_session() as session:
+ session.query(DatasetEvent).delete()
session.query(Dataset).delete()