You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/09/14 13:14:12 UTC
[airflow] 02/02: Handle list when serializing expand_kwargs (#26369)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7e9fd34cd07ec0f26c7a72589e327f48389771ed
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Sep 14 18:01:11 2022 +0800
Handle list when serializing expand_kwargs (#26369)
(cherry picked from commit b816a6b243d16da87ca00e443619c75e9f6f5816)
---
airflow/serialization/serialized_objects.py | 45 +++++++++++++++++++--
tests/serialization/test_dag_serialization.py | 57 ++++++++++++++++++++++++++-
2 files changed, 98 insertions(+), 4 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index fb298cc79e..969b6014db 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -17,6 +17,7 @@
"""Serialized DAG and BaseOperator"""
from __future__ import annotations
+import collections.abc
import datetime
import enum
import logging
@@ -24,7 +25,7 @@ import warnings
import weakref
from dataclasses import dataclass
from inspect import Parameter, signature
-from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Type
+from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Type, Union
import cattr
import lazy_object_proxy
@@ -207,6 +208,26 @@ class _XComRef(NamedTuple):
return deserialize_xcom_arg(self.data, dag)
+# These two should be kept in sync. Note that these are intentionally not using
+# the type declarations in expandinput.py so we always remember to update
+# serialization logic when adding new ExpandInput variants. If you add things to
+# the unions, be sure to update _ExpandInputRef to match.
+_ExpandInputOriginalValue = Union[
+ # For .expand(**kwargs).
+ Mapping[str, Any],
+ # For expand_kwargs(arg).
+ XComArg,
+ Collection[Union[XComArg, Mapping[str, Any]]],
+]
+_ExpandInputSerializedValue = Union[
+ # For .expand(**kwargs).
+ Mapping[str, Any],
+ # For expand_kwargs(arg).
+ _XComRef,
+ Collection[Union[_XComRef, Mapping[str, Any]]],
+]
+
+
class _ExpandInputRef(NamedTuple):
"""Used to store info needed to create a mapped operator's expand input.
@@ -215,13 +236,29 @@ class _ExpandInputRef(NamedTuple):
"""
key: str
- value: _XComRef | dict[str, Any]
+ value: _ExpandInputSerializedValue
+
+ @classmethod
+ def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
+ """Validate we've covered all ``ExpandInput.value`` types.
+
+ This function does not actually do anything, but is called during
+ serialization so Mypy will *statically* check we have handled all
+ possible ExpandInput cases.
+ """
def deref(self, dag: DAG) -> ExpandInput:
+ """De-reference into a concrete ExpandInput object.
+
+ If you add more cases here, be sure to update _ExpandInputOriginalValue
+ and _ExpandInputSerializedValue to match the logic.
+ """
if isinstance(self.value, _XComRef):
value: Any = self.value.deref(dag)
- else:
+ elif isinstance(self.value, collections.abc.Mapping):
value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
+ else:
+ value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
return create_expand_input(self.key, value)
@@ -663,6 +700,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
# Handle expand_input and op_kwargs_expand_input.
expansion_kwargs = op._get_specified_expand_input()
+ if TYPE_CHECKING: # Let Mypy check the input type for us!
+ _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
serialized_op[op._expand_input_attr] = {
"type": get_map_type_key(expansion_kwargs),
"value": cls.serialize(expansion_kwargs.value),
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index bd171fd50d..aa409183e8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1961,7 +1961,62 @@ def test_operator_expand_xcomarg_serde():
@pytest.mark.parametrize("strict", [True, False])
-def test_operator_expand_kwargs_serde(strict):
+def test_operator_expand_kwargs_literal_serde(strict):
+ from airflow.models.xcom_arg import PlainXComArg, XComArg
+ from airflow.serialization.serialized_objects import _XComRef
+
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+ task1 = BaseOperator(task_id="op1")
+ mapped = MockOperator.partial(task_id='task_2').expand_kwargs(
+ [{"a": "x"}, {"a": XComArg(task1)}],
+ strict=strict,
+ )
+
+ serialized = SerializedBaseOperator.serialize(mapped)
+ assert serialized == {
+ '_is_empty': False,
+ '_is_mapped': True,
+ '_task_module': 'tests.test_utils.mock_operators',
+ '_task_type': 'MockOperator',
+ 'downstream_task_ids': [],
+ 'expand_input': {
+ "type": "list-of-dicts",
+ "value": [
+ {"__type": "dict", "__var": {"a": "x"}},
+ {
+ "__type": "dict",
+ "__var": {"a": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
+ },
+ ],
+ },
+ 'partial_kwargs': {},
+ 'task_id': 'task_2',
+ 'template_fields': ['arg1', 'arg2'],
+ 'template_ext': [],
+ 'template_fields_renderers': {},
+ 'operator_extra_links': [],
+ 'ui_color': '#fff',
+ 'ui_fgcolor': '#000',
+ "_disallow_kwargs_override": strict,
+ '_expand_input_attr': 'expand_input',
+ }
+
+ op = SerializedBaseOperator.deserialize_operator(serialized)
+ assert op.deps is MappedOperator.deps_for(BaseOperator)
+ assert op._disallow_kwargs_override == strict
+
+ # The XComArg can't be deserialized before the DAG is.
+ expand_value = op.expand_input.value
+ assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}]
+
+ serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ resolved_expand_value = serialized_dag.task_dict['task_2'].expand_input.value
+ resolved_expand_value == [{"a": "x"}, {"a": PlainXComArg(serialized_dag.task_dict['op1'])}]
+
+
+@pytest.mark.parametrize("strict", [True, False])
+def test_operator_expand_kwargs_xcomarg_serde(strict):
from airflow.models.xcom_arg import PlainXComArg, XComArg
from airflow.serialization.serialized_objects import _XComRef