You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/09/12 21:05:18 UTC
[airflow] branch main updated: Flag to deserialize value on custom XCom backend (#26343)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ffee6bceb3 Flag to deserialize value on custom XCom backend (#26343)
ffee6bceb3 is described below
commit ffee6bceb32eba159a7a25a4613d573884a6a58d
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Tue Sep 13 05:05:02 2022 +0800
Flag to deserialize value on custom XCom backend (#26343)
---
airflow/api_connexion/endpoints/xcom_endpoint.py | 25 ++++++++--
airflow/api_connexion/openapi/v1.yaml | 17 +++++++
airflow/www/static/js/types/api-generated.ts | 16 ++++++-
.../api_connexion/endpoints/test_xcom_endpoint.py | 53 ++++++++++++++++++----
4 files changed, 96 insertions(+), 15 deletions(-)
diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py
index 62c7262f7e..6114d4d98b 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import copy
from typing import Optional
from flask import g
@@ -68,7 +69,7 @@ def get_xcom_entries(
query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key)
total_entries = query.count()
query = query.offset(offset).limit(limit)
- return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
+ return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries))
@security.requires_access(
@@ -86,14 +87,28 @@ def get_xcom_entry(
task_id: str,
dag_run_id: str,
xcom_key: str,
+ deserialize: bool = False,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get an XCom entry"""
- query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key)
+ if deserialize:
+ query = session.query(XCom, XCom.value)
+ else:
+ query = session.query(XCom)
+
+ query = query.filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
query = query.filter(DR.run_id == dag_run_id)
- query_object = query.one_or_none()
- if not query_object:
+ item = query.one_or_none()
+ if item is None:
raise NotFound("XCom entry not found")
- return xcom_schema.dump(query_object)
+
+ if deserialize:
+ xcom, value = item
+ stub = copy.copy(xcom)
+ stub.value = value
+ stub.value = XCom.deserialize_value(stub)
+ item = stub
+
+ return xcom_schema.dump(item)
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index d4685c0e0b..2cd55b00e9 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1412,6 +1412,23 @@ paths:
x-openapi-router-controller: airflow.api_connexion.endpoints.xcom_endpoint
operationId: get_xcom_entry
tags: [XCom]
+ parameters:
+ - in: query
+ name: deserialize
+ schema:
+ type: boolean
+ default: false
+ required: false
+ description: |
+ Whether to deserialize an XCom value when using a custom XCom backend.
+
+ The XCom API endpoint calls `orm_deserialize_value` by default since an XCom may contain value
+ that is potentially expensive to deserialize in the web server. Setting this to true overrides
+ the consideration, and calls `deserialize_value` instead.
+
+ This parameter is not meaningful when using the default XCom backend.
+
+ *New in version 2.5.0*
responses:
'200':
description: Success.
diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts
index 8443c169aa..50bfb3c653 100644
--- a/airflow/www/static/js/types/api-generated.ts
+++ b/airflow/www/static/js/types/api-generated.ts
@@ -3413,6 +3413,20 @@ export interface operations {
/** The XCom key. */
xcom_key: components["parameters"]["XComKey"];
};
+ query: {
+ /**
+ * Whether to deserialize an XCom value when using a custom XCom backend.
+ *
+ * The XCom API endpoint calls `orm_deserialize_value` by default since an XCom may contain value
+ * that is potentially expensive to deserialize in the web server. Setting this to true overrides
+ * the consideration, and calls `deserialize_value` instead.
+ *
+ * This parameter is not meaningful when using the default XCom backend.
+ *
+ * *New in version 2.5.0*
+ */
+ deserialize?: boolean;
+ };
};
responses: {
/** Success. */
@@ -4221,7 +4235,7 @@ export type GetVariableVariables = CamelCasedPropertiesDeep<operations['get_vari
export type DeleteVariableVariables = CamelCasedPropertiesDeep<operations['delete_variable']['parameters']['path']>;
export type PatchVariableVariables = CamelCasedPropertiesDeep<operations['patch_variable']['parameters']['path'] & operations['patch_variable']['parameters']['query'] & operations['patch_variable']['requestBody']['content']['application/json']>;
export type GetXcomEntriesVariables = CamelCasedPropertiesDeep<operations['get_xcom_entries']['parameters']['path'] & operations['get_xcom_entries']['parameters']['query']>;
-export type GetXcomEntryVariables = CamelCasedPropertiesDeep<operations['get_xcom_entry']['parameters']['path']>;
+export type GetXcomEntryVariables = CamelCasedPropertiesDeep<operations['get_xcom_entry']['parameters']['path'] & operations['get_xcom_entry']['parameters']['query']>;
export type GetExtraLinksVariables = CamelCasedPropertiesDeep<operations['get_extra_links']['parameters']['path']>;
export type GetLogVariables = CamelCasedPropertiesDeep<operations['get_log']['parameters']['path'] & operations['get_log']['parameters']['query']>;
export type GetDagDetailsVariables = CamelCasedPropertiesDeep<operations['get_dag_details']['parameters']['path']>;
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index efcba32711..7bc5b51439 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -15,20 +15,34 @@
# specific language governing permissions and limitations
# under the License.
from datetime import timedelta
+from unittest import mock
import pytest
-from parameterized import parameterized
-from airflow.models import DagModel, DagRun, TaskInstance, XCom
+from airflow.models.dag import DagModel
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.security import permissions
from airflow.utils.dates import parse_execution_date
from airflow.utils.session import create_session
+from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
+from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
+class CustomXCom(BaseXCom):
+ @classmethod
+ def deserialize_value(cls, xcom: XCom):
+ return f"real deserialized {super().deserialize_value(xcom)}"
+
+ def orm_deserialize_value(self):
+ return f"orm deserialized {super().orm_deserialize_value()}"
+
+
@pytest.fixture(scope="module")
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
@@ -145,7 +159,7 @@ class TestGetXComEntry(TestXComEndpoint):
)
assert response.status_code == 403
- def _create_xcom_entry(self, dag_id, run_id, execution_date, task_id, xcom_key):
+ def _create_xcom_entry(self, dag_id, run_id, execution_date, task_id, xcom_key, *, backend=XCom):
with create_session() as session:
dagrun = DagRun(
dag_id=dag_id,
@@ -158,7 +172,7 @@ class TestGetXComEntry(TestXComEndpoint):
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
- XCom.set(
+ backend.set(
key=xcom_key,
value="TEST_VALUE",
run_id=run_id,
@@ -166,6 +180,26 @@ class TestGetXComEntry(TestXComEndpoint):
dag_id=dag_id,
)
+ @pytest.mark.parametrize(
+ "query, expected_value",
+ [
+ pytest.param("?deserialize=true", "real deserialized TEST_VALUE", id="true"),
+ pytest.param("?deserialize=false", "orm deserialized TEST_VALUE", id="false"),
+ pytest.param("", "orm deserialized TEST_VALUE", id="default"),
+ ],
+ )
+ @conf_vars({("core", "xcom_backend"): "tests.api_connexion.endpoints.test_xcom_endpoint.CustomXCom"})
+ def test_custom_xcom_deserialize(self, query, expected_value):
+ XCom = resolve_xcom_backend()
+ self._create_xcom_entry("dag", "run", utcnow(), "task", "key", backend=XCom)
+
+ url = f"/api/v1/dags/dag/dagRuns/run/taskInstances/task/xcomEntries/key{query}"
+ with mock.patch("airflow.api_connexion.endpoints.xcom_endpoint.XCom", XCom):
+ response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
+
+ assert response.status_code == 200
+ assert response.json["value"] == expected_value
+
class TestGetXComEntries(TestXComEndpoint):
def test_should_respond_200(self):
@@ -386,7 +420,8 @@ class TestPaginationGetXComEntries(TestXComEndpoint):
self.execution_date_parsed = parse_execution_date(self.execution_date)
self.run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed)
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "query_params, expected_xcom_ids",
[
(
"limit=1",
@@ -433,12 +468,12 @@ class TestPaginationGetXComEntries(TestXComEndpoint):
"limit=2&offset=2",
["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"],
),
- ]
+ ],
)
def test_handle_limit_offset(self, query_params, expected_xcom_ids):
- url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}"
- url = url.format(
- dag_id=self.dag_id, dag_run_id=self.run_id, task_id=self.task_id, query_params=query_params
+ url = (
+ f"/api/v1/dags/{self.dag_id}/dagRuns/{self.run_id}/taskInstances/{self.task_id}/xcomEntries"
+ f"?{query_params}"
)
with create_session() as session:
dagrun = DagRun(