You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/07/12 23:41:02 UTC

[airflow] branch main updated: Add read-only REST API endpoint for Datasets (#24696)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 809d95ec06 Add read-only REST API endpoint for Datasets (#24696)
809d95ec06 is described below

commit 809d95ec06447c9579383d15136190c0963b3c1b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 13 00:40:50 2022 +0100

    Add read-only REST API endpoint for Datasets (#24696)
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
    Co-authored-by: Jed Cunningham <je...@apache.org>
---
 .../api_connexion/endpoints/dataset_endpoint.py    |  61 +++++
 airflow/api_connexion/openapi/v1.yaml              |  98 ++++++++
 airflow/api_connexion/schemas/dataset_schema.py    |  56 +++++
 airflow/security/permissions.py                    |   1 +
 .../endpoints/test_dataset_endpoint.py             | 261 +++++++++++++++++++++
 tests/api_connexion/schemas/test_dataset_schema.py |  95 ++++++++
 6 files changed, 572 insertions(+)

diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py
new file mode 100644
index 0000000000..a3c691b923
--- /dev/null
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from airflow import Dataset
+from airflow.api_connexion import security
+from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
+from airflow.api_connexion.schemas.dataset_schema import (
+    DatasetCollection,
+    dataset_collection_schema,
+    dataset_schema,
+)
+from airflow.api_connexion.types import APIResponse
+from airflow.security import permissions
+from airflow.utils.session import NEW_SESSION, provide_session
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@provide_session
+def get_dataset(id, session):
+    """Get a Dataset"""
+    dataset = session.query(Dataset).get(id)
+    if not dataset:
+        raise NotFound(
+            "Dataset not found",
+            detail=f"The Dataset with id: `{id}` was not found",
+        )
+    return dataset_schema.dump(dataset)
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@format_parameters({'limit': check_limit})
+@provide_session
+def get_datasets(
+    *, limit: int, offset: int = 0, order_by: str = "id", session: Session = NEW_SESSION
+) -> APIResponse:
+    """Get datasets"""
+    allowed_filter_attrs = ['id', 'uri', 'created_at', 'updated_at']
+
+    total_entries = session.query(func.count(Dataset.id)).scalar()
+    query = session.query(Dataset)
+    query = apply_sorting(query, order_by, {}, allowed_filter_attrs)
+    datasets = query.offset(offset).limit(limit).all()
+    return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index d9c0cc6ed8..4bfedbfb98 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1609,6 +1609,51 @@ paths:
         '403':
           $ref: '#/components/responses/PermissionDenied'
 
+  /datasets:
+    get:
+      summary: List datasets
+      x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+      operationId: get_datasets
+      tags: [Dataset]
+      parameters:
+        - $ref: '#/components/parameters/PageLimit'
+        - $ref: '#/components/parameters/PageOffset'
+        - $ref: '#/components/parameters/OrderBy'
+      responses:
+        '200':
+          description: Success.
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/DatasetCollection'
+        '401':
+          $ref: '#/components/responses/Unauthenticated'
+        '403':
+          $ref: '#/components/responses/PermissionDenied'
+
+  /datasets/{id}:
+    parameters:
+      - $ref: '#/components/parameters/DatasetID'
+    get:
+      summary: Get a dataset
+      description: Get a dataset by id.
+      x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+      operationId: get_dataset
+      tags: [Dataset]
+      responses:
+        '200':
+          description: Success.
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/Dataset'
+        '401':
+          $ref: '#/components/responses/Unauthenticated'
+        '403':
+          $ref: '#/components/responses/PermissionDenied'
+        '404':
+          $ref: '#/components/responses/NotFound'
+
   /config:
     get:
       summary: Get current configuration
@@ -3374,6 +3419,49 @@ components:
           $ref: '#/components/schemas/Resource'
           description: The permission resource
 
+    Dataset:
+      description: |
+        A dataset item.
+
+        *New in version 2.4.0*
+      type: object
+      properties:
+        id:
+          type: integer
+          description: The dataset id
+        uri:
+          type: string
+          description: The dataset uri
+          nullable: false
+        extra:
+          type: string
+          description: The dataset extra
+          nullable: true
+        created_at:
+          type: string
+          description: The dataset creation time
+          nullable: false
+        updated_at:
+          type: string
+          description: The dataset update time
+          nullable: false
+
+    DatasetCollection:
+      description: |
+        A collection of datasets.
+
+        *New in version 2.4.0*
+      type: object
+      allOf:
+        - type: object
+          properties:
+            datasets:
+              type: array
+              items:
+                $ref: '#/components/schemas/Dataset'
+        - $ref: '#/components/schemas/CollectionInfo'
+
+
     # Configuration
     ConfigOption:
       type: object
@@ -4018,6 +4106,14 @@ components:
       required: true
       description: The import error ID.
 
+    DatasetID:
+      in: path
+      name: id
+      schema:
+        type: integer
+      required: true
+      description: The Dataset ID
+
     PoolName:
       in: path
       name: pool_name
@@ -4347,6 +4443,8 @@ tags:
   - name: Role
   - name: Permission
   - name: User
+  - name: DagWarning
+  - name: Dataset
 
 externalDocs:
   url: https://airflow.apache.org/docs/apache-airflow/stable/
diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/dataset_schema.py
new file mode 100644
index 0000000000..2228ac4c9b
--- /dev/null
+++ b/airflow/api_connexion/schemas/dataset_schema.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow import Dataset
+
+
+class DatasetSchema(SQLAlchemySchema):
+    """Dataset DB schema"""
+
+    class Meta:
+        """Meta"""
+
+        model = Dataset
+
+    id = auto_field()
+    uri = auto_field()
+    extra = auto_field()
+    created_at = auto_field()
+    updated_at = auto_field()
+
+
+class DatasetCollection(NamedTuple):
+    """List of Datasets with meta"""
+
+    datasets: List[Dataset]
+    total_entries: int
+
+
+class DatasetCollectionSchema(Schema):
+    """Dataset Collection Schema"""
+
+    datasets = fields.List(fields.Nested(DatasetSchema))
+    total_entries = fields.Int()
+
+
+dataset_schema = DatasetSchema()
+dataset_collection_schema = DatasetCollectionSchema()
diff --git a/airflow/security/permissions.py b/airflow/security/permissions.py
index 9ffc07885d..b0e0b6df60 100644
--- a/airflow/security/permissions.py
+++ b/airflow/security/permissions.py
@@ -53,6 +53,7 @@ RESOURCE_USER_STATS_CHART = "User Stats Chart"
 RESOURCE_VARIABLE = "Variables"
 RESOURCE_WEBSITE = "Website"
 RESOURCE_XCOM = "XComs"
+RESOURCE_DATASET = "Datasets"
 
 
 # Action Constants
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py
new file mode 100644
index 0000000000..3d6dd59c8e
--- /dev/null
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from parameterized import parameterized
+
+from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from airflow.models import Dataset
+from airflow.security import permissions
+from airflow.utils import timezone
+from airflow.utils.session import provide_session
+from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_datasets
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_api):
+    app = minimal_app_for_api
+    create_user(
+        app,  # type: ignore
+        username="test",
+        role_name="Test",
+        permissions=[
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET),
+        ],
+    )
+    create_user(app, username="test_no_permissions", role_name="TestNoPermissions")  # type: ignore
+
+    yield app
+
+    delete_user(app, username="test")  # type: ignore
+    delete_user(app, username="test_no_permissions")  # type: ignore
+
+
+class TestDatasetEndpoint:
+
+    default_time = "2020-06-11T18:00:00+00:00"
+
+    @pytest.fixture(autouse=True)
+    def setup_attrs(self, configured_app) -> None:
+        self.app = configured_app
+        self.client = self.app.test_client()
+        clear_db_datasets()
+
+    def teardown_method(self) -> None:
+        clear_db_datasets()
+
+    def _create_dataset(self, session):
+        dataset_model = Dataset(
+            uri="s3://bucket/key",
+            extra={"foo": "bar"},
+            created_at=timezone.parse(self.default_time),
+            updated_at=timezone.parse(self.default_time),
+        )
+        session.add(dataset_model)
+        session.commit()
+
+    @staticmethod
+    def _normalize_dataset_ids(datasets):
+        for i, dataset in enumerate(datasets, 1):
+            dataset["id"] = i
+
+
+class TestGetDatasetEndpoint(TestDatasetEndpoint):
+    def test_should_respond_200(self, session):
+        self._create_dataset(session)
+        result = session.query(Dataset).all()
+        assert len(result) == 1
+        response = self.client.get(
+            f"/api/v1/datasets/{result[0].id}", environ_overrides={'REMOTE_USER': "test"}
+        )
+        assert response.status_code == 200
+        assert response.json == {
+            "id": result[0].id,
+            "uri": "s3://bucket/key",
+            "extra": "{'foo': 'bar'}",
+            "created_at": self.default_time,
+            "updated_at": self.default_time,
+        }
+
+    def test_should_respond_404(self):
+        response = self.client.get("/api/v1/datasets/1", environ_overrides={'REMOTE_USER': "test"})
+        assert response.status_code == 404
+        assert {
+            'detail': "The Dataset with id: `1` was not found",
+            'status': 404,
+            'title': 'Dataset not found',
+            'type': EXCEPTIONS_LINK_MAP[404],
+        } == response.json
+
+    def test_should_raises_401_unauthenticated(self, session):
+        self._create_dataset(session)
+        dataset = session.query(Dataset).first()
+        response = self.client.get(f"/api/v1/datasets/{dataset.id}")
+
+        assert_401(response)
+
+
+class TestGetDatasets(TestDatasetEndpoint):
+    def test_should_respond_200(self, session):
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i+1}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in range(2)
+        ]
+        session.add_all(datasets)
+        session.commit()
+        result = session.query(Dataset).all()
+        assert len(result) == 2
+        response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
+
+        assert response.status_code == 200
+        response_data = response.json
+        self._normalize_dataset_ids(response_data['datasets'])
+        assert response_data == {
+            "datasets": [
+                {
+                    "id": 1,
+                    "uri": "s3://bucket/key/1",
+                    "extra": "{'foo': 'bar'}",
+                    "created_at": self.default_time,
+                    "updated_at": self.default_time,
+                },
+                {
+                    "id": 2,
+                    "uri": "s3://bucket/key/2",
+                    "extra": "{'foo': 'bar'}",
+                    "created_at": self.default_time,
+                    "updated_at": self.default_time,
+                },
+            ],
+            "total_entries": 2,
+        }
+
+    def test_order_by_raises_400_for_invalid_attr(self, session):
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i+1}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in range(2)
+        ]
+        session.add_all(datasets)
+        session.commit()
+        result = session.query(Dataset).all()
+        assert len(result) == 2
+
+        response = self.client.get(
+            "/api/v1/datasets?order_by=fake", environ_overrides={'REMOTE_USER': "test"}
+        )  # missing attr
+
+        assert response.status_code == 400
+        msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model"
+        assert response.json['detail'] == msg
+
+    def test_should_raises_401_unauthenticated(self, session):
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i+1}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in range(2)
+        ]
+        session.add_all(datasets)
+        session.commit()
+        result = session.query(Dataset).all()
+        assert len(result) == 2
+
+        response = self.client.get("/api/v1/datasets")
+
+        assert_401(response)
+
+
+class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
+    @parameterized.expand(
+        [
+            # Limit test data
+            ("/api/v1/datasets?limit=1", ["s3://bucket/key/1"]),
+            ("/api/v1/datasets?limit=100", [f"s3://bucket/key/{i}" for i in range(1, 101)]),
+            # Offset test data
+            ("/api/v1/datasets?offset=1", [f"s3://bucket/key/{i}" for i in range(2, 102)]),
+            ("/api/v1/datasets?offset=3", [f"s3://bucket/key/{i}" for i in range(4, 104)]),
+            # Limit and offset test data
+            ("/api/v1/datasets?offset=3&limit=3", [f"s3://bucket/key/{i}" for i in [4, 5, 6]]),
+        ]
+    )
+    @provide_session
+    def test_limit_and_offset(self, url, expected_dataset_uris, session):
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in range(1, 110)
+        ]
+        session.add_all(datasets)
+        session.commit()
+
+        response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
+
+        assert response.status_code == 200
+        dataset_uris = [dataset["uri"] for dataset in response.json["datasets"]]
+        assert dataset_uris == expected_dataset_uris
+
+    def test_should_respect_page_size_limit_default(self, session):
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i+1}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in range(1, 110)
+        ]
+        session.add_all(datasets)
+        session.commit()
+        response = self.client.get("/api/v1/datasets", environ_overrides={'REMOTE_USER': "test"})
+        assert response.status_code == 200
+        assert len(response.json['datasets']) == 100
+
+    @conf_vars({("api", "maximum_page_limit"): "150"})
+    def test_should_return_conf_max_if_req_max_above_conf(self, session):
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i+1}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.default_time),
+                updated_at=timezone.parse(self.default_time),
+            )
+            for i in range(200)
+        ]
+        session.add_all(datasets)
+        session.commit()
+        response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={'REMOTE_USER': "test"})
+        assert response.status_code == 200
+        assert len(response.json['datasets']) == 150
diff --git a/tests/api_connexion/schemas/test_dataset_schema.py b/tests/api_connexion/schemas/test_dataset_schema.py
new file mode 100644
index 0000000000..bf30b9b12e
--- /dev/null
+++ b/tests/api_connexion/schemas/test_dataset_schema.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import Dataset
+from airflow.api_connexion.schemas.dataset_schema import (
+    DatasetCollection,
+    dataset_collection_schema,
+    dataset_schema,
+)
+from airflow.utils import timezone
+from tests.test_utils.db import clear_db_datasets
+
+
+class TestDatasetSchemaBase:
+    def setup_method(self) -> None:
+        clear_db_datasets()
+        self.timestamp = "2022-06-10T12:02:44+00:00"
+
+    def teardown_method(self) -> None:
+        clear_db_datasets()
+
+
+class TestDatasetSchema(TestDatasetSchemaBase):
+    def test_serialize(self, session):
+        dataset = Dataset(
+            uri="s3://bucket/key",
+            extra={"foo": "bar"},
+            created_at=timezone.parse(self.timestamp),
+            updated_at=timezone.parse(self.timestamp),
+        )
+        session.add(dataset)
+        session.flush()
+        serialized_data = dataset_schema.dump(dataset)
+        serialized_data['id'] = 1
+        assert serialized_data == {
+            "id": 1,
+            "uri": "s3://bucket/key",
+            "extra": "{'foo': 'bar'}",
+            "created_at": self.timestamp,
+            "updated_at": self.timestamp,
+        }
+
+
+class TestDatasetCollectionSchema(TestDatasetSchemaBase):
+    def test_serialize(self, session):
+
+        datasets = [
+            Dataset(
+                uri=f"s3://bucket/key/{i+1}",
+                extra={"foo": "bar"},
+                created_at=timezone.parse(self.timestamp),
+                updated_at=timezone.parse(self.timestamp),
+            )
+            for i in range(2)
+        ]
+        session.add_all(datasets)
+        session.flush()
+        serialized_data = dataset_collection_schema.dump(
+            DatasetCollection(datasets=datasets, total_entries=2)
+        )
+        serialized_data['datasets'][0]['id'] = 1
+        serialized_data['datasets'][1]['id'] = 2
+        assert serialized_data == {
+            "datasets": [
+                {
+                    "id": 1,
+                    "uri": "s3://bucket/key/1",
+                    "extra": "{'foo': 'bar'}",
+                    "created_at": self.timestamp,
+                    "updated_at": self.timestamp,
+                },
+                {
+                    "id": 2,
+                    "uri": "s3://bucket/key/2",
+                    "extra": "{'foo': 'bar'}",
+                    "created_at": self.timestamp,
+                    "updated_at": self.timestamp,
+                },
+            ],
+            "total_entries": 2,
+        }