You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/25 09:55:16 UTC

[airflow] branch master updated: Read only endpoint for XCom #8134 (#9170)

This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 5744a47  Read only endpoint for XCom #8134 (#9170)
5744a47 is described below

commit 5744a4797e10fad04ac4814c02889af309b65130
Author: S S Rohit <ro...@gmail.com>
AuthorDate: Thu Jun 25 15:24:39 2020 +0530

    Read only endpoint for XCom #8134 (#9170)
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
    Co-authored-by: Kamil BreguĊ‚a <ka...@polidea.com>
---
 airflow/api_connexion/endpoints/xcom_endpoint.py   |  61 +++++-
 airflow/api_connexion/schemas/xcom_schema.py       |  63 ++++++
 .../api_connexion/endpoints/test_xcom_endpoint.py  | 219 +++++++++++++++++++--
 tests/api_connexion/schemas/test_xcom_schema.py    | 211 ++++++++++++++++++++
 4 files changed, 532 insertions(+), 22 deletions(-)

diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py
index c67af40..cd317ad 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -14,9 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from flask import request
+from sqlalchemy import and_, func
+from sqlalchemy.orm.session import Session
 
-# TODO(mik-laj): We have to implement it.
-#     Do you want to help? Please look at: sshttps://github.com/apache/airflow/issues/8134
+from airflow.api_connexion import parameters
+from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.schemas.xcom_schema import (
+    XComCollection, XComCollectionItemSchema, XComCollectionSchema, xcom_collection_item_schema,
+    xcom_collection_schema,
+)
+from airflow.models import DagRun as DR, XCom
+from airflow.utils.session import provide_session
 
 
 def delete_xcom_entry():
@@ -26,18 +35,58 @@ def delete_xcom_entry():
     raise NotImplementedError("Not implemented yet.")
 
 
-def get_xcom_entries():
+@provide_session
+def get_xcom_entries(
+    dag_id: str,
+    dag_run_id: str,
+    task_id: str,
+    session: Session
+) -> XComCollectionSchema:
     """
     Get all XCom values
     """
-    raise NotImplementedError("Not implemented yet.")
+    offset = request.args.get(parameters.page_offset, 0)
+    limit = min(int(request.args.get(parameters.page_limit, 100)), 100)
+    query = session.query(XCom)
+    if dag_id != '~':
+        query = query.filter(XCom.dag_id == dag_id)
+        query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date))
+    else:
+        query.join(DR, XCom.execution_date == DR.execution_date)
+    if task_id != '~':
+        query = query.filter(XCom.task_id == task_id)
+    if dag_run_id != '~':
+        query = query.filter(DR.run_id == dag_run_id)
+    query = query.order_by(
+        XCom.execution_date, XCom.task_id, XCom.dag_id, XCom.key
+    )
+    total_entries = session.query(func.count(XCom.key)).scalar()
+    query = query.offset(offset).limit(limit)
+    return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
 
 
-def get_xcom_entry():
+@provide_session
+def get_xcom_entry(
+    dag_id: str,
+    task_id: str,
+    dag_run_id: str,
+    xcom_key: str,
+    session: Session
+) -> XComCollectionItemSchema:
     """
     Get an XCom entry
     """
-    raise NotImplementedError("Not implemented yet.")
+    query = session.query(XCom)
+    query = query.filter(and_(XCom.dag_id == dag_id,
+                              XCom.task_id == task_id,
+                              XCom.key == xcom_key))
+    query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date))
+    query = query.filter(DR.run_id == dag_run_id)
+
+    query_object = query.one_or_none()
+    if not query_object:
+        raise NotFound("XCom entry not found")
+    return xcom_collection_item_schema.dump(query_object)
 
 
 def patch_xcom_entry():
diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py
new file mode 100644
index 0000000..5adc36d
--- /dev/null
+++ b/airflow/api_connexion/schemas/xcom_schema.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import List, NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.models import XCom
+
+
+class XComCollectionItemSchema(SQLAlchemySchema):
+    """
+    Schema for a xcom item
+    """
+
+    class Meta:
+        """ Meta """
+        model = XCom
+
+    key = auto_field()
+    timestamp = auto_field()
+    execution_date = auto_field()
+    task_id = auto_field()
+    dag_id = auto_field()
+
+
+class XComSchema(XComCollectionItemSchema):
+    """
+    XCom schema
+    """
+
+    value = auto_field()
+
+
+class XComCollection(NamedTuple):
+    """ List of XComs with meta"""
+    xcom_entries: List[XCom]
+    total_entries: int
+
+
+class XComCollectionSchema(Schema):
+    """ XCom Collection Schema"""
+    xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema))
+    total_entries = fields.Int()
+
+
+xcom_schema = XComSchema(strict=True)
+xcom_collection_item_schema = XComCollectionItemSchema(strict=True)
+xcom_collection_schema = XComCollectionSchema(strict=True)
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index 2eabdcd..3f36e60 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -17,60 +17,247 @@
 import unittest
 
 import pytest
+from parameterized import parameterized
 
+from airflow.models import DagRun as DR, XCom
+from airflow.utils.dates import parse_execution_date
+from airflow.utils.session import create_session, provide_session
+from airflow.utils.types import DagRunType
 from airflow.www import app
 
 
