You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/25 09:55:16 UTC
[airflow] branch master updated: Read only endpoint for XCom #8134
(#9170)
This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 5744a47 Read only endpoint for XCom #8134 (#9170)
5744a47 is described below
commit 5744a4797e10fad04ac4814c02889af309b65130
Author: S S Rohit <ro...@gmail.com>
AuthorDate: Thu Jun 25 15:24:39 2020 +0530
Read only endpoint for XCom #8134 (#9170)
Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
Co-authored-by: Kamil BreguĊa <ka...@polidea.com>
---
airflow/api_connexion/endpoints/xcom_endpoint.py | 61 +++++-
airflow/api_connexion/schemas/xcom_schema.py | 63 ++++++
.../api_connexion/endpoints/test_xcom_endpoint.py | 219 +++++++++++++++++++--
tests/api_connexion/schemas/test_xcom_schema.py | 211 ++++++++++++++++++++
4 files changed, 532 insertions(+), 22 deletions(-)
diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py
index c67af40..cd317ad 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -14,9 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from flask import request
+from sqlalchemy import and_, func
+from sqlalchemy.orm.session import Session
-# TODO(mik-laj): We have to implement it.
-# Do you want to help? Please look at: sshttps://github.com/apache/airflow/issues/8134
+from airflow.api_connexion import parameters
+from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.schemas.xcom_schema import (
+ XComCollection, XComCollectionItemSchema, XComCollectionSchema, xcom_collection_item_schema,
+ xcom_collection_schema,
+)
+from airflow.models import DagRun as DR, XCom
+from airflow.utils.session import provide_session
def delete_xcom_entry():
@@ -26,18 +35,58 @@ def delete_xcom_entry():
raise NotImplementedError("Not implemented yet.")
-def get_xcom_entries():
+@provide_session
+def get_xcom_entries(
+ dag_id: str,
+ dag_run_id: str,
+ task_id: str,
+ session: Session
+) -> XComCollectionSchema:
"""
Get all XCom values
"""
- raise NotImplementedError("Not implemented yet.")
+ offset = request.args.get(parameters.page_offset, 0)
+ limit = min(int(request.args.get(parameters.page_limit, 100)), 100)
+ query = session.query(XCom)
+ if dag_id != '~':
+ query = query.filter(XCom.dag_id == dag_id)
+ query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date))
+ else:
+ query.join(DR, XCom.execution_date == DR.execution_date)
+ if task_id != '~':
+ query = query.filter(XCom.task_id == task_id)
+ if dag_run_id != '~':
+ query = query.filter(DR.run_id == dag_run_id)
+ query = query.order_by(
+ XCom.execution_date, XCom.task_id, XCom.dag_id, XCom.key
+ )
+ total_entries = session.query(func.count(XCom.key)).scalar()
+ query = query.offset(offset).limit(limit)
+ return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
-def get_xcom_entry():
+@provide_session
+def get_xcom_entry(
+ dag_id: str,
+ task_id: str,
+ dag_run_id: str,
+ xcom_key: str,
+ session: Session
+) -> XComCollectionItemSchema:
"""
Get an XCom entry
"""
- raise NotImplementedError("Not implemented yet.")
+ query = session.query(XCom)
+ query = query.filter(and_(XCom.dag_id == dag_id,
+ XCom.task_id == task_id,
+ XCom.key == xcom_key))
+ query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date))
+ query = query.filter(DR.run_id == dag_run_id)
+
+ query_object = query.one_or_none()
+ if not query_object:
+ raise NotFound("XCom entry not found")
+ return xcom_collection_item_schema.dump(query_object)
def patch_xcom_entry():
diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py
new file mode 100644
index 0000000..5adc36d
--- /dev/null
+++ b/airflow/api_connexion/schemas/xcom_schema.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import List, NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.models import XCom
+
+
+class XComCollectionItemSchema(SQLAlchemySchema):
+ """
+ Schema for a xcom item
+ """
+
+ class Meta:
+ """ Meta """
+ model = XCom
+
+ key = auto_field()
+ timestamp = auto_field()
+ execution_date = auto_field()
+ task_id = auto_field()
+ dag_id = auto_field()
+
+
+class XComSchema(XComCollectionItemSchema):
+ """
+ XCom schema
+ """
+
+ value = auto_field()
+
+
+class XComCollection(NamedTuple):
+ """ List of XComs with meta"""
+ xcom_entries: List[XCom]
+ total_entries: int
+
+
+class XComCollectionSchema(Schema):
+ """ XCom Collection Schema"""
+ xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema))
+ total_entries = fields.Int()
+
+
+xcom_schema = XComSchema(strict=True)
+xcom_collection_item_schema = XComCollectionItemSchema(strict=True)
+xcom_collection_schema = XComCollectionSchema(strict=True)
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index 2eabdcd..3f36e60 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -17,60 +17,247 @@
import unittest
import pytest
+from parameterized import parameterized
+from airflow.models import DagRun as DR, XCom
+from airflow.utils.dates import parse_execution_date
+from airflow.utils.session import create_session, provide_session
+from airflow.utils.types import DagRunType
from airflow.www import app
-class TesXComEndpoint(unittest.TestCase):
+class TestXComEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.app = app.create_app(testing=True) # type:ignore
def setUp(self) -> None:
+ """
+ Setup For XCom endpoint TC
+ """
self.client = self.app.test_client() # type:ignore
+ # clear existing xcoms
+ with create_session() as session:
+ session.query(XCom).delete()
+ session.query(DR).delete()
+ def tearDown(self) -> None:
+ """
+ Clear Hanging XComs
+ """
+ with create_session() as session:
+ session.query(XCom).delete()
+ session.query(DR).delete()
-class TestDeleteXComEntry(TesXComEndpoint):
+
+class TestDeleteXComEntry(TestXComEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
response = self.client.delete(
- "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY"
+ "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY"
)
assert response.status_code == 204
-class TestGetXComEntry(TesXComEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
- def test_should_response_200(self):
+class TestGetXComEntry(TestXComEndpoint):
+
+ @provide_session
+ def test_should_response_200(self, session):
+ dag_id = 'test-dag-id'
+ task_id = 'test-task-id'
+ execution_date = '2005-04-02T00:00:00+00:00'
+ xcom_key = 'test-xcom-key'
+ execution_date_parsed = parse_execution_date(execution_date)
+ xcom_model = XCom(key=xcom_key,
+ execution_date=execution_date_parsed,
+ task_id=task_id,
+ dag_id=dag_id,
+ timestamp=execution_date_parsed)
+ dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
+ dagrun = DR(dag_id=dag_id,
+ run_id=dag_run_id,
+ execution_date=execution_date_parsed,
+ start_date=execution_date_parsed,
+ run_type=DagRunType.MANUAL.value)
+ session.add(xcom_model)
+ session.add(dagrun)
+ session.commit()
response = self.client.get(
- "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY"
+ f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}"
+ )
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(
+ response.json,
+ {
+ 'dag_id': dag_id,
+ 'execution_date': execution_date,
+ 'key': xcom_key,
+ 'task_id': task_id,
+ 'timestamp': execution_date
+ }
)
- assert response.status_code == 200
-class TestGetXComEntries(TesXComEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
- def test_should_response_200(self):
+class TestGetXComEntries(TestXComEndpoint):
+ @provide_session
+ def test_should_response_200(self, session):
+ dag_id = 'test-dag-id'
+ task_id = 'test-task-id'
+ execution_date = '2005-04-02T00:00:00+00:00'
+ execution_date_parsed = parse_execution_date(execution_date)
+ xcom_model_1 = XCom(key='test-xcom-key-1',
+ execution_date=execution_date_parsed,
+ task_id=task_id,
+ dag_id=dag_id,
+ timestamp=execution_date_parsed)
+ xcom_model_2 = XCom(key='test-xcom-key-2',
+ execution_date=execution_date_parsed,
+ task_id=task_id,
+ dag_id=dag_id,
+ timestamp=execution_date_parsed)
+ dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
+ dagrun = DR(dag_id=dag_id,
+ run_id=dag_run_id,
+ execution_date=execution_date_parsed,
+ start_date=execution_date_parsed,
+ run_type=DagRunType.MANUAL.value)
+ xcom_models = [xcom_model_1, xcom_model_2]
+ session.add_all(xcom_models)
+ session.add(dagrun)
+ session.commit()
response = self.client.get(
- "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/"
+ f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries"
+ )
+ self.assertEqual(200, response.status_code)
+ self.assertEqual(
+ response.json,
+ {
+ 'xcom_entries': [
+ {
+ 'dag_id': dag_id,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-1',
+ 'task_id': task_id,
+ 'timestamp': execution_date
+ },
+ {
+ 'dag_id': dag_id,
+ 'execution_date': execution_date,
+ 'key': 'test-xcom-key-2',
+ 'task_id': task_id,
+ 'timestamp': execution_date
+ }
+ ],
+ 'total_entries': 2,
+ }
)
+
+
+class TestPaginationGetXComEntries(TestXComEndpoint):
+
+ def setUp(self):
+ super().setUp()
+ self.dag_id = 'test-dag-id'
+ self.task_id = 'test-task-id'
+ self.execution_date = '2005-04-02T00:00:00+00:00'
+ self.execution_date_parsed = parse_execution_date(self.execution_date)
+ self.dag_run_id = DR.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed)
+
+ @parameterized.expand(
+ [
+ (
+ "limit=1",
+ ["TEST_XCOM_KEY1"],
+ ),
+ (
+ "limit=2",
+ ["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"],
+ ),
+ (
+ "offset=5",
+ [
+ "TEST_XCOM_KEY5",
+ "TEST_XCOM_KEY6",
+ "TEST_XCOM_KEY7",
+ "TEST_XCOM_KEY8",
+ "TEST_XCOM_KEY9",
+ ]
+ ),
+ (
+ "offset=0",
+ [
+ "TEST_XCOM_KEY1",
+ "TEST_XCOM_KEY10",
+ "TEST_XCOM_KEY2",
+ "TEST_XCOM_KEY3",
+ "TEST_XCOM_KEY4",
+ "TEST_XCOM_KEY5",
+ "TEST_XCOM_KEY6",
+ "TEST_XCOM_KEY7",
+ "TEST_XCOM_KEY8",
+ "TEST_XCOM_KEY9"
+ ]
+ ),
+ (
+ "limit=1&offset=5",
+ ["TEST_XCOM_KEY5"],
+ ),
+ (
+ "limit=1&offset=1",
+ ["TEST_XCOM_KEY10"],
+ ),
+ (
+ "limit=2&offset=2",
+ ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"],
+ ),
+ ]
+ )
+ @provide_session
+ def test_handle_limit_offset(self, query_params, expected_xcom_ids, session):
+ url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}"
+ url = url.format(dag_id=self.dag_id,
+ dag_run_id=self.dag_run_id,
+ task_id=self.task_id,
+ query_params=query_params)
+ dagrun = DR(dag_id=self.dag_id,
+ run_id=self.dag_run_id,
+ execution_date=self.execution_date_parsed,
+ start_date=self.execution_date_parsed,
+ run_type=DagRunType.MANUAL.value)
+ xcom_models = self._create_xcoms(10)
+ session.add_all(xcom_models)
+ session.add(dagrun)
+ session.commit()
+ response = self.client.get(url)
assert response.status_code == 200
+ self.assertEqual(response.json["total_entries"], 10)
+ conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn]
+ self.assertEqual(conn_ids, expected_xcom_ids)
+
+ def _create_xcoms(self, count):
+ return [XCom(
+ key=f'TEST_XCOM_KEY{i}',
+ execution_date=self.execution_date_parsed,
+ task_id=self.task_id,
+ dag_id=self.dag_id,
+ timestamp=self.execution_date_parsed,
+ ) for i in range(1, count + 1)]
-class TestPatchXComEntry(TesXComEndpoint):
+class TestPatchXComEntry(TestXComEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
response = self.client.patch(
- "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries"
+ "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries"
)
assert response.status_code == 200
-class TestPostXComEntry(TesXComEndpoint):
+class TestPostXComEntry(TestXComEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
response = self.client.post(
- "/dags/TEST_DAG_ID}/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY"
+ "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY"
)
assert response.status_code == 200
diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py
new file mode 100644
index 0000000..d66c8ce
--- /dev/null
+++ b/tests/api_connexion/schemas/test_xcom_schema.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+
+from sqlalchemy import or_
+
+from airflow.api_connexion.schemas.xcom_schema import (
+ XComCollection, xcom_collection_item_schema, xcom_collection_schema, xcom_schema,
+)
+from airflow.models import XCom
+from airflow.utils.dates import parse_execution_date
+from airflow.utils.session import create_session, provide_session
+
+
+class TestXComSchemaBase(unittest.TestCase):
+
+ def setUp(self):
+ """
+ Clear Hanging XComs pre test
+ """
+ with create_session() as session:
+ session.query(XCom).delete()
+
+ def tearDown(self) -> None:
+ """
+ Clear Hanging XComs post test
+ """
+ with create_session() as session:
+ session.query(XCom).delete()
+
+
+class TestXComCollectionItemSchema(TestXComSchemaBase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.default_time = '2005-04-02T21:00:00+00:00'
+ self.default_time_parsed = parse_execution_date(self.default_time)
+
+ @provide_session
+ def test_serialize(self, session):
+ xcom_model = XCom(
+ key='test_key',
+ timestamp=self.default_time_parsed,
+ execution_date=self.default_time_parsed,
+ task_id='test_task_id',
+ dag_id='test_dag',
+ )
+ session.add(xcom_model)
+ session.commit()
+ xcom_model = session.query(XCom).first()
+ deserialized_xcom = xcom_collection_item_schema.dump(xcom_model)
+ self.assertEqual(
+ deserialized_xcom[0],
+ {
+ 'key': 'test_key',
+ 'timestamp': self.default_time,
+ 'execution_date': self.default_time,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ }
+ )
+
+ def test_deserialize(self):
+ xcom_dump = {
+ 'key': 'test_key',
+ 'timestamp': self.default_time,
+ 'execution_date': self.default_time,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ }
+ result = xcom_collection_item_schema.load(xcom_dump)
+ self.assertEqual(
+ result[0],
+ {
+ 'key': 'test_key',
+ 'timestamp': self.default_time_parsed,
+ 'execution_date': self.default_time_parsed,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ }
+ )
+
+
+class TestXComCollectionSchema(TestXComSchemaBase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.default_time_1 = '2005-04-02T21:00:00+00:00'
+ self.default_time_2 = '2005-04-02T21:01:00+00:00'
+ self.time_1 = parse_execution_date(self.default_time_1)
+ self.time_2 = parse_execution_date(self.default_time_2)
+
+ @provide_session
+ def test_serialize(self, session):
+ xcom_model_1 = XCom(
+ key='test_key_1',
+ timestamp=self.time_1,
+ execution_date=self.time_1,
+ task_id='test_task_id_1',
+ dag_id='test_dag_1',
+ )
+ xcom_model_2 = XCom(
+ key='test_key_2',
+ timestamp=self.time_2,
+ execution_date=self.time_2,
+ task_id='test_task_id_2',
+ dag_id='test_dag_2',
+ )
+ xcom_models = [xcom_model_1, xcom_model_2]
+ session.add_all(xcom_models)
+ session.commit()
+ xcom_models_query = session.query(XCom).filter(
+ or_(XCom.execution_date == self.time_1, XCom.execution_date == self.time_2)
+ )
+ xcom_models_queried = xcom_models_query.all()
+ deserialized_xcoms = xcom_collection_schema.dump(XComCollection(
+ xcom_entries=xcom_models_queried,
+ total_entries=xcom_models_query.count(),
+ ))
+ self.assertEqual(
+ deserialized_xcoms[0],
+ {
+ 'xcom_entries': [
+ {
+ 'key': 'test_key_1',
+ 'timestamp': self.default_time_1,
+ 'execution_date': self.default_time_1,
+ 'task_id': 'test_task_id_1',
+ 'dag_id': 'test_dag_1',
+ },
+ {
+ 'key': 'test_key_2',
+ 'timestamp': self.default_time_2,
+ 'execution_date': self.default_time_2,
+ 'task_id': 'test_task_id_2',
+ 'dag_id': 'test_dag_2',
+ }
+ ],
+ 'total_entries': len(xcom_models),
+ }
+ )
+
+
+class TestXComSchema(TestXComSchemaBase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.default_time = '2005-04-02T21:00:00+00:00'
+ self.default_time_parsed = parse_execution_date(self.default_time)
+
+ @provide_session
+ def test_serialize(self, session):
+ xcom_model = XCom(
+ key='test_key',
+ timestamp=self.default_time_parsed,
+ execution_date=self.default_time_parsed,
+ task_id='test_task_id',
+ dag_id='test_dag',
+ value=b'test_binary',
+ )
+ session.add(xcom_model)
+ session.commit()
+ xcom_model = session.query(XCom).first()
+ deserialized_xcom = xcom_schema.dump(xcom_model)
+ self.assertEqual(
+ deserialized_xcom[0],
+ {
+ 'key': 'test_key',
+ 'timestamp': self.default_time,
+ 'execution_date': self.default_time,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ 'value': 'test_binary',
+ }
+ )
+
+ def test_deserialize(self):
+ xcom_dump = {
+ 'key': 'test_key',
+ 'timestamp': self.default_time,
+ 'execution_date': self.default_time,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ 'value': b'test_binary',
+ }
+ result = xcom_schema.load(xcom_dump)
+ self.assertEqual(
+ result[0],
+ {
+ 'key': 'test_key',
+ 'timestamp': self.default_time_parsed,
+ 'execution_date': self.default_time_parsed,
+ 'task_id': 'test_task_id',
+ 'dag_id': 'test_dag',
+ 'value': 'test_binary',
+ }
+ )