You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/09/16 12:27:28 UTC

[airflow] branch main updated: Make `XCom.get_one` return full, not abbreviated values (#18274)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f8ba475  Make `XCom.get_one` return full, not abbreviated values (#18274)
f8ba475 is described below

commit f8ba4755ae77f3e08275d18e5df13c368363066b
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Thu Sep 16 13:27:14 2021 +0100

    Make `XCom.get_one` return full, not abbreviated values (#18274)
    
    If you used this class method directly (such as in a custom operator
    link) then the value would _always_ be subject to the
    `orm_deserialize_value` which would likely give the wrong result on
    custom XCom backends.
    
    This wasn't a problem for anyone using `ti.xcom_pull` as it handled this
    directly.
---
 airflow/models/xcom.py    | 30 ++++++++++++++++++++----------
 tests/models/test_xcom.py | 31 ++++++++++++++++++++++++++-----
 2 files changed, 46 insertions(+), 15 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index ef275aa..26dc3bc 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -128,6 +128,9 @@ class BaseXCom(Base, LoggingMixin):
 
         ``run_id`` and ``execution_date`` are mutually exclusive.
 
+        This method returns "full" XCom values (i.e. it uses ``deserialize_value`` from the XCom backend).
+        Please use :meth:`get_many` if you want the "shortened" value via ``orm_deserialize_value``
+
         :param execution_date: Execution date for the task
         :type execution_date: pendulum.datetime
         :param run_id: Dag run id for the task
@@ -151,17 +154,21 @@ class BaseXCom(Base, LoggingMixin):
         if not (execution_date is None) ^ (run_id is None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
-        result = cls.get_many(
-            execution_date=execution_date,
-            run_id=run_id,
-            key=key,
-            task_ids=task_id,
-            dag_ids=dag_id,
-            include_prior_dates=include_prior_dates,
-            session=session,
-        ).first()
+        result = (
+            cls.get_many(
+                execution_date=execution_date,
+                run_id=run_id,
+                key=key,
+                task_ids=task_id,
+                dag_ids=dag_id,
+                include_prior_dates=include_prior_dates,
+                session=session,
+            )
+            .with_entities(cls.value)
+            .first()
+        )
         if result:
-            return result.value
+            return cls.deserialize_value(result)
         return None
 
     @classmethod
@@ -182,6 +189,9 @@ class BaseXCom(Base, LoggingMixin):
 
         ``run_id`` and ``execution_date`` are mutually exclusive.
 
+        This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value then
+        use :meth:`get_one`.
+
         :param execution_date: Execution date for the task
         :type execution_date: pendulum.datetime
         :param run_id: Dag run id for the task
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index 1addd22..35c7e60 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -20,15 +20,14 @@ from unittest import mock
 import pytest
 
 from airflow.configuration import conf
-from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
+from airflow.models.xcom import XCOM_RETURN_KEY, BaseXCom, XCom, resolve_xcom_backend
 from airflow.utils import timezone
 from tests.test_utils.config import conf_vars
 
 
 class CustomXCom(BaseXCom):
-    @staticmethod
-    def serialize_value(_):
-        return "custom_value"
+    def orm_deserialize_value(self):
+        return 'Short value...'
 
 
 class TestXCom:
@@ -36,7 +35,6 @@ class TestXCom:
     def test_resolve_xcom_class(self):
         cls = resolve_xcom_backend()
         assert issubclass(cls, CustomXCom)
-        assert cls().serialize_value(None) == "custom_value"
 
     @conf_vars({("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"})
     def test_resolve_xcom_class_fallback_to_basexcom(self):
@@ -217,3 +215,26 @@ class TestXCom:
 
         instance.init_on_load()
         mock_orm_deserialize.assert_called_once_with()
+
+    @conf_vars({("core", "xcom_backend"): "tests.models.test_xcom.CustomXCom"})
+    def test_get_one_doesnt_use_orm_deserialize_value(self, session):
+        """Test that XCom.get_one does not call orm_deserialize_value"""
+        json_obj = {"key": "value"}
+        execution_date = timezone.utcnow()
+        key = XCOM_RETURN_KEY
+        dag_id = "test_dag"
+        task_id = "test_task"
+
+        XCom = resolve_xcom_backend()
+        XCom.set(
+            key=key,
+            value=json_obj,
+            dag_id=dag_id,
+            task_id=task_id,
+            execution_date=execution_date,
+            session=session,
+        )
+
+        value = XCom.get_one(dag_id=dag_id, task_id=task_id, execution_date=execution_date, session=session)
+
+        assert value == json_obj