You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/08/01 09:12:41 UTC

[airflow] branch main updated: Implement XComArg.zip(*xcom_args) (#25176)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b90fc14e0c Implement XComArg.zip(*xcom_args) (#25176)
b90fc14e0c is described below

commit b90fc14e0c35e679b164ef7bcab22d7a44e0210e
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Aug 1 17:12:31 2022 +0800

    Implement XComArg.zip(*xcom_args) (#25176)
---
 airflow/models/expandinput.py                 |  88 ++-------
 airflow/models/mappedoperator.py              |   8 +-
 airflow/models/xcom_arg.py                    | 269 +++++++++++++++++++++++---
 airflow/serialization/serialized_objects.py   |  19 +-
 tests/decorators/test_python.py               |   7 +-
 tests/models/test_xcom_arg.py                 |  44 +++++
 tests/models/test_xcom_arg_map.py             |  44 +++++
 tests/serialization/test_dag_serialization.py |  28 +--
 8 files changed, 372 insertions(+), 135 deletions(-)

diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 5a94698f9b..1b8f7fa80e 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -24,13 +24,12 @@ import functools
 import operator
 from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
 
-from sqlalchemy import func
-from sqlalchemy.orm import Session
-
 from airflow.compat.functools import cache
 from airflow.utils.context import Context
 
 if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
     from airflow.models.xcom_arg import XComArg
 
 ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
@@ -95,63 +94,16 @@ class DictOfListsExpandInput(NamedTuple):
         If any arguments are not known right now (upstream task not finished),
         they will not be present in the dict.
         """
-        from airflow.models.taskmap import TaskMap
-        from airflow.models.xcom import XCOM_RETURN_KEY, XCom
         from airflow.models.xcom_arg import XComArg
 
-        # Populate literal mapped arguments first.
-        map_lengths: dict[str, int] = collections.defaultdict(int)
-        map_lengths.update((k, len(v)) for k, v in self.value.items() if not isinstance(v, XComArg))
-
-        try:
-            dag_id = next(v.operator.dag_id for v in self.value.values() if isinstance(v, XComArg))
-        except StopIteration:  # All mapped arguments are literal. We're done.
-            return map_lengths
-
-        # Build a reverse mapping of what arguments each task contributes to.
-        mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
-        non_mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
-        for k, v in self.value.items():
-            if not isinstance(v, XComArg):
-                continue
-            assert v.operator.dag_id == dag_id
-            if v.operator.is_mapped:
-                mapped_dep_keys[v.operator.task_id].add(k)
-            else:
-                non_mapped_dep_keys[v.operator.task_id].add(k)
-            # TODO: It's not possible now, but in the future we may support
-            # depending on one single mapped task instance. When that happens,
-            # we need to further analyze the mapped case to contain only tasks
-            # we depend on "as a whole", and put those we only depend on
-            # individually to the non-mapped lookup.
-
-        # Collect lengths from unmapped upstreams.
-        taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
-            TaskMap.dag_id == dag_id,
-            TaskMap.run_id == run_id,
-            TaskMap.task_id.in_(non_mapped_dep_keys),
-            TaskMap.map_index < 0,
-        )
-        for task_id, length in taskmap_query:
-            for mapped_arg_name in non_mapped_dep_keys[task_id]:
-                map_lengths[mapped_arg_name] += length
-
-        # Collect lengths from mapped upstreams.
-        xcom_query = (
-            session.query(XCom.task_id, func.count(XCom.map_index))
-            .group_by(XCom.task_id)
-            .filter(
-                XCom.dag_id == dag_id,
-                XCom.run_id == run_id,
-                XCom.key == XCOM_RETURN_KEY,
-                XCom.task_id.in_(mapped_dep_keys),
-                XCom.map_index >= 0,
-            )
+        # TODO: This initiates one database call for each XComArg. Would it be
+        # more efficient to do one single db call and unpack the value here?
+        map_lengths_iterator = (
+            (k, (v.get_task_map_length(run_id, session=session) if isinstance(v, XComArg) else len(v)))
+            for k, v in self.value.items()
         )
-        for task_id, length in xcom_query:
-            for mapped_arg_name in mapped_dep_keys[task_id]:
-                map_lengths[mapped_arg_name] += length
 
+        map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
         if len(map_lengths) < len(self.value):
             raise NotFullyPopulated(set(self.value).difference(map_lengths))
         return map_lengths
@@ -228,28 +180,10 @@ class ListOfDictsExpandInput(NamedTuple):
         return None
 
     def get_total_map_length(self, run_id: str, *, session: Session) -> int:
-        from airflow.models.taskmap import TaskMap
-        from airflow.models.xcom import XCom
-
-        task = self.value.operator
-        if task.is_mapped:
-            query = session.query(func.count(XCom.map_index)).filter(
-                XCom.dag_id == task.dag_id,
-                XCom.run_id == run_id,
-                XCom.task_id == task.task_id,
-                XCom.map_index >= 0,
-            )
-        else:
-            query = session.query(TaskMap.length).filter(
-                TaskMap.dag_id == task.dag_id,
-                TaskMap.run_id == run_id,
-                TaskMap.task_id == task.task_id,
-                TaskMap.map_index < 0,
-            )
-        value = query.scalar()
-        if value is None:
+        length = self.value.get_task_map_length(run_id, session=session)
+        if length is None:
             raise NotFullyPopulated({"expand_kwargs() argument"})
-        return value
+        return length
 
     def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
         map_index = context["ti"].map_index
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 53ea072a79..f62d1c687b 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -138,8 +138,9 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
     from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg
 
     if isinstance(arg, XComArg):
-        if arg.key != XCOM_RETURN_KEY:
-            raise ValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}")
+        for operator, key in arg.iter_references():
+            if key != XCOM_RETURN_KEY:
+                raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
     elif not is_container(arg):
         return
     elif isinstance(arg, collections.abc.Mapping):
@@ -704,7 +705,8 @@ class MappedOperator(AbstractOperator):
         from airflow.models.xcom_arg import XComArg
 
         for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()):
-            yield ref.operator
+            for operator, _ in ref.iter_references():
+                yield operator
 
     @cached_property
     def parse_time_mapped_ti_count(self) -> Optional[int]:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 0a602d0487..a4c2b4d46d 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -15,7 +15,25 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Type, Union, overload
+import inspect
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+    overload,
+)
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
 
 from airflow.exceptions import AirflowException
 from airflow.models.abstractoperator import AbstractOperator
@@ -27,8 +45,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm import Session
-
+    from airflow.models.dag import DAG
     from airflow.models.operator import Operator
 
 
@@ -65,9 +82,6 @@ class XComArg(DependencyMixin):
         i.e. the referenced operator's return value.
     """
 
-    operator: "Operator"
-    key: str
-
     @overload
     def __new__(cls: Type["XComArg"], operator: "Operator", key: str = XCOM_RETURN_KEY) -> "XComArg":
         """Called when the user writes ``XComArg(...)`` directly."""
@@ -109,17 +123,18 @@ class XComArg(DependencyMixin):
         sets the relationship to ``op`` on any found.
         """
         for ref in XComArg.iter_xcom_args(arg):
-            op.set_upstream(ref.operator)
+            for operator, _ in ref.iter_references():
+                op.set_upstream(operator)
 
     @property
     def roots(self) -> List[DAGNode]:
         """Required by TaskMixin"""
-        return [self.operator]
+        return [op for op, _ in self.iter_references()]
 
     @property
     def leaves(self) -> List[DAGNode]:
         """Required by TaskMixin"""
-        return [self.operator]
+        return [op for op, _ in self.iter_references()]
 
     def set_upstream(
         self,
@@ -127,7 +142,8 @@ class XComArg(DependencyMixin):
         edge_modifier: Optional[EdgeModifier] = None,
     ):
         """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
-        self.operator.set_upstream(task_or_task_list, edge_modifier)
+        for operator, _ in self.iter_references():
+            operator.set_upstream(task_or_task_list, edge_modifier)
 
     def set_downstream(
         self,
@@ -135,9 +151,51 @@ class XComArg(DependencyMixin):
         edge_modifier: Optional[EdgeModifier] = None,
     ):
         """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
-        self.operator.set_downstream(task_or_task_list, edge_modifier)
+        for operator, _ in self.iter_references():
+            operator.set_downstream(task_or_task_list, edge_modifier)
+
+    def _serialize(self) -> Dict[str, Any]:
+        """Called by DAG serialization.
+
+        The implementation should be the inverse function to ``deserialize``,
+        returning a data dict converted from this XComArg derivative. DAG
+        serialization does not call this directly, but ``serialize_xcom_arg``
+        instead, which adds additional information to dispatch deserialization
+        to the correct class.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> "XComArg":
+        """Called when deserializing a DAG.
+
+        The implementation should be the inverse function to ``serialize``,
+        implementing given a data dict converted from this XComArg derivative,
+        how the original XComArg should be created. DAG serialization relies on
+        additional information added in ``serialize_xcom_arg`` to dispatch data
+        dicts to the correct ``_deserialize`` information, so this function does
+        not need to validate whether the incoming data contains correct keys.
+        """
+        raise NotImplementedError()
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        """Iterate through (operator, key) references."""
+        raise NotImplementedError()
 
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+        return MapXComArg(self, [f])
+
+    def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+        return ZipXComArg([self, *others], fillvalue=fillvalue)
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        """Inspect length of pushed value for task-mapping.
+
+        This is used to determine how many task instances the scheduler should
+        create for a downstream using this XComArg for task-mapping.
+
+        *None* may be returned if the depended XCom has not been pushed.
+        """
         raise NotImplementedError()
 
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
@@ -166,7 +224,7 @@ class PlainXComArg(XComArg):
         self.operator = operator
         self.key = key
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         if not isinstance(other, PlainXComArg):
             return NotImplemented
         return self.operator == other.operator and self.key == other.key
@@ -191,7 +249,12 @@ class PlainXComArg(XComArg):
         """
         raise TypeError("'XComArg' object is not iterable")
 
-    def __str__(self):
+    def __repr__(self) -> str:
+        if self.key == XCOM_RETURN_KEY:
+            return f"XComArg({self.operator!r})"
+        return f"XComArg({self.operator!r}, {self.key!r})"
+
+    def __str__(self) -> str:
         """
         Backward compatibility for old-style jinja used in Airflow Operators
 
@@ -203,20 +266,57 @@ class PlainXComArg(XComArg):
         """
         xcom_pull_kwargs = [
             f"task_ids='{self.operator.task_id}'",
-            f"dag_id='{self.operator.dag.dag_id}'",
+            f"dag_id='{self.operator.dag_id}'",
         ]
         if self.key is not None:
             xcom_pull_kwargs.append(f"key='{self.key}'")
 
-        xcom_pull_kwargs = ", ".join(xcom_pull_kwargs)
+        xcom_pull_str = ", ".join(xcom_pull_kwargs)
         # {{{{ are required for escape {{ in f-string
-        xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
+        xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
         return xcom_pull
 
+    def _serialize(self) -> Dict[str, Any]:
+        return {"task_id": self.operator.task_id, "key": self.key}
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        return cls(dag.get_task(data["task_id"]), data["key"])
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        yield self.operator, self.key
+
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
         if self.key != XCOM_RETURN_KEY:
-            raise ValueError
-        return MapXComArg(self, [f])
+            raise ValueError("cannot map against non-return XCom")
+        return super().map(f)
+
+    def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+        if self.key != XCOM_RETURN_KEY:
+            raise ValueError("cannot map against non-return XCom")
+        return super().zip(*others, fillvalue=fillvalue)
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        from airflow.models.taskmap import TaskMap
+        from airflow.models.xcom import XCom
+
+        task = self.operator
+        if task.is_mapped:
+            query = session.query(func.count(XCom.map_index)).filter(
+                XCom.dag_id == task.dag_id,
+                XCom.run_id == run_id,
+                XCom.task_id == task.task_id,
+                XCom.map_index >= 0,
+                XCom.key == XCOM_RETURN_KEY,
+            )
+        else:
+            query = session.query(TaskMap.length).filter(
+                TaskMap.dag_id == task.dag_id,
+                TaskMap.run_id == run_id,
+                TaskMap.task_id == task.task_id,
+                TaskMap.map_index < 0,
+            )
+        return query.scalar()
 
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
@@ -257,23 +357,140 @@ class MapXComArg(XComArg):
     convert the pulled XCom value.
     """
 
-    def __init__(self, arg: PlainXComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
+    def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
         self.arg = arg
         self.callables = callables
 
-    @property
-    def operator(self) -> "Operator":  # type: ignore[override]
-        return self.arg.operator
+    def __repr__(self) -> str:
+        return f"{self.arg!r}.map([{len(self.callables)} functions])"
 
-    @property
-    def key(self) -> str:  # type: ignore[override]
-        return self.arg.key
+    def _serialize(self) -> Dict[str, Any]:
+        return {
+            "arg": serialize_xcom_arg(self.arg),
+            "callables": [inspect.getsource(c) for c in self.callables],
+        }
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        # We are deliberately NOT deserializing the callables. These are shown
+        # in the UI, and displaying a function object is useless.
+        return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        yield from self.arg.iter_references()
 
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+        # Flatten arg.map(f1).map(f2) into one MapXComArg.
         return MapXComArg(self.arg, [*self.callables, f])
 
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        return self.arg.get_task_map_length(run_id, session=session)
+
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
         value = self.arg.resolve(context, session=session)
-        assert isinstance(value, (Sequence, dict))  # Validation was done when XCom was pushed.
+        if not isinstance(value, (Sequence, dict)):
+            raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
         return _MapResult(value, self.callables)
+
+
+class _ZipResult(Sequence):
+    def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None:
+        self.values = values
+        self.fillvalue = fillvalue
+
+    @staticmethod
+    def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any:
+        try:
+            return container[index]
+        except (IndexError, KeyError):
+            return fillvalue
+
+    def __getitem__(self, index: Any) -> Any:
+        if index >= len(self):
+            raise IndexError(index)
+        return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
+
+    def __len__(self) -> int:
+        lengths = (len(v) for v in self.values)
+        if self.fillvalue is NOTSET:
+            return min(lengths)
+        return max(lengths)
+
+
+class ZipXComArg(XComArg):
+    """An XCom reference with ``zip()`` applied.
+
+    This is constructed from multiple XComArg instances, and presents an
+    iterable that "zips" them together like the built-in ``zip()`` (and
+    ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+    """
+
+    def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
+        if not args:
+            raise ValueError("At least one input is required")
+        self.args = args
+        self.fillvalue = fillvalue
+
+    def __repr__(self) -> str:
+        args_iter = iter(self.args)
+        first = repr(next(args_iter))
+        rest = ", ".join(repr(arg) for arg in args_iter)
+        if self.fillvalue is NOTSET:
+            return f"{first}.zip({rest})"
+        return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+    def _serialize(self) -> Dict[str, Any]:
+        args = [serialize_xcom_arg(arg) for arg in self.args]
+        if self.fillvalue is NOTSET:
+            return {"args": args}
+        return {"args": args, "fillvalue": self.fillvalue}
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        return cls(
+            [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+            fillvalue=data.get("fillvalue", NOTSET),
+        )
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        for arg in self.args:
+            yield from arg.iter_references()
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
+        ready_lengths = [length for length in all_lengths if length is not None]
+        if len(ready_lengths) != len(self.args):
+            return None  # If any of the referenced XComs is not ready, we are not ready either.
+        if self.fillvalue is NOTSET:
+            return min(ready_lengths)
+        return max(ready_lengths)
+
+    @provide_session
+    def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
+        values = [arg.resolve(context, session=session) for arg in self.args]
+        for value in values:
+            if not isinstance(value, (Sequence, dict)):
+                raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
+        return _ZipResult(values, fillvalue=self.fillvalue)
+
+
+_XCOM_ARG_TYPES: Mapping[str, Type[XComArg]] = {
+    "": PlainXComArg,
+    "map": MapXComArg,
+    "zip": ZipXComArg,
+}
+
+
+def serialize_xcom_arg(value: XComArg) -> Dict[str, Any]:
+    """DAG serialization interface."""
+    key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
+    if key:
+        return {"type": key, **value._serialize()}
+    return value._serialize()
+
+
+def deserialize_xcom_arg(data: Dict[str, Any], dag: "DAG") -> XComArg:
+    """DAG serialization interface."""
+    klass = _XCOM_ARG_TYPES[data.get("type", "")]
+    return klass._deserialize(data, dag)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index e24aa826a4..3ace47e2b0 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -43,7 +43,7 @@ from airflow.models.mappedoperator import MappedOperator
 from airflow.models.operator import Operator
 from airflow.models.param import Param, ParamsDict
 from airflow.models.taskmixin import DAGNode
-from airflow.models.xcom_arg import XComArg
+from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
 from airflow.operators.trigger_dagrun import TriggerDagRunOperator
 from airflow.providers_manager import ProvidersManager
 from airflow.sensors.external_task import ExternalTaskSensor
@@ -202,11 +202,10 @@ class _XComRef(NamedTuple):
     post-process it in ``deserialize_dag``.
     """
 
-    task_id: str
-    key: str
+    data: dict
 
     def deref(self, dag: DAG) -> XComArg:
-        return XComArg(operator=dag.get_task(self.task_id), key=self.key)
+        return deserialize_xcom_arg(self.data, dag)
 
 
 class _ExpandInputRef(NamedTuple):
@@ -393,7 +392,7 @@ class BaseSerialization:
         elif isinstance(var, Param):
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
-            return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+            return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
         elif isinstance(var, Dataset):
             return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
         else:
@@ -440,7 +439,7 @@ class BaseSerialization:
         elif type_ == DAT.PARAM:
             return cls._deserialize_param(var)
         elif type_ == DAT.XCOM_REF:
-            return cls._deserialize_xcomref(var)
+            return _XComRef(var)  # Delay deserializing XComArg objects until we have the entire DAG.
         elif type_ == DAT.DATASET:
             return Dataset(**var)
         else:
@@ -545,14 +544,6 @@ class BaseSerialization:
 
         return ParamsDict(op_params)
 
-    @classmethod
-    def _serialize_xcomarg(cls, arg: XComArg) -> dict:
-        return {"key": arg.key, "task_id": arg.operator.task_id}
-
-    @classmethod
-    def _deserialize_xcomref(cls, encoded: dict) -> _XComRef:
-        return _XComRef(key=encoded['key'], task_id=encoded['task_id'])
-
 
 class DependencyDetector:
     """
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 58ae1c5f87..199366df8e 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -32,7 +32,7 @@ from airflow.models.expandinput import DictOfListsExpandInput
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.models.xcom_arg import XComArg
+from airflow.models.xcom_arg import PlainXComArg, XComArg
 from airflow.utils import timezone
 from airflow.utils.state import State
 from airflow.utils.task_group import TaskGroup
@@ -649,13 +649,16 @@ def test_partial_mapped_decorator() -> None:
 
         product.partial(multiple=2)  # No operator is actually created.
 
+    assert isinstance(doubled, PlainXComArg)
+    assert isinstance(trippled, PlainXComArg)
+    assert isinstance(quadrupled, PlainXComArg)
+
     assert dag.task_dict == {
         "product": quadrupled.operator,
         "product__1": doubled.operator,
         "product__2": trippled.operator,
     }
 
-    assert isinstance(doubled, XComArg)
     assert isinstance(doubled.operator, DecoratedMappedOperator)
     assert doubled.operator.op_kwargs_expand_input == DictOfListsExpandInput({"number": literal})
     assert doubled.operator.partial_kwargs["op_kwargs"] == {"multiple": 2}
diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py
index cd3b548285..047412a248 100644
--- a/tests/models/test_xcom_arg.py
+++ b/tests/models/test_xcom_arg.py
@@ -19,6 +19,7 @@ import pytest
 from airflow.models.xcom_arg import XComArg
 from airflow.operators.bash import BashOperator
 from airflow.operators.python import PythonOperator
+from airflow.utils.types import NOTSET
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs
 
@@ -177,3 +178,46 @@ class TestXComArgRuntime:
             )
             op1 >> op2
         dag.run()
+
+
+@pytest.mark.parametrize(
+    "fillvalue, expected_results",
+    [
+        (NOTSET, {("a", 1), ("b", 2), ("c", 3)}),
+        (None, {("a", 1), ("b", 2), ("c", 3), (None, 4)}),
+    ],
+)
+def test_xcom_zip(dag_maker, session, fillvalue, expected_results):
+    results = set()
+    with dag_maker(session=session) as dag:
+
+        @dag.task
+        def push_letters():
+            return ["a", "b", "c"]
+
+        @dag.task
+        def push_numbers():
+            return [1, 2, 3, 4]
+
+        @dag.task
+        def pull(value):
+            results.add(value)
+
+        pull.expand(value=push_letters().zip(push_numbers(), fillvalue=fillvalue))
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push_letters" and "push_numbers".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
+    for ti in decision.schedulable_tis:
+        ti.run(session=session)
+    session.commit()
+
+    # Run "pull".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
+    for ti in decision.schedulable_tis:
+        ti.run(session=session)
+
+    assert results == expected_results
diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py
index 144eb20327..9da732b93d 100644
--- a/tests/models/test_xcom_arg_map.py
+++ b/tests/models/test_xcom_arg_map.py
@@ -257,3 +257,47 @@ def test_xcom_map_nest(dag_maker, session):
     for ti in decision.schedulable_tis:
         ti.run()
     assert results == {"aa", "bb", "cc"}
+
+
+def test_xcom_map_zip_nest(dag_maker, session):
+    results = set()
+
+    with dag_maker(session=session) as dag:
+
+        @dag.task
+        def push_letters():
+            return ["a", "b", "c", "d"]
+
+        @dag.task
+        def push_numbers():
+            return [1, 2, 3, 4]
+
+        @dag.task
+        def pull(value):
+            results.add(value)
+
+        doubled = push_numbers().map(lambda v: v * 2)
+        combined = doubled.zip(push_letters())
+
+        def convert_zipped(zipped):
+            letter, number = zipped
+            return letter * number
+
+        pull.expand(value=combined.map(convert_zipped))
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push_letters" and "push_numbers".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
+    for ti in decision.schedulable_tis:
+        ti.run(session=session)
+    session.commit()
+
+    # Run "pull".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
+    for ti in decision.schedulable_tis:
+        ti.run(session=session)
+
+    assert results == {"aa", "bbbb", "cccccc", "dddddddd"}
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 5751ae137c..c6444b8aaa 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1731,7 +1731,8 @@ def test_operator_expand_serde():
 
 
 def test_operator_expand_xcomarg_serde():
-    from airflow.models.xcom_arg import XComArg
+    from airflow.models.xcom_arg import PlainXComArg, XComArg
+    from airflow.serialization.serialized_objects import _XComRef
 
     with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
         task1 = BaseOperator(task_id="op1")
@@ -1766,20 +1767,21 @@ def test_operator_expand_xcomarg_serde():
     op = SerializedBaseOperator.deserialize_operator(serialized)
     assert op.deps is MappedOperator.deps_for(BaseOperator)
 
-    arg = op.expand_input.value['arg2']
-    assert arg.task_id == 'op1'
-    assert arg.key == XCOM_RETURN_KEY
+    # The XComArg can't be deserialized before the DAG is.
+    xcom_ref = op.expand_input.value['arg2']
+    assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})
 
     serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
 
     xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value['arg2']
-    assert isinstance(xcom_arg, XComArg)
+    assert isinstance(xcom_arg, PlainXComArg)
     assert xcom_arg.operator is serialized_dag.task_dict['op1']
 
 
 @pytest.mark.parametrize("strict", [True, False])
 def test_operator_expand_kwargs_serde(strict):
-    from airflow.models.xcom_arg import XComArg
+    from airflow.models.xcom_arg import PlainXComArg, XComArg
+    from airflow.serialization.serialized_objects import _XComRef
 
     with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
         task1 = BaseOperator(task_id="op1")
@@ -1812,14 +1814,14 @@ def test_operator_expand_kwargs_serde(strict):
     assert op.deps is MappedOperator.deps_for(BaseOperator)
     assert op._disallow_kwargs_override == strict
 
+    # The XComArg can't be deserialized before the DAG is.
     xcom_ref = op.expand_input.value
-    assert xcom_ref.task_id == 'op1'
-    assert xcom_ref.key == XCOM_RETURN_KEY
+    assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})
 
     serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
 
     xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value
