You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/04 20:45:03 UTC
[airflow] branch main updated: Serialize pod_override to JSON before pickling executor_config (#24356)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c1d621c7ce Serialize pod_override to JSON before pickling executor_config (#24356)
c1d621c7ce is described below
commit c1d621c7ce352cb900ff5fb7da214e1fbcf0a15f
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Mon Jul 4 13:44:55 2022 -0700
Serialize pod_override to JSON before pickling executor_config (#24356)
* Serialize pod_override to JSON before pickling executor_config
If we unpickle a k8s object that was pickled under an earlier k8s library version, then the unpickled object may throw an error when to_dict is called. To be more tolerant of version changes we convert to JSON using Airflow's serializer before pickling.
---
airflow/models/taskinstance.py | 25 ++++-------
airflow/utils/sqlalchemy.py | 63 ++++++++++++++++++++++++++-
tests/models/test_taskinstance.py | 21 +--------
tests/utils/test_sqlalchemy.py | 90 ++++++++++++++++++++++++++++++++++++++-
4 files changed, 160 insertions(+), 39 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b02241abe0..8950385248 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -58,7 +58,6 @@ from sqlalchemy import (
ForeignKeyConstraint,
Index,
Integer,
- PickleType,
String,
and_,
false,
@@ -120,7 +119,13 @@ from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import (
+ ExecutorConfigType,
+ ExtendedJSON,
+ UtcDateTime,
+ tuple_in_condition,
+ with_row_locks,
+)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.timeout import timeout
@@ -401,20 +406,6 @@ class TaskInstanceKey(NamedTuple):
return self
-def _executor_config_comparator(x, y):
- """
- The TaskInstance.executor_config attribute is a pickled object that may contain
- kubernetes objects. If the installed library version has changed since the
- object was originally pickled, due to the underlying ``__eq__`` method on these
- objects (which converts them to JSON), we may encounter attribute errors. In this
- case we should replace the stored object.
- """
- try:
- return x == y
- except AttributeError:
- return False
-
-
class TaskInstance(Base, LoggingMixin):
"""
Task instances store the state of a task instance. This table is the
@@ -457,7 +448,7 @@ class TaskInstance(Base, LoggingMixin):
queued_dttm = Column(UtcDateTime)
queued_by_job_id = Column(Integer)
pid = Column(Integer)
- executor_config = Column(PickleType(pickler=dill, comparator=_executor_config_comparator))
+ executor_config = Column(ExecutorConfigType(pickler=dill))
external_executor_id = Column(StringID())
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 5dd47d157b..ab0bc5dff8 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -23,7 +23,7 @@ from typing import Any, Dict, Iterable, Tuple
import pendulum
from dateutil import relativedelta
-from sqlalchemy import TIMESTAMP, and_, event, false, nullsfirst, or_, tuple_
+from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, tuple_
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import Session
@@ -33,6 +33,7 @@ from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
from airflow import settings
from airflow.configuration import conf
+from airflow.serialization.enums import Encoding
log = logging.getLogger(__name__)
@@ -146,6 +147,66 @@ class ExtendedJSON(TypeDecorator):
return BaseSerialization._deserialize(value)
+class ExecutorConfigType(PickleType):
+ """
+ Adds special handling for K8s executor config. If we unpickle a k8s object that was
+ pickled under an earlier k8s library version, then the unpickled object may throw an error
+ when to_dict is called. To be more tolerant of version changes we convert to JSON using
+ Airflow's serializer before pickling.
+ """
+
+ def bind_processor(self, dialect):
+
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ super_process = super().bind_processor(dialect)
+
+ def process(value):
+ if isinstance(value, dict) and 'pod_override' in value:
+ value['pod_override'] = BaseSerialization()._serialize(value['pod_override'])
+ return super_process(value)
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ super_process = super().result_processor(dialect, coltype)
+
+ def process(value):
+ value = super_process(value) # unpickle
+
+ if isinstance(value, dict) and 'pod_override' in value:
+ pod_override = value['pod_override']
+
+ # If pod_override was serialized with Airflow's BaseSerialization, deserialize it
+ if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
+ value['pod_override'] = BaseSerialization()._deserialize(pod_override)
+ return value
+
+ return process
+
+ def compare_values(self, x, y):
+ """
+ The TaskInstance.executor_config attribute is a pickled object that may contain
+ kubernetes objects. If the installed library version has changed since the
+ object was originally pickled, due to the underlying ``__eq__`` method on these
+ objects (which converts them to JSON), we may encounter attribute errors. In this
+ case we should replace the stored object.
+
+ From https://github.com/apache/airflow/pull/24356 we use our serializer to store
+ k8s objects, but there could still be raw pickled k8s objects in the database,
+ stored from earlier version, so we still compare them defensively here.
+ """
+ if self.comparator:
+ return self.comparator(x, y)
+ else:
+ try:
+ return x == y
+ except AttributeError:
+ return False
+
+
class Interval(TypeDecorator):
"""Base class representing a time interval."""
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 71c6c2f8c4..da3d138306 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -57,7 +57,7 @@ from airflow.models import (
)
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
-from airflow.models.taskinstance import TaskInstance, _executor_config_comparator
+from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.operators.bash import BashOperator
@@ -2848,22 +2848,3 @@ def test_expand_non_templated_field(dag_maker, session):
echo_task = dag.get_task("echo")
assert "get_extra_env" in echo_task.upstream_task_ids
-
-
-def test_executor_config_comparator():
- """
- When comparison raises AttributeError, return False.
- This can happen when executor config contains kubernetes objects pickled
- under older kubernetes library version.
- """
-
- class MockAttrError:
- def __eq__(self, other):
- raise AttributeError('hello')
-
- a = MockAttrError()
- with pytest.raises(AttributeError):
- # just verify for ourselves that this throws
- assert a == a
- assert _executor_config_comparator(a, a) is False
- assert _executor_config_comparator('a', 'a') is True
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index 250bd37756..0038577b8a 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -17,20 +17,29 @@
# under the License.
#
import datetime
+import pickle
import unittest
+from copy import copy
from unittest import mock
+from unittest.mock import MagicMock
import pytest
+from kubernetes.client import models as k8s
from parameterized import parameterized
+from pytest import param
from sqlalchemy.exc import StatementError
from airflow import settings
from airflow.models import DAG
+from airflow.serialization.enums import Encoding
+from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import Session
-from airflow.utils.sqlalchemy import nowait, prohibit_commit, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import ExecutorConfigType, nowait, prohibit_commit, skip_locked, with_row_locks
from airflow.utils.state import State
from airflow.utils.timezone import utcnow
+TEST_POD = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
+
class TestSqlAlchemyUtils(unittest.TestCase):
def setUp(self):
@@ -226,3 +235,82 @@ class TestSqlAlchemyUtils(unittest.TestCase):
def tearDown(self):
self.session.close()
settings.engine.dispose()
+
+
+class TestExecutorConfigType:
+ @pytest.mark.parametrize(
+ 'input',
+ ['anything', {'pod_override': TEST_POD}],
+ )
+ def test_bind_processor(self, input):
+ """
+ The returned bind processor should pickle the object as is, unless it is a dictionary with
+ a pod_override node, in which case it should run it through BaseSerialization.
+ """
+ config_type = ExecutorConfigType()
+ mock_dialect = MagicMock()
+ mock_dialect.dbapi = None
+ process = config_type.bind_processor(mock_dialect)
+ expected = copy(input)
+ if 'pod_override' in input:
+ expected['pod_override'] = BaseSerialization()._serialize(input['pod_override'])
+ assert pickle.loads(process(input)) == expected
+
+ @pytest.mark.parametrize(
+ 'input',
+ [
+ param(
+ pickle.dumps('anything'),
+ id='anything',
+ ),
+ param(
+ pickle.dumps({'pod_override': BaseSerialization()._serialize(TEST_POD)}),
+ id='serialized_pod',
+ ),
+ param(
+ pickle.dumps({'pod_override': TEST_POD}),
+ id='old_pickled_raw_pod',
+ ),
+ param(
+ pickle.dumps({'pod_override': {"name": "hi"}}),
+ id='arbitrary_dict',
+ ),
+ ],
+ )
+ def test_result_processor(self, input):
+ """
+ The returned bind processor should pickle the object as is, unless it is a dictionary with
+ a pod_override node whose value was serialized with BaseSerialization.
+ """
+ config_type = ExecutorConfigType()
+ mock_dialect = MagicMock()
+ mock_dialect.dbapi = None
+ process = config_type.result_processor(mock_dialect, None)
+ result = process(input)
+ expected = pickle.loads(input)
+ pod_override = isinstance(expected, dict) and expected.get('pod_override')
+ if pod_override and isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
+ # We should only deserialize a pod_override with BaseSerialization if
+ # it was serialized with BaseSerialization (which is the behavior added in #24356
+ expected['pod_override'] = BaseSerialization()._deserialize(expected['pod_override'])
+ assert result == expected
+
+ def test_compare_values(self):
+ """
+ When comparison raises AttributeError, return False.
+ This can happen when executor config contains kubernetes objects pickled
+ under older kubernetes library version.
+ """
+
+ class MockAttrError:
+ def __eq__(self, other):
+ raise AttributeError('hello')
+
+ a = MockAttrError()
+ with pytest.raises(AttributeError):
+ # just verify for ourselves that comparing directly will throw AttributeError
+ assert a == a
+
+ instance = ExecutorConfigType()
+ assert instance.compare_values(a, a) is False
+ assert instance.compare_values('a', 'a') is True