You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/07/12 23:41:02 UTC
[airflow] branch main updated: Add read-only REST API endpoint for Datasets (#24696)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 809d95ec06 Add read-only REST API endpoint for Datasets (#24696)
809d95ec06 is described below
commit 809d95ec06447c9579383d15136190c0963b3c1b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 13 00:40:50 2022 +0100
Add read-only REST API endpoint for Datasets (#24696)
Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
Co-authored-by: Jed Cunningham <je...@apache.org>
---
.../api_connexion/endpoints/dataset_endpoint.py | 61 +++++
airflow/api_connexion/openapi/v1.yaml | 98 ++++++++
airflow/api_connexion/schemas/dataset_schema.py | 56 +++++
airflow/security/permissions.py | 1 +
.../endpoints/test_dataset_endpoint.py | 261 +++++++++++++++++++++
tests/api_connexion/schemas/test_dataset_schema.py | 95 ++++++++
6 files changed, 572 insertions(+)
diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py
new file mode 100644
index 0000000000..a3c691b923
--- /dev/null
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from airflow import Dataset
+from airflow.api_connexion import security
+from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
+from airflow.api_connexion.schemas.dataset_schema import (
+ DatasetCollection,
+ dataset_collection_schema,
+ dataset_schema,
+)
+from airflow.api_connexion.types import APIResponse
+from airflow.security import permissions
+from airflow.utils.session import NEW_SESSION, provide_session
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@provide_session
+def get_dataset(id, session):
+ """Get a Dataset"""
+ dataset = session.query(Dataset).get(id)
+ if not dataset:
+ raise NotFound(
+ "Dataset not found",
+ detail=f"The Dataset with id: `{id}` was not found",
+ )
+ return dataset_schema.dump(dataset)
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@format_parameters({'limit': check_limit})
+@provide_session
+def get_datasets(
+ *, limit: int, offset: int = 0, order_by: str = "id", session: Session = NEW_SESSION
+) -> APIResponse:
+ """Get datasets"""
+ allowed_filter_attrs = ['id', 'uri', 'created_at', 'updated_at']
+
+ total_entries = session.query(func.count(Dataset.id)).scalar()
+ query = session.query(Dataset)
+ query = apply_sorting(query, order_by, {}, allowed_filter_attrs)
+ datasets = query.offset(offset).limit(limit).all()
+ return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index d9c0cc6ed8..4bfedbfb98 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1609,6 +1609,51 @@ paths:
'403':
$ref: '#/components/responses/PermissionDenied'
+ /datasets:
+ get:
+ summary: List datasets
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+ operationId: get_datasets
+ tags: [Dataset]
+ parameters:
+ - $ref: '#/components/parameters/PageLimit'
+ - $ref: '#/components/parameters/PageOffset'
+ - $ref: '#/components/parameters/OrderBy'
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DatasetCollection'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+
+ /datasets/{id}:
+ parameters:
+ - $ref: '#/components/parameters/DatasetID'
+ get:
+ summary: Get a dataset
+ description: Get a dataset by id.
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+ operationId: get_dataset
+ tags: [Dataset]
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/Dataset'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
/config:
get:
summary: Get current configuration
@@ -3374,6 +3419,49 @@ components:
$ref: '#/components/schemas/Resource'
description: The permission resource
+ Dataset:
+ description: |
+ A dataset item.
+
+ *New in version 2.4.0*
+ type: object
+ properties:
+ id:
+ type: integer
+ description: The dataset id
+ uri:
+ type: string
+ description: The dataset uri
+ nullable: false
+ extra:
+ type: string
+ description: The dataset extra
+ nullable: true
+ created_at:
+ type: string
+ description: The dataset creation time
+ nullable: false
+ updated_at:
+ type: string
+ description: The dataset update time
+ nullable: false
+
+ DatasetCollection:
+ description: |
+ A collection of datasets.
+
+ *New in version 2.4.0*
+ type: object
+ allOf:
+ - type: object
+ properties:
+ datasets:
+ type: array
+ items:
+ $ref: '#/components/schemas/Dataset'
+ - $ref: '#/components/schemas/CollectionInfo'
+
+
# Configuration
ConfigOption:
type: object
@@ -4018,6 +4106,14 @@ components:
required: true
description: The import error ID.
+ DatasetID:
+ in: path
+ name: id
+ schema:
+ type: integer
+ required: true
+ description: The Dataset ID
+
PoolName:
in: path
name: pool_name
@@ -4347,6 +4443,8 @@ tags:
- name: Role
- name: Permission
- name: User
+ - name: DagWarning
+ - name: Dataset
externalDocs:
url: https://airflow.apache.org/docs/apache-airflow/stable/
diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/dataset_schema.py
new file mode 100644
index 0000000000..2228ac4c9b
--- /dev/null
+++ b/airflow/api_connexion/schemas/dataset_schema.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow import Dataset
+
+
+class DatasetSchema(SQLAlchemySchema):
+ """Dataset DB schema"""
+
+ class Meta:
+ """Meta"""
+
+ model = Dataset
+
+ id = auto_field()
+ uri = auto_field()
+ extra = auto_field()
+ created_at = auto_field()
+ updated_at = auto_field()
+
+
+class DatasetCollection(NamedTuple):
+ """List of Datasets with meta"""
+
+ datasets: List[Dataset]
+ total_entries: int
+
+
+class DatasetCollectionSchema(Schema):
+ """Dataset Collection Schema"""
+
+ datasets = fields.List(fields.Nested(DatasetSchema))
+ total_entries = fields.Int()
+
+
+dataset_schema = DatasetSchema()
+dataset_collection_schema = DatasetCollectionSchema()
diff --git a/airflow/security/permissions.py b/airflow/security/permissions.py
index 9ffc07885d..b0e0b6df60 100644
--- a/airflow/security/permissions.py
+++ b/airflow/security/permissions.py
@@ -53,6 +53,7 @@ RESOURCE_USER_STATS_CHART = "User Stats Chart"
RESOURCE_VARIABLE = "Variables"
RESOURCE_WEBSITE = "Website"
RESOURCE_XCOM = "XComs"
+RESOURCE_DATASET = "Datasets"
# Action Constants
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py
new file mode 100644
index 0000000000..3d6dd59c8e
--- /dev/null
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from parameterized import parameterized
+
+from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from airflow.models import Dataset
+from airflow.security import permissions
+from airflow.utils import timezone
+from airflow.utils.session import provide_session
+from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_datasets
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_api):
+ app = minimal_app_for_api
+ create_user(
+ app, # type: ignore
+ username="test",
+ role_name="Test",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET),
+ ],
+ )
+ create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+
+ yield app
+
+ delete_user(app, username="test") # type: ignore
+ delete_user(app, username="test_no_permissions") # type: ignore
+
+
+class TestDatasetEndpoint:
+
+ default_time = "2020-06-11T18:00:00+00:00"
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client()
+ clear_db_datasets()
+
+ def teardown_method(self) -> None:
+ clear_db_datasets()
+
+ def _create_dataset(self, session):
+ dataset_model = Dataset(
+ uri="s3://bucket/key",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ session.add(dataset_model)
+ session.commit()
+
+ @staticmethod
+ def _normalize_dataset_ids(datasets):
+ for i, dataset in enumerate(datasets, 1):
+ dataset["id"] = i
+
+
+class TestGetDatasetEndpoint(TestDatasetEndpoint):
+ def test_should_respond_200(self, session):
+ self._create_dataset(session)
+ result = session.query(Dataset).all()
+ assert len(result) == 1
+ response = self.client.get(
+ f"/api/v1/datasets/{result[0].id}", environ_overrides={'REMOTE_USER': "test"}
+ )
+ assert response.status_code == 200
+ assert response.json == {
+ "id": result[0].id,
+ "uri": "s3://bucket/key",
+ "extra": "{'foo': 'bar'}",
+ "created_at": self.default_time,
+ "updated_at": self.default_time,
+ }
+
+ def test_should_respond_404(self):
+ response = self.client.get("/api/v1/datasets/1", environ_overrides={'REMOTE_USER': "test"})
+ assert response.status_code == 404
+ assert {
+ 'detail': "The Dataset with id: `1` was not found",
+ 'status': 404,
+ 'title': 'Dataset not found',
+ 'type': EXCEPTIONS_LINK_MAP[404],
+ } == response.json
+
+ def test_should_raises_401_unauthenticated(self, session):
+ self._create_dataset(session)
+ dataset = session.query(Dataset).first()
+ response = self.client.get(f"/api/v1/datasets/{dataset.id}")
+
+ assert_401(response)
+
+
+class TestGetDatasets(TestDatasetEndpoint):
+ def test_should_respond_200(self, session):
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i+1}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in range(2)
+ ]
+ session.add_all(datasets)
+ session.commit()
+ result = session.query(Dataset).all()
+ assert len(result) == 2
+ response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
+
+ assert response.status_code == 200
+ response_data = response.json
+ self._normalize_dataset_ids(response_data['datasets'])
+ assert response_data == {
+ "datasets": [
+ {
+ "id": 1,
+ "uri": "s3://bucket/key/1",
+ "extra": "{'foo': 'bar'}",
+ "created_at": self.default_time,
+ "updated_at": self.default_time,
+ },
+ {
+ "id": 2,
+ "uri": "s3://bucket/key/2",
+ "extra": "{'foo': 'bar'}",
+ "created_at": self.default_time,
+ "updated_at": self.default_time,
+ },
+ ],
+ "total_entries": 2,
+ }
+
+ def test_order_by_raises_400_for_invalid_attr(self, session):
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i+1}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in range(2)
+ ]
+ session.add_all(datasets)
+ session.commit()
+ result = session.query(Dataset).all()
+ assert len(result) == 2
+
+ response = self.client.get(
+ "/api/v1/datasets?order_by=fake", environ_overrides={'REMOTE_USER': "test"}
+ ) # missing attr
+
+ assert response.status_code == 400
+ msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model"
+ assert response.json['detail'] == msg
+
+ def test_should_raises_401_unauthenticated(self, session):
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i+1}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in range(2)
+ ]
+ session.add_all(datasets)
+ session.commit()
+ result = session.query(Dataset).all()
+ assert len(result) == 2
+
+ response = self.client.get("/api/v1/datasets")
+
+ assert_401(response)
+
+
+class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
+ @parameterized.expand(
+ [
+ # Limit test data
+ ("/api/v1/datasets?limit=1", ["s3://bucket/key/1"]),
+ ("/api/v1/datasets?limit=100", [f"s3://bucket/key/{i}" for i in range(1, 101)]),
+ # Offset test data
+ ("/api/v1/datasets?offset=1", [f"s3://bucket/key/{i}" for i in range(2, 102)]),
+ ("/api/v1/datasets?offset=3", [f"s3://bucket/key/{i}" for i in range(4, 104)]),
+ # Limit and offset test data
+ ("/api/v1/datasets?offset=3&limit=3", [f"s3://bucket/key/{i}" for i in [4, 5, 6]]),
+ ]
+ )
+ @provide_session
+ def test_limit_and_offset(self, url, expected_dataset_uris, session):
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in range(1, 110)
+ ]
+ session.add_all(datasets)
+ session.commit()
+
+ response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
+
+ assert response.status_code == 200
+ dataset_uris = [dataset["uri"] for dataset in response.json["datasets"]]
+ assert dataset_uris == expected_dataset_uris
+
+ def test_should_respect_page_size_limit_default(self, session):
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i+1}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in range(1, 110)
+ ]
+ session.add_all(datasets)
+ session.commit()
+ response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
+ assert response.status_code == 200
+ assert len(response.json['datasets']) == 100
+
+ @conf_vars({("api", "maximum_page_limit"): "150"})
+ def test_should_return_conf_max_if_req_max_above_conf(self, session):
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i+1}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ for i in range(200)
+ ]
+ session.add_all(datasets)
+ session.commit()
+ response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={'REMOTE_USER': "test"})
+ assert response.status_code == 200
+ assert len(response.json['datasets']) == 150
diff --git a/tests/api_connexion/schemas/test_dataset_schema.py b/tests/api_connexion/schemas/test_dataset_schema.py
new file mode 100644
index 0000000000..bf30b9b12e
--- /dev/null
+++ b/tests/api_connexion/schemas/test_dataset_schema.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import Dataset
+from airflow.api_connexion.schemas.dataset_schema import (
+ DatasetCollection,
+ dataset_collection_schema,
+ dataset_schema,
+)
+from airflow.utils import timezone
+from tests.test_utils.db import clear_db_datasets
+
+
+class TestDatasetSchemaBase:
+ def setup_method(self) -> None:
+ clear_db_datasets()
+ self.timestamp = "2022-06-10T12:02:44+00:00"
+
+ def teardown_method(self) -> None:
+ clear_db_datasets()
+
+
+class TestDatasetSchema(TestDatasetSchemaBase):
+ def test_serialize(self, session):
+ dataset = Dataset(
+ uri="s3://bucket/key",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.timestamp),
+ updated_at=timezone.parse(self.timestamp),
+ )
+ session.add(dataset)
+ session.flush()
+ serialized_data = dataset_schema.dump(dataset)
+ serialized_data['id'] = 1
+ assert serialized_data == {
+ "id": 1,
+ "uri": "s3://bucket/key",
+ "extra": "{'foo': 'bar'}",
+ "created_at": self.timestamp,
+ "updated_at": self.timestamp,
+ }
+
+
+class TestDatasetCollectionSchema(TestDatasetSchemaBase):
+ def test_serialize(self, session):
+
+ datasets = [
+ Dataset(
+ uri=f"s3://bucket/key/{i+1}",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.timestamp),
+ updated_at=timezone.parse(self.timestamp),
+ )
+ for i in range(2)
+ ]
+ session.add_all(datasets)
+ session.flush()
+ serialized_data = dataset_collection_schema.dump(
+ DatasetCollection(datasets=datasets, total_entries=2)
+ )
+ serialized_data['datasets'][0]['id'] = 1
+ serialized_data['datasets'][1]['id'] = 2
+ assert serialized_data == {
+ "datasets": [
+ {
+ "id": 1,
+ "uri": "s3://bucket/key/1",
+ "extra": "{'foo': 'bar'}",
+ "created_at": self.timestamp,
+ "updated_at": self.timestamp,
+ },
+ {
+ "id": 2,
+ "uri": "s3://bucket/key/2",
+ "extra": "{'foo': 'bar'}",
+ "created_at": self.timestamp,
+ "updated_at": self.timestamp,
+ },
+ ],
+ "total_entries": 2,
+ }