-    assert isinstance(xcom_arg, XComArg)
+    assert isinstance(xcom_arg, PlainXComArg)
     assert xcom_arg.operator is serialized_dag.task_dict['op1']
 
 
@@ -1913,7 +1915,7 @@ def test_taskflow_expand_serde():
 
     assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
         key="dict-of-lists",
-        value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", XCOM_RETURN_KEY)},
+        value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})},
     )
     assert deserialized.partial_kwargs == {
         "op_args": [],
@@ -1928,7 +1930,7 @@ def test_taskflow_expand_serde():
     pickled = pickle.loads(pickle.dumps(deserialized))
     assert pickled.op_kwargs_expand_input == _ExpandInputRef(
         key="dict-of-lists",
-        value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", XCOM_RETURN_KEY)},
+        value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})},
     )
     assert pickled.partial_kwargs == {
         "op_args": [],
@@ -1996,7 +1998,7 @@ def test_taskflow_expand_kwargs_serde(strict):
 
     assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
         key="list-of-dicts",
-        value=_XComRef("op1", XCOM_RETURN_KEY),
+        value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
     )
     assert deserialized.partial_kwargs == {
         "op_args": [],
@@ -2011,7 +2013,7 @@ def test_taskflow_expand_kwargs_serde(strict):
     pickled = pickle.loads(pickle.dumps(deserialized))
     assert pickled.op_kwargs_expand_input == _ExpandInputRef(
         "list-of-dicts",
-        _XComRef("op1", XCOM_RETURN_KEY),
+        _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
     )
     assert pickled.partial_kwargs == {
         "op_args": [],