-class TesXComEndpoint(unittest.TestCase):
+class TestXComEndpoint(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         super().setUpClass()
         cls.app = app.create_app(testing=True)  # type:ignore
 
     def setUp(self) -> None:
+        """
+        Setup For XCom endpoint TC
+        """
         self.client = self.app.test_client()  # type:ignore
+        # clear existing xcoms
+        with create_session() as session:
+            session.query(XCom).delete()
+            session.query(DR).delete()
 
+    def tearDown(self) -> None:
+        """
+        Clear Hanging XComs
+        """
+        with create_session() as session:
+            session.query(XCom).delete()
+            session.query(DR).delete()
 
-class TestDeleteXComEntry(TesXComEndpoint):
+
+class TestDeleteXComEntry(TestXComEndpoint):
     @pytest.mark.skip(reason="Not implemented yet")
     def test_should_response_200(self):
         response = self.client.delete(
-            "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY"
+            "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY"
         )
         assert response.status_code == 204
 
 
-class TestGetXComEntry(TesXComEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
+class TestGetXComEntry(TestXComEndpoint):
+
+    @provide_session
+    def test_should_response_200(self, session):
+        dag_id = 'test-dag-id'
+        task_id = 'test-task-id'
+        execution_date = '2005-04-02T00:00:00+00:00'
+        xcom_key = 'test-xcom-key'
+        execution_date_parsed = parse_execution_date(execution_date)
+        xcom_model = XCom(key=xcom_key,
+                          execution_date=execution_date_parsed,
+                          task_id=task_id,
+                          dag_id=dag_id,
+                          timestamp=execution_date_parsed)
+        dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
+        dagrun = DR(dag_id=dag_id,
+                    run_id=dag_run_id,
+                    execution_date=execution_date_parsed,
+                    start_date=execution_date_parsed,
+                    run_type=DagRunType.MANUAL.value)
+        session.add(xcom_model)
+        session.add(dagrun)
+        session.commit()
         response = self.client.get(
-            "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY"
+            f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}"
+        )
+        self.assertEqual(200, response.status_code)
+        self.assertEqual(
+            response.json,
+            {
+                'dag_id': dag_id,
+                'execution_date': execution_date,
+                'key': xcom_key,
+                'task_id': task_id,
+                'timestamp': execution_date
+            }
         )
-        assert response.status_code == 200
 
 
-class TestGetXComEntries(TesXComEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
+class TestGetXComEntries(TestXComEndpoint):
+    @provide_session
+    def test_should_response_200(self, session):
+        dag_id = 'test-dag-id'
+        task_id = 'test-task-id'
+        execution_date = '2005-04-02T00:00:00+00:00'
+        execution_date_parsed = parse_execution_date(execution_date)
+        xcom_model_1 = XCom(key='test-xcom-key-1',
+                            execution_date=execution_date_parsed,
+                            task_id=task_id,
+                            dag_id=dag_id,
+                            timestamp=execution_date_parsed)
+        xcom_model_2 = XCom(key='test-xcom-key-2',
+                            execution_date=execution_date_parsed,
+                            task_id=task_id,
+                            dag_id=dag_id,
+                            timestamp=execution_date_parsed)
+        dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
+        dagrun = DR(dag_id=dag_id,
+                    run_id=dag_run_id,
+                    execution_date=execution_date_parsed,
+                    start_date=execution_date_parsed,
+                    run_type=DagRunType.MANUAL.value)
+        xcom_models = [xcom_model_1, xcom_model_2]
+        session.add_all(xcom_models)
+        session.add(dagrun)
+        session.commit()
         response = self.client.get(
-            "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/"
+            f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries"
+        )
+        self.assertEqual(200, response.status_code)
+        self.assertEqual(
+            response.json,
+            {
+                'xcom_entries': [
+                    {
+                        'dag_id': dag_id,
+                        'execution_date': execution_date,
+                        'key': 'test-xcom-key-1',
+                        'task_id': task_id,
+                        'timestamp': execution_date
+                    },
+                    {
+                        'dag_id': dag_id,
+                        'execution_date': execution_date,
+                        'key': 'test-xcom-key-2',
+                        'task_id': task_id,
+                        'timestamp': execution_date
+                    }
+                ],
+                'total_entries': 2,
+            }
         )
+
+
+class TestPaginationGetXComEntries(TestXComEndpoint):
+
+    def setUp(self):
+        super().setUp()
+        self.dag_id = 'test-dag-id'
+        self.task_id = 'test-task-id'
+        self.execution_date = '2005-04-02T00:00:00+00:00'
+        self.execution_date_parsed = parse_execution_date(self.execution_date)
+        self.dag_run_id = DR.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed)
+
+    @parameterized.expand(
+        [
+            (
+                "limit=1",
+                ["TEST_XCOM_KEY1"],
+            ),
+            (
+                "limit=2",
+                ["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"],
+            ),
+            (
+                "offset=5",
+                [
+                    "TEST_XCOM_KEY5",
+                    "TEST_XCOM_KEY6",
+                    "TEST_XCOM_KEY7",
+                    "TEST_XCOM_KEY8",
+                    "TEST_XCOM_KEY9",
+                ]
+            ),
+            (
+                "offset=0",
+                [
+                    "TEST_XCOM_KEY1",
+                    "TEST_XCOM_KEY10",
+                    "TEST_XCOM_KEY2",
+                    "TEST_XCOM_KEY3",
+                    "TEST_XCOM_KEY4",
+                    "TEST_XCOM_KEY5",
+                    "TEST_XCOM_KEY6",
+                    "TEST_XCOM_KEY7",
+                    "TEST_XCOM_KEY8",
+                    "TEST_XCOM_KEY9"
+                ]
+            ),
+            (
+                "limit=1&offset=5",
+                ["TEST_XCOM_KEY5"],
+            ),
+            (
+                "limit=1&offset=1",
+                ["TEST_XCOM_KEY10"],
+            ),
+            (
+                "limit=2&offset=2",
+                ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"],
+            ),
+        ]
+    )
+    @provide_session
+    def test_handle_limit_offset(self, query_params, expected_xcom_ids, session):
+        url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}"
+        url = url.format(dag_id=self.dag_id,
+                         dag_run_id=self.dag_run_id,
+                         task_id=self.task_id,
+                         query_params=query_params)
+        dagrun = DR(dag_id=self.dag_id,
+                    run_id=self.dag_run_id,
+                    execution_date=self.execution_date_parsed,
+                    start_date=self.execution_date_parsed,
+                    run_type=DagRunType.MANUAL.value)
+        xcom_models = self._create_xcoms(10)
+        session.add_all(xcom_models)
+        session.add(dagrun)
+        session.commit()
+        response = self.client.get(url)
         assert response.status_code == 200
