You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/09/14 13:14:12 UTC

[airflow] 02/02: Handle list when serializing expand_kwargs (#26369)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7e9fd34cd07ec0f26c7a72589e327f48389771ed
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Sep 14 18:01:11 2022 +0800

    Handle list when serializing expand_kwargs (#26369)
    
    (cherry picked from commit b816a6b243d16da87ca00e443619c75e9f6f5816)
---
 airflow/serialization/serialized_objects.py   | 45 +++++++++++++++++++--
 tests/serialization/test_dag_serialization.py | 57 ++++++++++++++++++++++++++-
 2 files changed, 98 insertions(+), 4 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index fb298cc79e..969b6014db 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -17,6 +17,7 @@
 """Serialized DAG and BaseOperator"""
 from __future__ import annotations
 
+import collections.abc
 import datetime
 import enum
 import logging
@@ -24,7 +25,7 @@ import warnings
 import weakref
 from dataclasses import dataclass
 from inspect import Parameter, signature
-from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Type
+from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Type, Union
 
 import cattr
 import lazy_object_proxy
@@ -207,6 +208,26 @@ class _XComRef(NamedTuple):
         return deserialize_xcom_arg(self.data, dag)
 
 
+# These two should be kept in sync. Note that these are intentionally not using
+# the type declarations in expandinput.py so we always remember to update
+# serialization logic when adding new ExpandInput variants. If you add things to
+# the unions, be sure to update _ExpandInputRef to match.
+_ExpandInputOriginalValue = Union[
+    # For .expand(**kwargs).
+    Mapping[str, Any],
+    # For expand_kwargs(arg).
+    XComArg,
+    Collection[Union[XComArg, Mapping[str, Any]]],
+]
+_ExpandInputSerializedValue = Union[
+    # For .expand(**kwargs).
+    Mapping[str, Any],
+    # For expand_kwargs(arg).
+    _XComRef,
+    Collection[Union[_XComRef, Mapping[str, Any]]],
+]
+
+
 class _ExpandInputRef(NamedTuple):
     """Used to store info needed to create a mapped operator's expand input.
 
@@ -215,13 +236,29 @@ class _ExpandInputRef(NamedTuple):
     """
 
     key: str
-    value: _XComRef | dict[str, Any]
+    value: _ExpandInputSerializedValue
+
+    @classmethod
+    def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
+        """Validate we've covered all ``ExpandInput.value`` types.
+
+        This function does not actually do anything, but is called during
+        serialization so Mypy will *statically* check we have handled all
+        possible ExpandInput cases.
+        """
 
     def deref(self, dag: DAG) -> ExpandInput:
+        """De-reference into a concrete ExpandInput object.
+
+        If you add more cases here, be sure to update _ExpandInputOriginalValue
+        and _ExpandInputSerializedValue to match the logic.
+        """
         if isinstance(self.value, _XComRef):
             value: Any = self.value.deref(dag)
-        else:
+        elif isinstance(self.value, collections.abc.Mapping):
             value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
+        else:
+            value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
         return create_expand_input(self.key, value)
 
 
@@ -663,6 +700,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
         serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
         # Handle expand_input and op_kwargs_expand_input.
         expansion_kwargs = op._get_specified_expand_input()
+        if TYPE_CHECKING:  # Let Mypy check the input type for us!
+            _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
         serialized_op[op._expand_input_attr] = {
             "type": get_map_type_key(expansion_kwargs),
             "value": cls.serialize(expansion_kwargs.value),
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index bd171fd50d..aa409183e8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1961,7 +1961,62 @@ def test_operator_expand_xcomarg_serde():
 
 
 @pytest.mark.parametrize("strict", [True, False])
-def test_operator_expand_kwargs_serde(strict):
+def test_operator_expand_kwargs_literal_serde(strict):
+    from airflow.models.xcom_arg import PlainXComArg, XComArg
+    from airflow.serialization.serialized_objects import _XComRef
+
+    with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+        task1 = BaseOperator(task_id="op1")
+        mapped = MockOperator.partial(task_id='task_2').expand_kwargs(
+            [{"a": "x"}, {"a": XComArg(task1)}],
+            strict=strict,
+        )
+
+    serialized = SerializedBaseOperator.serialize(mapped)
+    assert serialized == {
+        '_is_empty': False,
+        '_is_mapped': True,
+        '_task_module': 'tests.test_utils.mock_operators',
+        '_task_type': 'MockOperator',
+        'downstream_task_ids': [],
+        'expand_input': {
+            "type": "list-of-dicts",
+            "value": [
+                {"__type": "dict", "__var": {"a": "x"}},
+                {
+                    "__type": "dict",
+                    "__var": {"a": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
+                },
+            ],
+        },
+        'partial_kwargs': {},
+        'task_id': 'task_2',
+        'template_fields': ['arg1', 'arg2'],
+        'template_ext': [],
+        'template_fields_renderers': {},
+        'operator_extra_links': [],
+        'ui_color': '#fff',
+        'ui_fgcolor': '#000',
+        "_disallow_kwargs_override": strict,
+        '_expand_input_attr': 'expand_input',
+    }
+
+    op = SerializedBaseOperator.deserialize_operator(serialized)
+    assert op.deps is MappedOperator.deps_for(BaseOperator)
+    assert op._disallow_kwargs_override == strict
+
+    # The XComArg can't be deserialized before the DAG is.
+    expand_value = op.expand_input.value
+    assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}]
+
+    serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    resolved_expand_value = serialized_dag.task_dict['task_2'].expand_input.value
+    resolved_expand_value == [{"a": "x"}, {"a": PlainXComArg(serialized_dag.task_dict['op1'])}]
+
+
+@pytest.mark.parametrize("strict", [True, False])
+def test_operator_expand_kwargs_xcomarg_serde(strict):
     from airflow.models.xcom_arg import PlainXComArg, XComArg
     from airflow.serialization.serialized_objects import _XComRef