You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/09/12 21:05:18 UTC

[airflow] branch main updated: Flag to deserialize value on custom XCom backend (#26343)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ffee6bceb3 Flag to deserialize value on custom XCom backend (#26343)
ffee6bceb3 is described below

commit ffee6bceb32eba159a7a25a4613d573884a6a58d
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Tue Sep 13 05:05:02 2022 +0800

    Flag to deserialize value on custom XCom backend (#26343)
---
 airflow/api_connexion/endpoints/xcom_endpoint.py   | 25 ++++++++--
 airflow/api_connexion/openapi/v1.yaml              | 17 +++++++
 airflow/www/static/js/types/api-generated.ts       | 16 ++++++-
 .../api_connexion/endpoints/test_xcom_endpoint.py  | 53 ++++++++++++++++++----
 4 files changed, 96 insertions(+), 15 deletions(-)

diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py
index 62c7262f7e..6114d4d98b 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import copy
 from typing import Optional
 
 from flask import g
@@ -68,7 +69,7 @@ def get_xcom_entries(
     query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key)
     total_entries = query.count()
     query = query.offset(offset).limit(limit)
-    return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
+    return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries))
 
 
 @security.requires_access(
@@ -86,14 +87,28 @@ def get_xcom_entry(
     task_id: str,
     dag_run_id: str,
     xcom_key: str,
+    deserialize: bool = False,
     session: Session = NEW_SESSION,
 ) -> APIResponse:
     """Get an XCom entry"""
-    query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key)
+    if deserialize:
+        query = session.query(XCom, XCom.value)
+    else:
+        query = session.query(XCom)
+
+    query = query.filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key)
     query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
     query = query.filter(DR.run_id == dag_run_id)
 
-    query_object = query.one_or_none()
-    if not query_object:
+    item = query.one_or_none()
+    if item is None:
         raise NotFound("XCom entry not found")
-    return xcom_schema.dump(query_object)
+
+    if deserialize:
+        xcom, value = item
+        stub = copy.copy(xcom)
+        stub.value = value
+        stub.value = XCom.deserialize_value(stub)
+        item = stub
+
+    return xcom_schema.dump(item)
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index d4685c0e0b..2cd55b00e9 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1412,6 +1412,23 @@ paths:
       x-openapi-router-controller: airflow.api_connexion.endpoints.xcom_endpoint
       operationId: get_xcom_entry
       tags: [XCom]
+      parameters:
+        - in: query
+          name: deserialize
+          schema:
+            type: boolean
+            default: false
+          required: false
+          description: |
+            Whether to deserialize an XCom value when using a custom XCom backend.
+
+            The XCom API endpoint calls `orm_deserialize_value` by default since an XCom may contain value
+            that is potentially expensive to deserialize in the web server. Setting this to true overrides
+            the consideration, and calls `deserialize_value` instead.
+
+            This parameter is not meaningful when using the default XCom backend.
+
+            *New in version 2.5.0*
       responses:
         '200':
           description: Success.
diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts
index 8443c169aa..50bfb3c653 100644
--- a/airflow/www/static/js/types/api-generated.ts
+++ b/airflow/www/static/js/types/api-generated.ts
@@ -3413,6 +3413,20 @@ export interface operations {
         /** The XCom key. */
         xcom_key: components["parameters"]["XComKey"];
       };
