You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/04 20:45:03 UTC

[airflow] branch main updated: Serialize pod_override to JSON before pickling executor_config (#24356)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c1d621c7ce Serialize pod_override to JSON before pickling executor_config (#24356)
c1d621c7ce is described below

commit c1d621c7ce352cb900ff5fb7da214e1fbcf0a15f
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Mon Jul 4 13:44:55 2022 -0700

    Serialize pod_override to JSON before pickling executor_config (#24356)
    
    * Serialize pod_override to JSON before pickling executor_config
    
    If we unpickle a k8s object that was pickled under an earlier k8s library version, then the unpickled object may throw an error when to_dict is called.  To be more tolerant of version changes we convert to JSON using Airflow's serializer before pickling.
---
 airflow/models/taskinstance.py    | 25 ++++-------
 airflow/utils/sqlalchemy.py       | 63 ++++++++++++++++++++++++++-
 tests/models/test_taskinstance.py | 21 +--------
 tests/utils/test_sqlalchemy.py    | 90 ++++++++++++++++++++++++++++++++++++++-
 4 files changed, 160 insertions(+), 39 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b02241abe0..8950385248 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -58,7 +58,6 @@ from sqlalchemy import (
     ForeignKeyConstraint,
     Index,
     Integer,
-    PickleType,
     String,
     and_,
     false,
@@ -120,7 +119,13 @@ from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.retries import run_with_db_retries
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import (
+    ExecutorConfigType,
+    ExtendedJSON,
+    UtcDateTime,
+    tuple_in_condition,
+    with_row_locks,
+)
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.timeout import timeout
 
@@ -401,20 +406,6 @@ class TaskInstanceKey(NamedTuple):
         return self
 
 
-def _executor_config_comparator(x, y):
-    """
-    The TaskInstance.executor_config attribute is a pickled object that may contain
-    kubernetes objects.  If the installed library version has changed since the
-    object was originally pickled, due to the underlying ``__eq__`` method on these
-    objects (which converts them to JSON), we may encounter attribute errors. In this
-    case we should replace the stored object.
-    """
-    try:
-        return x == y
-    except AttributeError:
-        return False
-
-
 class TaskInstance(Base, LoggingMixin):
     """
     Task instances store the state of a task instance. This table is the
@@ -457,7 +448,7 @@ class TaskInstance(Base, LoggingMixin):
     queued_dttm = Column(UtcDateTime)
     queued_by_job_id = Column(Integer)
     pid = Column(Integer)
-    executor_config = Column(PickleType(pickler=dill, comparator=_executor_config_comparator))
+    executor_config = Column(ExecutorConfigType(pickler=dill))
 
     external_executor_id = Column(StringID())
 
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 5dd47d157b..ab0bc5dff8 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -23,7 +23,7 @@ from typing import Any, Dict, Iterable, Tuple
 
 import pendulum
 from dateutil import relativedelta
-from sqlalchemy import TIMESTAMP, and_, event, false, nullsfirst, or_, tuple_
+from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, tuple_
 from sqlalchemy.dialects import mssql, mysql
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import Session
@@ -33,6 +33,7 @@ from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
 
 from airflow import settings
 from airflow.configuration import conf
+from airflow.serialization.enums import Encoding
 
 log = logging.getLogger(__name__)
 
@@ -146,6 +147,66 @@ class ExtendedJSON(TypeDecorator):
         return BaseSerialization._deserialize(value)
 
 
+class ExecutorConfigType(PickleType):
+    """
+    Adds special handling for K8s executor config. If we unpickle a k8s object that was
+    pickled under an earlier k8s library version, then the unpickled object may throw an error
+    when to_dict is called.  To be more tolerant of version changes we convert to JSON using
+    Airflow's serializer before pickling.
+    """
+
+    def bind_processor(self, dialect):
+
+        from airflow.serialization.serialized_objects import BaseSerialization
+
+        super_process = super().bind_processor(dialect)
+
+        def process(value):
+            if isinstance(value, dict) and 'pod_override' in value:
+                value['pod_override'] = BaseSerialization()._serialize(value['pod_override'])
+            return super_process(value)
+
+        return process
+
+    def result_processor(self, dialect, coltype):
+        from airflow.serialization.serialized_objects import BaseSerialization
+
+        super_process = super().result_processor(dialect, coltype)
+
+        def process(value):
+            value = super_process(value)  # unpickle
+
+            if isinstance(value, dict) and 'pod_override' in value:
+                pod_override = value['pod_override']
+
+                # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
+                if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
+                    value['pod_override'] = BaseSerialization()._deserialize(pod_override)
+            return value
+
+        return process
+
+    def compare_values(self, x, y):
+        """
+        The TaskInstance.executor_config attribute is a pickled object that may contain
+        kubernetes objects.  If the installed library version has changed since the
+        object was originally pickled, due to the underlying ``__eq__`` method on these
+        objects (which converts them to JSON), we may encounter attribute errors. In this
+        case we should replace the stored object.
+
+        From https://github.com/apache/airflow/pull/24356 we use our serializer to store
+        k8s objects, but there could still be raw pickled k8s objects in the database,
+        stored from earlier version, so we still compare them defensively here.
+        """
+        if self.comparator:
+            return self.comparator(x, y)
+        else:
+            try:
+                return x == y
+            except AttributeError:
+                return False
+
+
 class Interval(TypeDecorator):
     """Base class representing a time interval."""
 
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 71c6c2f8c4..da3d138306 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -57,7 +57,7 @@ from airflow.models import (
 )
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
-from airflow.models.taskinstance import TaskInstance, _executor_config_comparator
+from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XCOM_RETURN_KEY
 from airflow.operators.bash import BashOperator
@@ -2848,22 +2848,3 @@ def test_expand_non_templated_field(dag_maker, session):
 
     echo_task = dag.get_task("echo")
     assert "get_extra_env" in echo_task.upstream_task_ids
-
-
-def test_executor_config_comparator():
-    """
-    When comparison raises AttributeError, return False.
-    This can happen when executor config contains kubernetes objects pickled
-    under older kubernetes library version.
-    """
-
-    class MockAttrError:
-        def __eq__(self, other):
-            raise AttributeError('hello')
-
-    a = MockAttrError()
-    with pytest.raises(AttributeError):
-        # just verify for ourselves that this throws
-        assert a == a
-    assert _executor_config_comparator(a, a) is False
-    assert _executor_config_comparator('a', 'a') is True
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index 250bd37756..0038577b8a 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -17,20 +17,29 @@
 # under the License.
 #
 import datetime
+import pickle
 import unittest
+from copy import copy
 from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
+from kubernetes.client import models as k8s
 from parameterized import parameterized
+from pytest import param
 from sqlalchemy.exc import StatementError
 
 from airflow import settings
 from airflow.models import DAG
+from airflow.serialization.enums import Encoding
+from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.settings import Session
-from airflow.utils.sqlalchemy import nowait, prohibit_commit, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import ExecutorConfigType, nowait, prohibit_commit, skip_locked, with_row_locks
 from airflow.utils.state import State
 from airflow.utils.timezone import utcnow
 
+TEST_POD = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
+
 
 class TestSqlAlchemyUtils(unittest.TestCase):
     def setUp(self):
@@ -226,3 +235,82 @@ class TestSqlAlchemyUtils(unittest.TestCase):
     def tearDown(self):
         self.session.close()
         settings.engine.dispose()
+
+
+class TestExecutorConfigType:
+    @pytest.mark.parametrize(
+        'input',
+        ['anything', {'pod_override': TEST_POD}],
+    )
+    def test_bind_processor(self, input):
+        """
+        The returned bind processor should pickle the object as is, unless it is a dictionary with
+        a pod_override node, in which case it should run it through BaseSerialization.
+        """
+        config_type = ExecutorConfigType()
+        mock_dialect = MagicMock()
+        mock_dialect.dbapi = None
+        process = config_type.bind_processor(mock_dialect)
+        expected = copy(input)
+        if 'pod_override' in input:
+            expected['pod_override'] = BaseSerialization()._serialize(input['pod_override'])
+        assert pickle.loads(process(input)) == expected
+
+    @pytest.mark.parametrize(
+        'input',
+        [
+            param(
+                pickle.dumps('anything'),
+                id='anything',
+            ),
+            param(
+                pickle.dumps({'pod_override': BaseSerialization()._serialize(TEST_POD)}),
+                id='serialized_pod',
+            ),
+            param(
+                pickle.dumps({'pod_override': TEST_POD}),
+                id='old_pickled_raw_pod',
+            ),
+            param(
+                pickle.dumps({'pod_override': {"name": "hi"}}),
+                id='arbitrary_dict',
+            ),
+        ],
+    )
+    def test_result_processor(self, input):
+        """
+        The returned bind processor should pickle the object as is, unless it is a dictionary with
+        a pod_override node whose value was serialized with BaseSerialization.
+        """
+        config_type = ExecutorConfigType()
+        mock_dialect = MagicMock()
+        mock_dialect.dbapi = None
+        process = config_type.result_processor(mock_dialect, None)
+        result = process(input)
+        expected = pickle.loads(input)
+        pod_override = isinstance(expected, dict) and expected.get('pod_override')
+        if pod_override and isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
+            # We should only deserialize a pod_override with BaseSerialization if
+            # it was serialized with BaseSerialization (which is the behavior added in #24356
+            expected['pod_override'] = BaseSerialization()._deserialize(expected['pod_override'])
+        assert result == expected
+
+    def test_compare_values(self):
+        """
+        When comparison raises AttributeError, return False.
+        This can happen when executor config contains kubernetes objects pickled
+        under older kubernetes library version.
+        """
+
+        class MockAttrError:
+            def __eq__(self, other):
+                raise AttributeError('hello')
+
+        a = MockAttrError()
+        with pytest.raises(AttributeError):
+            # just verify for ourselves that comparing directly will throw AttributeError
+            assert a == a
+
+        instance = ExecutorConfigType()
+        assert instance.compare_values(a, a) is False
+        assert instance.compare_values('a', 'a') is True