You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/08/01 09:12:41 UTC
[airflow] branch main updated: Implement XComArg.zip(*xcom_args) (#25176)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b90fc14e0c Implement XComArg.zip(*xcom_args) (#25176)
b90fc14e0c is described below
commit b90fc14e0c35e679b164ef7bcab22d7a44e0210e
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Aug 1 17:12:31 2022 +0800
Implement XComArg.zip(*xcom_args) (#25176)
---
airflow/models/expandinput.py | 88 ++-------
airflow/models/mappedoperator.py | 8 +-
airflow/models/xcom_arg.py | 269 +++++++++++++++++++++++---
airflow/serialization/serialized_objects.py | 19 +-
tests/decorators/test_python.py | 7 +-
tests/models/test_xcom_arg.py | 44 +++++
tests/models/test_xcom_arg_map.py | 44 +++++
tests/serialization/test_dag_serialization.py | 28 +--
8 files changed, 372 insertions(+), 135 deletions(-)
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 5a94698f9b..1b8f7fa80e 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -24,13 +24,12 @@ import functools
import operator
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
-from sqlalchemy import func
-from sqlalchemy.orm import Session
-
from airflow.compat.functools import cache
from airflow.utils.context import Context
if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
from airflow.models.xcom_arg import XComArg
ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
@@ -95,63 +94,16 @@ class DictOfListsExpandInput(NamedTuple):
If any arguments are not known right now (upstream task not finished),
they will not be present in the dict.
"""
- from airflow.models.taskmap import TaskMap
- from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.models.xcom_arg import XComArg
- # Populate literal mapped arguments first.
- map_lengths: dict[str, int] = collections.defaultdict(int)
- map_lengths.update((k, len(v)) for k, v in self.value.items() if not isinstance(v, XComArg))
-
- try:
- dag_id = next(v.operator.dag_id for v in self.value.values() if isinstance(v, XComArg))
- except StopIteration: # All mapped arguments are literal. We're done.
- return map_lengths
-
- # Build a reverse mapping of what arguments each task contributes to.
- mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
- non_mapped_dep_keys: dict[str, set[str]] = collections.defaultdict(set)
- for k, v in self.value.items():
- if not isinstance(v, XComArg):
- continue
- assert v.operator.dag_id == dag_id
- if v.operator.is_mapped:
- mapped_dep_keys[v.operator.task_id].add(k)
- else:
- non_mapped_dep_keys[v.operator.task_id].add(k)
- # TODO: It's not possible now, but in the future we may support
- # depending on one single mapped task instance. When that happens,
- # we need to further analyze the mapped case to contain only tasks
- # we depend on "as a whole", and put those we only depend on
- # individually to the non-mapped lookup.
-
- # Collect lengths from unmapped upstreams.
- taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
- TaskMap.dag_id == dag_id,
- TaskMap.run_id == run_id,
- TaskMap.task_id.in_(non_mapped_dep_keys),
- TaskMap.map_index < 0,
- )
- for task_id, length in taskmap_query:
- for mapped_arg_name in non_mapped_dep_keys[task_id]:
- map_lengths[mapped_arg_name] += length
-
- # Collect lengths from mapped upstreams.
- xcom_query = (
- session.query(XCom.task_id, func.count(XCom.map_index))
- .group_by(XCom.task_id)
- .filter(
- XCom.dag_id == dag_id,
- XCom.run_id == run_id,
- XCom.key == XCOM_RETURN_KEY,
- XCom.task_id.in_(mapped_dep_keys),
- XCom.map_index >= 0,
- )
+ # TODO: This initiates one database call for each XComArg. Would it be
+ # more efficient to do one single db call and unpack the value here?
+ map_lengths_iterator = (
+ (k, (v.get_task_map_length(run_id, session=session) if isinstance(v, XComArg) else len(v)))
+ for k, v in self.value.items()
)
- for task_id, length in xcom_query:
- for mapped_arg_name in mapped_dep_keys[task_id]:
- map_lengths[mapped_arg_name] += length
+ map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
if len(map_lengths) < len(self.value):
raise NotFullyPopulated(set(self.value).difference(map_lengths))
return map_lengths
@@ -228,28 +180,10 @@ class ListOfDictsExpandInput(NamedTuple):
return None
def get_total_map_length(self, run_id: str, *, session: Session) -> int:
- from airflow.models.taskmap import TaskMap
- from airflow.models.xcom import XCom
-
- task = self.value.operator
- if task.is_mapped:
- query = session.query(func.count(XCom.map_index)).filter(
- XCom.dag_id == task.dag_id,
- XCom.run_id == run_id,
- XCom.task_id == task.task_id,
- XCom.map_index >= 0,
- )
- else:
- query = session.query(TaskMap.length).filter(
- TaskMap.dag_id == task.dag_id,
- TaskMap.run_id == run_id,
- TaskMap.task_id == task.task_id,
- TaskMap.map_index < 0,
- )
- value = query.scalar()
- if value is None:
+ length = self.value.get_task_map_length(run_id, session=session)
+ if length is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
- return value
+ return length
def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
map_index = context["ti"].map_index
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 53ea072a79..f62d1c687b 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -138,8 +138,9 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg
if isinstance(arg, XComArg):
- if arg.key != XCOM_RETURN_KEY:
- raise ValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}")
+ for operator, key in arg.iter_references():
+ if key != XCOM_RETURN_KEY:
+ raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
elif not is_container(arg):
return
elif isinstance(arg, collections.abc.Mapping):
@@ -704,7 +705,8 @@ class MappedOperator(AbstractOperator):
from airflow.models.xcom_arg import XComArg
for ref in XComArg.iter_xcom_args(self._get_specified_expand_input()):
- yield ref.operator
+ for operator, _ in ref.iter_references():
+ yield operator
@cached_property
def parse_time_mapped_ti_count(self) -> Optional[int]:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 0a602d0487..a4c2b4d46d 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -15,7 +15,25 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Type, Union, overload
+import inspect
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ overload,
+)
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import AbstractOperator
@@ -27,8 +45,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
- from sqlalchemy.orm import Session
-
+ from airflow.models.dag import DAG
from airflow.models.operator import Operator
@@ -65,9 +82,6 @@ class XComArg(DependencyMixin):
i.e. the referenced operator's return value.
"""
- operator: "Operator"
- key: str
-
@overload
def __new__(cls: Type["XComArg"], operator: "Operator", key: str = XCOM_RETURN_KEY) -> "XComArg":
"""Called when the user writes ``XComArg(...)`` directly."""
@@ -109,17 +123,18 @@ class XComArg(DependencyMixin):
sets the relationship to ``op`` on any found.
"""
for ref in XComArg.iter_xcom_args(arg):
- op.set_upstream(ref.operator)
+ for operator, _ in ref.iter_references():
+ op.set_upstream(operator)
@property
def roots(self) -> List[DAGNode]:
"""Required by TaskMixin"""
- return [self.operator]
+ return [op for op, _ in self.iter_references()]
@property
def leaves(self) -> List[DAGNode]:
"""Required by TaskMixin"""
- return [self.operator]
+ return [op for op, _ in self.iter_references()]
def set_upstream(
self,
@@ -127,7 +142,8 @@ class XComArg(DependencyMixin):
edge_modifier: Optional[EdgeModifier] = None,
):
"""Proxy to underlying operator set_upstream method. Required by TaskMixin."""
- self.operator.set_upstream(task_or_task_list, edge_modifier)
+ for operator, _ in self.iter_references():
+ operator.set_upstream(task_or_task_list, edge_modifier)
def set_downstream(
self,
@@ -135,9 +151,51 @@ class XComArg(DependencyMixin):
edge_modifier: Optional[EdgeModifier] = None,
):
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
- self.operator.set_downstream(task_or_task_list, edge_modifier)
+ for operator, _ in self.iter_references():
+ operator.set_downstream(task_or_task_list, edge_modifier)
+
+ def _serialize(self) -> Dict[str, Any]:
+ """Called by DAG serialization.
+
+ The implementation should be the inverse function to ``deserialize``,
+ returning a data dict converted from this XComArg derivative. DAG
+ serialization does not call this directly, but ``serialize_xcom_arg``
+ instead, which adds additional information to dispatch deserialization
+ to the correct class.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> "XComArg":
+ """Called when deserializing a DAG.
+
+ The implementation should be the inverse function to ``serialize``,
+ implementing given a data dict converted from this XComArg derivative,
+ how the original XComArg should be created. DAG serialization relies on
+ additional information added in ``serialize_xcom_arg`` to dispatch data
+ dicts to the correct ``_deserialize`` information, so this function does
+ not need to validate whether the incoming data contains correct keys.
+ """
+ raise NotImplementedError()
+
+ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+ """Iterate through (operator, key) references."""
+ raise NotImplementedError()
def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+ return MapXComArg(self, [f])
+
+ def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+ return ZipXComArg([self, *others], fillvalue=fillvalue)
+
+ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+ """Inspect length of pushed value for task-mapping.
+
+ This is used to determine how many task instances the scheduler should
+ create for a downstream using this XComArg for task-mapping.
+
+ *None* may be returned if the depended XCom has not been pushed.
+ """
raise NotImplementedError()
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
@@ -166,7 +224,7 @@ class PlainXComArg(XComArg):
self.operator = operator
self.key = key
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if not isinstance(other, PlainXComArg):
return NotImplemented
return self.operator == other.operator and self.key == other.key
@@ -191,7 +249,12 @@ class PlainXComArg(XComArg):
"""
raise TypeError("'XComArg' object is not iterable")
- def __str__(self):
+ def __repr__(self) -> str:
+ if self.key == XCOM_RETURN_KEY:
+ return f"XComArg({self.operator!r})"
+ return f"XComArg({self.operator!r}, {self.key!r})"
+
+ def __str__(self) -> str:
"""
Backward compatibility for old-style jinja used in Airflow Operators
@@ -203,20 +266,57 @@ class PlainXComArg(XComArg):
"""
xcom_pull_kwargs = [
f"task_ids='{self.operator.task_id}'",
- f"dag_id='{self.operator.dag.dag_id}'",
+ f"dag_id='{self.operator.dag_id}'",
]
if self.key is not None:
xcom_pull_kwargs.append(f"key='{self.key}'")
- xcom_pull_kwargs = ", ".join(xcom_pull_kwargs)
+ xcom_pull_str = ", ".join(xcom_pull_kwargs)
# {{{{ are required for escape {{ in f-string
- xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
+ xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
return xcom_pull
+ def _serialize(self) -> Dict[str, Any]:
+ return {"task_id": self.operator.task_id, "key": self.key}
+
+ @classmethod
+ def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+ return cls(dag.get_task(data["task_id"]), data["key"])
+
+ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+ yield self.operator, self.key
+
def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
if self.key != XCOM_RETURN_KEY:
- raise ValueError
- return MapXComArg(self, [f])
+ raise ValueError("cannot map against non-return XCom")
+ return super().map(f)
+
+ def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+ if self.key != XCOM_RETURN_KEY:
+ raise ValueError("cannot map against non-return XCom")
+ return super().zip(*others, fillvalue=fillvalue)
+
+ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+ from airflow.models.taskmap import TaskMap
+ from airflow.models.xcom import XCom
+
+ task = self.operator
+ if task.is_mapped:
+ query = session.query(func.count(XCom.map_index)).filter(
+ XCom.dag_id == task.dag_id,
+ XCom.run_id == run_id,
+ XCom.task_id == task.task_id,
+ XCom.map_index >= 0,
+ XCom.key == XCOM_RETURN_KEY,
+ )
+ else:
+ query = session.query(TaskMap.length).filter(
+ TaskMap.dag_id == task.dag_id,
+ TaskMap.run_id == run_id,
+ TaskMap.task_id == task.task_id,
+ TaskMap.map_index < 0,
+ )
+ return query.scalar()
@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
@@ -257,23 +357,140 @@ class MapXComArg(XComArg):
convert the pulled XCom value.
"""
- def __init__(self, arg: PlainXComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
+ def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
self.arg = arg
self.callables = callables
- @property
- def operator(self) -> "Operator": # type: ignore[override]
- return self.arg.operator
+ def __repr__(self) -> str:
+ return f"{self.arg!r}.map([{len(self.callables)} functions])"
- @property
- def key(self) -> str: # type: ignore[override]
- return self.arg.key
+ def _serialize(self) -> Dict[str, Any]:
+ return {
+ "arg": serialize_xcom_arg(self.arg),
+ "callables": [inspect.getsource(c) for c in self.callables],
+ }
+
+ @classmethod
+ def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+ # We are deliberately NOT deserializing the callables. These are shown
+ # in the UI, and displaying a function object is useless.
+ return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
+
+ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+ yield from self.arg.iter_references()
def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+ # Flatten arg.map(f1).map(f2) into one MapXComArg.
return MapXComArg(self.arg, [*self.callables, f])
+ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+ return self.arg.get_task_map_length(run_id, session=session)
+
@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
value = self.arg.resolve(context, session=session)
- assert isinstance(value, (Sequence, dict)) # Validation was done when XCom was pushed.
+ if not isinstance(value, (Sequence, dict)):
+ raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
return _MapResult(value, self.callables)
+
+
+class _ZipResult(Sequence):
+ def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None:
+ self.values = values
+ self.fillvalue = fillvalue
+
+ @staticmethod
+ def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any:
+ try:
+ return container[index]
+ except (IndexError, KeyError):
+ return fillvalue
+
+ def __getitem__(self, index: Any) -> Any:
+ if index >= len(self):
+ raise IndexError(index)
+ return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
+
+ def __len__(self) -> int:
+ lengths = (len(v) for v in self.values)
+ if self.fillvalue is NOTSET:
+ return min(lengths)
+ return max(lengths)
+
+
+class ZipXComArg(XComArg):
+ """An XCom reference with ``zip()`` applied.
+
+ This is constructed from multiple XComArg instances, and presents an
+ iterable that "zips" them together like the built-in ``zip()`` (and
+ ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+ """
+
+ def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
+ if not args:
+ raise ValueError("At least one input is required")
+ self.args = args
+ self.fillvalue = fillvalue
+
+ def __repr__(self) -> str:
+ args_iter = iter(self.args)
+ first = repr(next(args_iter))
+ rest = ", ".join(repr(arg) for arg in args_iter)
+ if self.fillvalue is NOTSET:
+ return f"{first}.zip({rest})"
+ return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+ def _serialize(self) -> Dict[str, Any]:
+ args = [serialize_xcom_arg(arg) for arg in self.args]
+ if self.fillvalue is NOTSET:
+ return {"args": args}
+ return {"args": args, "fillvalue": self.fillvalue}
+
+ @classmethod
+ def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+ return cls(
+ [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+ fillvalue=data.get("fillvalue", NOTSET),
+ )
+
+ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+ for arg in self.args:
+ yield from arg.iter_references()
+
+ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+ all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
+ ready_lengths = [length for length in all_lengths if length is not None]
+ if len(ready_lengths) != len(self.args):
+ return None # If any of the referenced XComs is not ready, we are not ready either.
+ if self.fillvalue is NOTSET:
+ return min(ready_lengths)
+ return max(ready_lengths)
+
+ @provide_session
+ def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
+ values = [arg.resolve(context, session=session) for arg in self.args]
+ for value in values:
+ if not isinstance(value, (Sequence, dict)):
+ raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
+ return _ZipResult(values, fillvalue=self.fillvalue)
+
+
+_XCOM_ARG_TYPES: Mapping[str, Type[XComArg]] = {
+ "": PlainXComArg,
+ "map": MapXComArg,
+ "zip": ZipXComArg,
+}
+
+
+def serialize_xcom_arg(value: XComArg) -> Dict[str, Any]:
+ """DAG serialization interface."""
+ key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
+ if key:
+ return {"type": key, **value._serialize()}
+ return value._serialize()
+
+
+def deserialize_xcom_arg(data: Dict[str, Any], dag: "DAG") -> XComArg:
+ """DAG serialization interface."""
+ klass = _XCOM_ARG_TYPES[data.get("type", "")]
+ return klass._deserialize(data, dag)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index e24aa826a4..3ace47e2b0 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -43,7 +43,7 @@ from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskmixin import DAGNode
-from airflow.models.xcom_arg import XComArg
+from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.providers_manager import ProvidersManager
from airflow.sensors.external_task import ExternalTaskSensor
@@ -202,11 +202,10 @@ class _XComRef(NamedTuple):
post-process it in ``deserialize_dag``.
"""
- task_id: str
- key: str
+ data: dict
def deref(self, dag: DAG) -> XComArg:
- return XComArg(operator=dag.get_task(self.task_id), key=self.key)
+ return deserialize_xcom_arg(self.data, dag)
class _ExpandInputRef(NamedTuple):
@@ -393,7 +392,7 @@ class BaseSerialization:
elif isinstance(var, Param):
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
- return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+ return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
else:
@@ -440,7 +439,7 @@ class BaseSerialization:
elif type_ == DAT.PARAM:
return cls._deserialize_param(var)
elif type_ == DAT.XCOM_REF:
- return cls._deserialize_xcomref(var)
+ return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
else:
@@ -545,14 +544,6 @@ class BaseSerialization:
return ParamsDict(op_params)
- @classmethod
- def _serialize_xcomarg(cls, arg: XComArg) -> dict:
- return {"key": arg.key, "task_id": arg.operator.task_id}
-
- @classmethod
- def _deserialize_xcomref(cls, encoded: dict) -> _XComRef:
- return _XComRef(key=encoded['key'], task_id=encoded['task_id'])
-
class DependencyDetector:
"""
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 58ae1c5f87..199366df8e 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -32,7 +32,7 @@ from airflow.models.expandinput import DictOfListsExpandInput
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.models.xcom_arg import XComArg
+from airflow.models.xcom_arg import PlainXComArg, XComArg
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
@@ -649,13 +649,16 @@ def test_partial_mapped_decorator() -> None:
product.partial(multiple=2) # No operator is actually created.
+ assert isinstance(doubled, PlainXComArg)
+ assert isinstance(trippled, PlainXComArg)
+ assert isinstance(quadrupled, PlainXComArg)
+
assert dag.task_dict == {
"product": quadrupled.operator,
"product__1": doubled.operator,
"product__2": trippled.operator,
}
- assert isinstance(doubled, XComArg)
assert isinstance(doubled.operator, DecoratedMappedOperator)
assert doubled.operator.op_kwargs_expand_input == DictOfListsExpandInput({"number": literal})
assert doubled.operator.partial_kwargs["op_kwargs"] == {"multiple": 2}
diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py
index cd3b548285..047412a248 100644
--- a/tests/models/test_xcom_arg.py
+++ b/tests/models/test_xcom_arg.py
@@ -19,6 +19,7 @@ import pytest
from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
+from airflow.utils.types import NOTSET
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs
@@ -177,3 +178,46 @@ class TestXComArgRuntime:
)
op1 >> op2
dag.run()
+
+
+@pytest.mark.parametrize(
+ "fillvalue, expected_results",
+ [
+ (NOTSET, {("a", 1), ("b", 2), ("c", 3)}),
+ (None, {("a", 1), ("b", 2), ("c", 3), (None, 4)}),
+ ],
+)
+def test_xcom_zip(dag_maker, session, fillvalue, expected_results):
+ results = set()
+ with dag_maker(session=session) as dag:
+
+ @dag.task
+ def push_letters():
+ return ["a", "b", "c"]
+
+ @dag.task
+ def push_numbers():
+ return [1, 2, 3, 4]
+
+ @dag.task
+ def pull(value):
+ results.add(value)
+
+ pull.expand(value=push_letters().zip(push_numbers(), fillvalue=fillvalue))
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push_letters" and "push_numbers".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
+ for ti in decision.schedulable_tis:
+ ti.run(session=session)
+ session.commit()
+
+ # Run "pull".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
+ for ti in decision.schedulable_tis:
+ ti.run(session=session)
+
+ assert results == expected_results
diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py
index 144eb20327..9da732b93d 100644
--- a/tests/models/test_xcom_arg_map.py
+++ b/tests/models/test_xcom_arg_map.py
@@ -257,3 +257,47 @@ def test_xcom_map_nest(dag_maker, session):
for ti in decision.schedulable_tis:
ti.run()
assert results == {"aa", "bb", "cc"}
+
+
+def test_xcom_map_zip_nest(dag_maker, session):
+ results = set()
+
+ with dag_maker(session=session) as dag:
+
+ @dag.task
+ def push_letters():
+ return ["a", "b", "c", "d"]
+
+ @dag.task
+ def push_numbers():
+ return [1, 2, 3, 4]
+
+ @dag.task
+ def pull(value):
+ results.add(value)
+
+ doubled = push_numbers().map(lambda v: v * 2)
+ combined = doubled.zip(push_letters())
+
+ def convert_zipped(zipped):
+ letter, number = zipped
+ return letter * number
+
+ pull.expand(value=combined.map(convert_zipped))
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push_letters" and "push_numbers".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
+ for ti in decision.schedulable_tis:
+ ti.run(session=session)
+ session.commit()
+
+ # Run "pull".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
+ for ti in decision.schedulable_tis:
+ ti.run(session=session)
+
+ assert results == {"aa", "bbbb", "cccccc", "dddddddd"}
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 5751ae137c..c6444b8aaa 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1731,7 +1731,8 @@ def test_operator_expand_serde():
def test_operator_expand_xcomarg_serde():
- from airflow.models.xcom_arg import XComArg
+ from airflow.models.xcom_arg import PlainXComArg, XComArg
+ from airflow.serialization.serialized_objects import _XComRef
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
task1 = BaseOperator(task_id="op1")
@@ -1766,20 +1767,21 @@ def test_operator_expand_xcomarg_serde():
op = SerializedBaseOperator.deserialize_operator(serialized)
assert op.deps is MappedOperator.deps_for(BaseOperator)
- arg = op.expand_input.value['arg2']
- assert arg.task_id == 'op1'
- assert arg.key == XCOM_RETURN_KEY
+ # The XComArg can't be deserialized before the DAG is.
+ xcom_ref = op.expand_input.value['arg2']
+ assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value['arg2']
- assert isinstance(xcom_arg, XComArg)
+ assert isinstance(xcom_arg, PlainXComArg)
assert xcom_arg.operator is serialized_dag.task_dict['op1']
@pytest.mark.parametrize("strict", [True, False])
def test_operator_expand_kwargs_serde(strict):
- from airflow.models.xcom_arg import XComArg
+ from airflow.models.xcom_arg import PlainXComArg, XComArg
+ from airflow.serialization.serialized_objects import _XComRef
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
task1 = BaseOperator(task_id="op1")
@@ -1812,14 +1814,14 @@ def test_operator_expand_kwargs_serde(strict):
assert op.deps is MappedOperator.deps_for(BaseOperator)
assert op._disallow_kwargs_override == strict
+ # The XComArg can't be deserialized before the DAG is.
xcom_ref = op.expand_input.value
- assert xcom_ref.task_id == 'op1'
- assert xcom_ref.key == XCOM_RETURN_KEY
+ assert xcom_ref == _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value
- assert isinstance(xcom_arg, XComArg)
+ assert isinstance(xcom_arg, PlainXComArg)
assert xcom_arg.operator is serialized_dag.task_dict['op1']
@@ -1913,7 +1915,7 @@ def test_taskflow_expand_serde():
assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
key="dict-of-lists",
- value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", XCOM_RETURN_KEY)},
+ value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})},
)
assert deserialized.partial_kwargs == {
"op_args": [],
@@ -1928,7 +1930,7 @@ def test_taskflow_expand_serde():
pickled = pickle.loads(pickle.dumps(deserialized))
assert pickled.op_kwargs_expand_input == _ExpandInputRef(
key="dict-of-lists",
- value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef("op1", XCOM_RETURN_KEY)},
+ value={"arg2": {"a": 1, "b": 2}, "arg3": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})},
)
assert pickled.partial_kwargs == {
"op_args": [],
@@ -1996,7 +1998,7 @@ def test_taskflow_expand_kwargs_serde(strict):
assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
key="list-of-dicts",
- value=_XComRef("op1", XCOM_RETURN_KEY),
+ value=_XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
)
assert deserialized.partial_kwargs == {
"op_args": [],
@@ -2011,7 +2013,7 @@ def test_taskflow_expand_kwargs_serde(strict):
pickled = pickle.loads(pickle.dumps(deserialized))
assert pickled.op_kwargs_expand_input == _ExpandInputRef(
"list-of-dicts",
- _XComRef("op1", XCOM_RETURN_KEY),
+ _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY}),
)
assert pickled.partial_kwargs == {
"op_args": [],