You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/01/11 21:25:46 UTC

[airflow] 05/27: Fix bad pods pickled in executor_config (#28454)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 2e545122b6d5dd4741901eb098366e367026a3de
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Sun Dec 25 23:56:13 2022 -0800

    Fix bad pods pickled in executor_config (#28454)
    
    We used to pickle raw pods objects but found that when unpickling across k8s lib versions we would get missing attr errors.
    
    Now, we serialize to json.
    
    But we still get reports of issues when people upgrade because it only solves the issue on a go-forward basis.
    
    But we can fix these old bad executor configs that keep popping up by roundtripping the pod to json in a more tolerant fashion than is done by the openapi-generated code, i.e. by populating missing attrs with None.
    
    (cherry picked from commit 27f07b0bf5ed088c4186296668a36dc89da25617)
---
 airflow/utils/sqlalchemy.py    | 101 ++++++++++++++++++++++++++++++++++++++++-
 tests/utils/test_sqlalchemy.py |  53 ++++++++++++++++++++-
 2 files changed, 151 insertions(+), 3 deletions(-)

diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 08a4b890b0..1ee482cc51 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -21,7 +21,7 @@ import copy
 import datetime
 import json
 import logging
-from typing import Any, Iterable
+from typing import TYPE_CHECKING, Any, Iterable
 
 import pendulum
 from dateutil import relativedelta
@@ -37,6 +37,9 @@ from airflow import settings
 from airflow.configuration import conf
 from airflow.serialization.enums import Encoding
 
+if TYPE_CHECKING:
+    from kubernetes.client.models.v1_pod import V1Pod
+
 log = logging.getLogger(__name__)
 
 utc = pendulum.tz.timezone("UTC")
@@ -153,6 +156,93 @@ class ExtendedJSON(TypeDecorator):
         return BaseSerialization.deserialize(value)
 
 
+def sanitize_for_serialization(obj: V1Pod):
+    """
+    Convert pod to dict.... but *safely*.
+
+    When pod objects created with one k8s version are unpickled in a python
+    env with a more recent k8s version (in which the object attrs may have
+    changed) the unpickled obj may throw an error because the attr
+    expected on new obj may not be there on the unpickled obj.
+
+    This function still converts the pod to a dict; the only difference is
+    it populates missing attrs with None. You may compare with
+    https://github.com/kubernetes-client/python/blob/5a96bbcbe21a552cc1f9cda13e0522fafb0dbac8/kubernetes/client/api_client.py#L202
+
+    If obj is None, return None.
+    If obj is str, int, long, float, bool, return directly.
+    If obj is datetime.datetime, datetime.date
+        convert to string in iso8601 format.
+    If obj is list, sanitize each element in the list.
+    If obj is dict, return the dict.
+    If obj is OpenAPI model, return the properties dict.
+
+    :param obj: The data to serialize.
+    :return: The serialized form of data.
+
+    :meta private:
+    """
+    if obj is None:
+        return None
+    elif isinstance(obj, (float, bool, bytes, str, int)):
+        return obj
+    elif isinstance(obj, list):
+        return [sanitize_for_serialization(sub_obj) for sub_obj in obj]
+    elif isinstance(obj, tuple):
+        return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj)
+    elif isinstance(obj, (datetime.datetime, datetime.date)):
+        return obj.isoformat()
+
+    if isinstance(obj, dict):
+        obj_dict = obj
+    else:
+        obj_dict = {
+            obj.attribute_map[attr]: getattr(obj, attr)
+            for attr, _ in obj.openapi_types.items()
+            # below is the only line we change, and we just add default=None for getattr
+            if getattr(obj, attr, None) is not None
+        }
+
+    return {key: sanitize_for_serialization(val) for key, val in obj_dict.items()}
+
+
+def ensure_pod_is_valid_after_unpickling(pod: V1Pod) -> V1Pod | None:
+    """
+    Convert pod to json and back so that pod is safe.
+
+    The pod_override in executor_config is a V1Pod object.
+    Such objects created with one k8s version, when unpickled in
+    an env with upgraded k8s version, may blow up when
+    `to_dict` is called, because openapi client code gen calls
+    getattr on all attrs in openapi_types for each object, and when
+    new attrs are added to that list, getattr will fail.
+
+    Here we re-serialize it to ensure it is not going to blow up.
+
+    :meta private:
+    """
+    try:
+        # if to_dict works, the pod is fine
+        pod.to_dict()
+        return pod
+    except AttributeError:
+        pass
+    try:
+        from kubernetes.client.models.v1_pod import V1Pod
+    except ImportError:
+        return None
+    if not isinstance(pod, V1Pod):
+        return None
+    try:
+        from airflow.kubernetes.pod_generator import PodGenerator
+
+        # now we actually reserialize / deserialize the pod
+        pod_dict = sanitize_for_serialization(pod)
+        return PodGenerator.deserialize_model_dict(pod_dict)
+    except Exception:
+        return None
+
+
 class ExecutorConfigType(PickleType):
     """
     Adds special handling for K8s executor config. If we unpickle a k8s object that was
@@ -188,9 +278,16 @@ class ExecutorConfigType(PickleType):
             if isinstance(value, dict) and "pod_override" in value:
                 pod_override = value["pod_override"]
 
-                # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
                 if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
+                    # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
                     value["pod_override"] = BaseSerialization.deserialize(pod_override)
+                else:
+                    # backcompat path
+                    # we no longer pickle raw pods but this code may be reached
+                    # when accessing executor configs created in a prior version
+                    new_pod = ensure_pod_is_valid_after_unpickling(pod_override)
+                    if new_pod:
+                        value["pod_override"] = new_pod
             return value
 
         return process
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index 2650aea6bc..2bf4ad1be5 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import datetime
 import pickle
+from copy import deepcopy
 from unittest import mock
 from unittest.mock import MagicMock
 
@@ -32,7 +33,14 @@ from airflow.models import DAG
 from airflow.serialization.enums import DagAttributeTypes, Encoding
 from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.settings import Session
-from airflow.utils.sqlalchemy import ExecutorConfigType, nowait, prohibit_commit, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import (
+    ExecutorConfigType,
+    ensure_pod_is_valid_after_unpickling,
+    nowait,
+    prohibit_commit,
+    skip_locked,
+    with_row_locks,
+)
 from airflow.utils.state import State
 from airflow.utils.timezone import utcnow
 
@@ -324,3 +332,46 @@ class TestExecutorConfigType:
         instance = ExecutorConfigType()
         assert instance.compare_values(a, a) is False
         assert instance.compare_values("a", "a") is True
+
+    def test_result_processor_bad_pickled_obj(self):
+        """
+        If unpickled obj is missing attrs that curr lib expects
+        """
+        test_container = k8s.V1Container(name="base")
+        test_pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[test_container]))
+        copy_of_test_pod = deepcopy(test_pod)
+        # curr api expects attr `tty`
+        assert "tty" in test_container.openapi_types
+        # it lives in protected attr _tty
+        assert hasattr(test_container, "_tty")
+        # so, let's remove it before pickling, to simulate what happens in real life
+        del test_container._tty
+        # now let's prove that this blows up when calling to_dict
+        with pytest.raises(AttributeError):
+            test_pod.to_dict()
+        # no such problem with the copy
+        assert copy_of_test_pod.to_dict()
+        # so we need to roundtrip it through json
+        fixed_pod = ensure_pod_is_valid_after_unpickling(test_pod)
+        # and, since the missing attr was None anyway, we actually have the same pod
+        assert fixed_pod.to_dict() == copy_of_test_pod.to_dict()
+
+        # now, let's verify that result processor makes this all work
+        # first, check that bad pod is still bad
+        with pytest.raises(AttributeError):
+            test_pod.to_dict()
+        # define what will be retrieved from db
+        input = pickle.dumps({"pod_override": TEST_POD})
+
+        # get the result processor method
+        config_type = ExecutorConfigType()
+        mock_dialect = MagicMock()
+        mock_dialect.dbapi = None
+        process = config_type.result_processor(mock_dialect, None)
+
+        # apply the result processor
+        result = process(input)
+
+        # show that the pickled (bad) pod is now a good pod, and same as the copy made
+        # before making it bad
+        assert result["pod_override"].to_dict() == copy_of_test_pod.to_dict()