You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by hu...@apache.org on 2023/12/16 20:08:48 UTC

(airflow) branch main updated: Stop deserializing pickle when enable_xcom_pickling is False (#36255)

This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 63e97abec5 Stop deserializing pickle when enable_xcom_pickling is False (#36255)
63e97abec5 is described below

commit 63e97abec5d56bc62a293c93f5227f364561e51c
Author: Hussein Awala <hu...@awala.fr>
AuthorDate: Sat Dec 16 21:08:41 2023 +0100

    Stop deserializing pickle when enable_xcom_pickling is False (#36255)
    
    * Stop deserializing pickle when enable_xcom_pickling is False
    
    * Fix unit tests
---
 airflow/models/xcom.py                          |  6 ++----
 tests/api_connexion/schemas/test_xcom_schema.py |  3 +++
 tests/models/test_xcom.py                       | 18 +++++++++---------
 3 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 23d33e268d..a55fe99b37 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -685,10 +685,8 @@ class BaseXCom(Base, LoggingMixin):
             except pickle.UnpicklingError:
                 return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
         else:
-            try:
-                return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
-            except (json.JSONDecodeError, UnicodeDecodeError):
-                return pickle.loads(result.value)
+            # Since xcom_pickling is disabled, we should only try to deserialize with JSON
+            return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
 
     @staticmethod
     def deserialize_value(result: XCom) -> Any:
diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py
index 7d83cdcc57..f3a373e0ad 100644
--- a/tests/api_connexion/schemas/test_xcom_schema.py
+++ b/tests/api_connexion/schemas/test_xcom_schema.py
@@ -30,6 +30,7 @@ from airflow.api_connexion.schemas.xcom_schema import (
 from airflow.models import DagRun, XCom
 from airflow.utils.dates import parse_execution_date
 from airflow.utils.session import create_session
+from tests.test_utils.config import conf_vars
 
 pytestmark = pytest.mark.db_test
 
@@ -188,6 +189,7 @@ class TestXComSchema:
     default_time = "2016-04-02T21:00:00+00:00"
     default_time_parsed = parse_execution_date(default_time)
 
+    @conf_vars({("core", "enable_xcom_pickling"): "True"})
     def test_serialize(self, create_xcom, session):
         create_xcom(
             dag_id="test_dag",
@@ -208,6 +210,7 @@ class TestXComSchema:
             "map_index": -1,
         }
 
+    @conf_vars({("core", "enable_xcom_pickling"): "True"})
     def test_deserialize(self):
         xcom_dump = {
             "key": "test_key",
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index db290c0e85..8ab7d4eb56 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -140,7 +140,7 @@ class TestXCom:
             ret_value = XCom.get_value(key="xcom_test3", ti_key=ti_key, session=session)
         assert ret_value == {"key": "value"}
 
-    def test_xcom_deserialize_with_pickle_to_json_switch(self, task_instance, session):
+    def test_xcom_deserialize_pickle_when_xcom_pickling_is_disabled(self, task_instance, session):
         with conf_vars({("core", "enable_xcom_pickling"): "True"}):
             XCom.set(
                 key="xcom_test3",
@@ -151,14 +151,14 @@ class TestXCom:
                 session=session,
             )
         with conf_vars({("core", "enable_xcom_pickling"): "False"}):
-            ret_value = XCom.get_one(
-                key="xcom_test3",
-                dag_id=task_instance.dag_id,
-                task_id=task_instance.task_id,
-                run_id=task_instance.run_id,
-                session=session,
-            )
-        assert ret_value == {"key": "value"}
+            with pytest.raises(UnicodeDecodeError):
+                XCom.get_one(
+                    key="xcom_test3",
+                    dag_id=task_instance.dag_id,
+                    task_id=task_instance.task_id,
+                    run_id=task_instance.run_id,
+                    session=session,
+                )
 
     @conf_vars({("core", "xcom_enable_pickling"): "False"})
     def test_xcom_disable_pickle_type_fail_on_non_json(self, task_instance, session):