+      query: {
+        /**
+         * Whether to deserialize an XCom value when using a custom XCom backend.
+         *
+         * The XCom API endpoint calls `orm_deserialize_value` by default since an XCom may contain value
+         * that is potentially expensive to deserialize in the web server. Setting this to true overrides
+         * the consideration, and calls `deserialize_value` instead.
+         *
+         * This parameter is not meaningful when using the default XCom backend.
+         *
+         * *New in version 2.5.0*
+         */
+        deserialize?: boolean;
+      };
     };
     responses: {
       /** Success. */
@@ -4221,7 +4235,7 @@ export type GetVariableVariables = CamelCasedPropertiesDeep<operations['get_vari
 export type DeleteVariableVariables = CamelCasedPropertiesDeep<operations['delete_variable']['parameters']['path']>;
 export type PatchVariableVariables = CamelCasedPropertiesDeep<operations['patch_variable']['parameters']['path'] & operations['patch_variable']['parameters']['query'] & operations['patch_variable']['requestBody']['content']['application/json']>;
 export type GetXcomEntriesVariables = CamelCasedPropertiesDeep<operations['get_xcom_entries']['parameters']['path'] & operations['get_xcom_entries']['parameters']['query']>;
-export type GetXcomEntryVariables = CamelCasedPropertiesDeep<operations['get_xcom_entry']['parameters']['path']>;
+export type GetXcomEntryVariables = CamelCasedPropertiesDeep<operations['get_xcom_entry']['parameters']['path'] & operations['get_xcom_entry']['parameters']['query']>;
 export type GetExtraLinksVariables = CamelCasedPropertiesDeep<operations['get_extra_links']['parameters']['path']>;
 export type GetLogVariables = CamelCasedPropertiesDeep<operations['get_log']['parameters']['path'] & operations['get_log']['parameters']['query']>;
 export type GetDagDetailsVariables = CamelCasedPropertiesDeep<operations['get_dag_details']['parameters']['path']>;
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index efcba32711..7bc5b51439 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -15,20 +15,34 @@
 # specific language governing permissions and limitations
 # under the License.
 from datetime import timedelta
+from unittest import mock
 
 import pytest
-from parameterized import parameterized
 
-from airflow.models import DagModel, DagRun, TaskInstance, XCom
+from airflow.models.dag import DagModel
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
 from airflow.operators.empty import EmptyOperator
 from airflow.security import permissions
 from airflow.utils.dates import parse_execution_date
 from airflow.utils.session import create_session
+from airflow.utils.timezone import utcnow
 from airflow.utils.types import DagRunType
 from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
+from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
 
 
+class CustomXCom(BaseXCom):
+    @classmethod
+    def deserialize_value(cls, xcom: XCom):
+        return f"real deserialized {super().deserialize_value(xcom)}"
+
+    def orm_deserialize_value(self):
+        return f"orm deserialized {super().orm_deserialize_value()}"
+
+
 @pytest.fixture(scope="module")
 def configured_app(minimal_app_for_api):
     app = minimal_app_for_api
@@ -145,7 +159,7 @@ class TestGetXComEntry(TestXComEndpoint):
         )
         assert response.status_code == 403
 
-    def _create_xcom_entry(self, dag_id, run_id, execution_date, task_id, xcom_key):
+    def _create_xcom_entry(self, dag_id, run_id, execution_date, task_id, xcom_key, *, backend=XCom):
         with create_session() as session:
             dagrun = DagRun(
                 dag_id=dag_id,
@@ -158,7 +172,7 @@ class TestGetXComEntry(TestXComEndpoint):
             ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
             ti.dag_id = dag_id
             session.add(ti)
-        XCom.set(
+        backend.set(
             key=xcom_key,
             value="TEST_VALUE",
             run_id=run_id,
@@ -166,6 +180,26 @@ class TestGetXComEntry(TestXComEndpoint):
             dag_id=dag_id,
         )
 
+    @pytest.mark.parametrize(
+        "query, expected_value",
+        [
+            pytest.param("?deserialize=true", "real deserialized TEST_VALUE", id="true"),
+            pytest.param("?deserialize=false", "orm deserialized TEST_VALUE", id="false"),
+            pytest.param("", "orm deserialized TEST_VALUE", id="default"),
+        ],
+    )
+    @conf_vars({("core", "xcom_backend"): "tests.api_connexion.endpoints.test_xcom_endpoint.CustomXCom"})
+    def test_custom_xcom_deserialize(self, query, expected_value):
+        XCom = resolve_xcom_backend()
+        self._create_xcom_entry("dag", "run", utcnow(), "task", "key", backend=XCom)
+
+        url = f"/api/v1/dags/dag/dagRuns/run/taskInstances/task/xcomEntries/key{query}"
+        with mock.patch("airflow.api_connexion.endpoints.xcom_endpoint.XCom", XCom):
+            response = self.client.get(url, environ_overrides={'REMOTE_USER': "test"})
+
+        assert response.status_code == 200
+        assert response.json["value"] == expected_value
+
 
 class TestGetXComEntries(TestXComEndpoint):
     def test_should_respond_200(self):
@@ -386,7 +420,8 @@ class TestPaginationGetXComEntries(TestXComEndpoint):
         self.execution_date_parsed = parse_execution_date(self.execution_date)
         self.run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed)
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "query_params, expected_xcom_ids",
         [
             (
                 "limit=1",
@@ -433,12 +468,12 @@ class TestPaginationGetXComEntries(TestXComEndpoint):
                 "limit=2&offset=2",
                 ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"],
             ),
-        ]
+        ],
     )
     def test_handle_limit_offset(self, query_params, expected_xcom_ids):
-        url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}"
-        url = url.format(
-            dag_id=self.dag_id, dag_run_id=self.run_id, task_id=self.task_id, query_params=query_params
+        url = (
+            f"/api/v1/dags/{self.dag_id}/dagRuns/{self.run_id}/taskInstances/{self.task_id}/xcomEntries"
+            f"?{query_params}"
         )
         with create_session() as session:
             dagrun = DagRun(