+        self.assertEqual(response.json["total_entries"], 10)
+        conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn]
+        self.assertEqual(conn_ids, expected_xcom_ids)
+
+    def _create_xcoms(self, count):
+        return [XCom(
+            key=f'TEST_XCOM_KEY{i}',
+            execution_date=self.execution_date_parsed,
+            task_id=self.task_id,
+            dag_id=self.dag_id,
+            timestamp=self.execution_date_parsed,
+        ) for i in range(1, count + 1)]
 
 
-class TestPatchXComEntry(TesXComEndpoint):
+class TestPatchXComEntry(TestXComEndpoint):
     @pytest.mark.skip(reason="Not implemented yet")
     def test_should_response_200(self):
         response = self.client.patch(
-            "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries"
+            "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries"
         )
         assert response.status_code == 200
 
 
-class TestPostXComEntry(TesXComEndpoint):
+class TestPostXComEntry(TestXComEndpoint):
     @pytest.mark.skip(reason="Not implemented yet")
     def test_should_response_200(self):
         response = self.client.post(
-            "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY"
+            "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY"
         )
         assert response.status_code == 200
diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py
new file mode 100644
index 0000000..d66c8ce
--- /dev/null
+++ b/tests/api_connexion/schemas/test_xcom_schema.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+
+from sqlalchemy import or_
+
+from airflow.api_connexion.schemas.xcom_schema import (
+    XComCollection, xcom_collection_item_schema, xcom_collection_schema, xcom_schema,
+)
+from airflow.models import XCom
+from airflow.utils.dates import parse_execution_date
+from airflow.utils.session import create_session, provide_session
+
+
+class TestXComSchemaBase(unittest.TestCase):
+
+    def setUp(self):
+        """
+        Clear Hanging XComs pre test
+        """
+        with create_session() as session:
+            session.query(XCom).delete()
+
+    def tearDown(self) -> None:
+        """
+        Clear Hanging XComs post test
+        """
+        with create_session() as session:
+            session.query(XCom).delete()
+
+
+class TestXComCollectionItemSchema(TestXComSchemaBase):
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.default_time = '2005-04-02T21:00:00+00:00'
+        self.default_time_parsed = parse_execution_date(self.default_time)
+
+    @provide_session
+    def test_serialize(self, session):
+        xcom_model = XCom(
+            key='test_key',
+            timestamp=self.default_time_parsed,
+            execution_date=self.default_time_parsed,
+            task_id='test_task_id',
+            dag_id='test_dag',
+        )
+        session.add(xcom_model)
+        session.commit()
+        xcom_model = session.query(XCom).first()
+        deserialized_xcom = xcom_collection_item_schema.dump(xcom_model)
+        self.assertEqual(
+            deserialized_xcom[0],
+            {
+                'key': 'test_key',
+                'timestamp': self.default_time,
+                'execution_date': self.default_time,
+                'task_id': 'test_task_id',
+                'dag_id': 'test_dag',
+            }
+        )
+
+    def test_deserialize(self):
+        xcom_dump = {
+            'key': 'test_key',
+            'timestamp': self.default_time,
+            'execution_date': self.default_time,
+            'task_id': 'test_task_id',
+            'dag_id': 'test_dag',
+        }
+        result = xcom_collection_item_schema.load(xcom_dump)
+        self.assertEqual(
+            result[0],
+            {
+                'key': 'test_key',
+                'timestamp': self.default_time_parsed,
+                'execution_date': self.default_time_parsed,
+                'task_id': 'test_task_id',
+                'dag_id': 'test_dag',
+            }
+        )
+
+
+class TestXComCollectionSchema(TestXComSchemaBase):
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.default_time_1 = '2005-04-02T21:00:00+00:00'
+        self.default_time_2 = '2005-04-02T21:01:00+00:00'
+        self.time_1 = parse_execution_date(self.default_time_1)
+        self.time_2 = parse_execution_date(self.default_time_2)
+
+    @provide_session
+    def test_serialize(self, session):
+        xcom_model_1 = XCom(
+            key='test_key_1',
+            timestamp=self.time_1,
+            execution_date=self.time_1,
+            task_id='test_task_id_1',
+            dag_id='test_dag_1',
+        )
+        xcom_model_2 = XCom(
+            key='test_key_2',
+            timestamp=self.time_2,
+            execution_date=self.time_2,
+            task_id='test_task_id_2',
+            dag_id='test_dag_2',
+        )
+        xcom_models = [xcom_model_1, xcom_model_2]
+        session.add_all(xcom_models)
+        session.commit()
+        xcom_models_query = session.query(XCom).filter(
+            or_(XCom.execution_date == self.time_1, XCom.execution_date == self.time_2)
+        )
+        xcom_models_queried = xcom_models_query.all()
+        deserialized_xcoms = xcom_collection_schema.dump(XComCollection(
+            xcom_entries=xcom_models_queried,
+            total_entries=xcom_models_query.count(),
+        ))
+        self.assertEqual(
+            deserialized_xcoms[0],
+            {
+                'xcom_entries': [
+                    {
+                        'key': 'test_key_1',
+                        'timestamp': self.default_time_1,
+                        'execution_date': self.default_time_1,
+                        'task_id': 'test_task_id_1',
+                        'dag_id': 'test_dag_1',
+                    },
+                    {
+                        'key': 'test_key_2',
+                        'timestamp': self.default_time_2,
+                        'execution_date': self.default_time_2,
+                        'task_id': 'test_task_id_2',
+                        'dag_id': 'test_dag_2',
+                    }
+                ],
+                'total_entries': len(xcom_models),
+            }
+        )
+
+
+class TestXComSchema(TestXComSchemaBase):
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.default_time = '2005-04-02T21:00:00+00:00'
+        self.default_time_parsed = parse_execution_date(self.default_time)
+
+    @provide_session
+    def test_serialize(self, session):
+        xcom_model = XCom(
+            key='test_key',
+            timestamp=self.default_time_parsed,
+            execution_date=self.default_time_parsed,
+            task_id='test_task_id',
+            dag_id='test_dag',
+            value=b'test_binary',
+        )
+        session.add(xcom_model)
+        session.commit()
+        xcom_model = session.query(XCom).first()
+        deserialized_xcom = xcom_schema.dump(xcom_model)
+        self.assertEqual(
+            deserialized_xcom[0],
+            {
+                'key': 'test_key',
+                'timestamp': self.default_time,
+                'execution_date': self.default_time,
+                'task_id': 'test_task_id',
+                'dag_id': 'test_dag',
+                'value': 'test_binary',
+            }
+        )
+
+    def test_deserialize(self):
+        xcom_dump = {
+            'key': 'test_key',
+            'timestamp': self.default_time,
+            'execution_date': self.default_time,
+            'task_id': 'test_task_id',
+            'dag_id': 'test_dag',
+            'value': b'test_binary',
+        }
+        result = xcom_schema.load(xcom_dump)
+        self.assertEqual(
+            result[0],
+            {
+                'key': 'test_key',
+                'timestamp': self.default_time_parsed,
+                'execution_date': self.default_time_parsed,
+                'task_id': 'test_task_id',
+                'dag_id': 'test_dag',
+                'value': 'test_binary',
+            }
+        )