You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/25 03:41:09 UTC
[airflow] branch main updated: Implement map() semantic (#25085)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 877dc89004 Implement map() semantic (#25085)
877dc89004 is described below
commit 877dc890047357b664c5f9ee97d58653ccc24237
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Jul 25 11:41:00 2022 +0800
Implement map() semantic (#25085)
---
airflow/exceptions.py | 12 -
airflow/models/expandinput.py | 18 --
airflow/models/mappedoperator.py | 15 --
airflow/models/taskinstance.py | 13 +-
airflow/models/xcom_arg.py | 219 ++++++++++++-----
.../pre_commit_base_operator_partial_arguments.py | 1 -
tests/models/test_taskinstance.py | 3 -
tests/models/test_xcom_arg_map.py | 259 +++++++++++++++++++++
8 files changed, 423 insertions(+), 117 deletions(-)
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 7a91100f11..924d326ca9 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -123,18 +123,6 @@ class UnmappableXComTypePushed(AirflowException):
return f"unmappable return type {typename!r}"
-class UnmappableXComValuePushed(AirflowException):
- """Raise when an invalid value is pushed as a mapped downstream's dependency."""
-
- def __init__(self, value: Any, reason: str) -> None:
- super().__init__(value, reason)
- self.value = value
- self.reason = reason
-
- def __str__(self) -> str:
- return f"unmappable return value {self.value!r} ({self.reason})"
-
-
class UnmappableXComLengthPushed(AirflowException):
"""Raise when the pushed value is too large to map as a downstream's dependency."""
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index b5b922f9df..7aab2446e2 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -28,7 +28,6 @@ from sqlalchemy import func
from sqlalchemy.orm import Session
from airflow.compat.functools import cache
-from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed
from airflow.utils.context import Context
if TYPE_CHECKING:
@@ -72,11 +71,6 @@ class DictOfListsExpandInput(NamedTuple):
value: dict[str, Mappable]
- @staticmethod
- def validate_xcom(value: Any) -> None:
- if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
- raise UnmappableXComTypePushed(value)
-
def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving."""
return self.value
@@ -212,18 +206,6 @@ class ListOfDictsExpandInput(NamedTuple):
value: XComArg
- @staticmethod
- def validate_xcom(value: Any) -> None:
- if not isinstance(value, collections.abc.Collection):
- raise UnmappableXComTypePushed(value)
- if isinstance(value, (str, bytes, collections.abc.Mapping)):
- raise UnmappableXComTypePushed(value)
- for item in value:
- if not isinstance(item, collections.abc.Mapping):
- raise UnmappableXComTypePushed(value, item)
- if not all(isinstance(k, str) for k in item):
- raise UnmappableXComValuePushed(value, reason="dict keys must be str")
-
def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving.
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index a883ff2404..dead62753c 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -23,7 +23,6 @@ import warnings
from typing import (
TYPE_CHECKING,
Any,
- Callable,
ClassVar,
Collection,
Dict,
@@ -613,20 +612,6 @@ class MappedOperator(AbstractOperator):
"""Input received from the expand call on the operator."""
return getattr(self, self._expand_input_attr)
- @property
- def validate_upstream_return_value(self) -> Callable[[Any], None]:
- """Validate an upstream's return value satisfies this task's needs.
-
- This is implemented as a property (instead of a function calling
- ``validate_xcom``) so the call site in TaskInstance can de-duplicate
- validation functions. If this is an instance method, each
- ``validate_upstream_return_value`` would be a different object (due to
- how Python handles bounded functions), and de-duplication won't work.
-
- :meta private:
- """
- return self._get_specified_expand_input().validate_xcom
-
def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]:
"""Create the mapped task instances for mapped task.
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index bb4eff6faf..725070500c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -92,6 +92,7 @@ from airflow.exceptions import (
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
+ UnmappableXComTypePushed,
XComForMappingNotPushed,
)
from airflow.models.base import Base, StringID
@@ -2330,8 +2331,7 @@ class TaskInstance(Base, LoggingMixin):
self.log.debug("Task Duration set to %s", self.duration)
def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None:
- validators = {m.validate_upstream_return_value for m in task.iter_mapped_dependants()}
- if not validators: # No mapped dependants, no need to validate.
+ if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
@@ -2341,9 +2341,12 @@ class TaskInstance(Base, LoggingMixin):
return
if value is None:
raise XComForMappingNotPushed()
- for validator in validators:
- validator(value)
- assert isinstance(value, collections.abc.Collection) # The validators type-guard this.
+ if not isinstance(value, (collections.abc.Sequence, dict)):
+ raise UnmappableXComTypePushed(value)
+ if isinstance(value, (bytes, str)):
+ raise UnmappableXComTypePushed(value)
+ if TYPE_CHECKING: # The isinstance() checks above guard this.
+ assert isinstance(value, collections.abc.Collection)
task_map = TaskMap.from_task_instance_xcom(self, value)
max_map_length = conf.getint("core", "max_map_length", fallback=1024)
if task_map.length > max_map_length:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 2c2a8f9b58..0a602d0487 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Union
+#
+from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Type, Union, overload
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import AbstractOperator
@@ -32,15 +33,14 @@ if TYPE_CHECKING:
class XComArg(DependencyMixin):
- """
- Class that represents a XCom push from a previous operator.
- Defaults to "return_value" as only key.
+ """Reference to an XCom value pushed from another operator.
+
+ The implementation supports::
- Current implementation supports
xcomarg >> op
xcomarg << op
- op >> xcomarg (by BaseOperator code)
- op << xcomarg (by BaseOperator code)
+ op >> xcomarg # By BaseOperator code
+ op << xcomarg # By BaseOperator code
**Example**: The moment you get a result from any operator (decorated or regular) you can ::
@@ -53,15 +53,113 @@ class XComArg(DependencyMixin):
This object can be used in legacy Operators via Jinja.
- **Example**: You can make this result to be part of any generated string ::
+ **Example**: You can make this result to be part of any generated string::
any_op = AnyOperator()
xcomarg = any_op.output
op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
- :param operator: operator to which the XComArg belongs to
- :param key: key value which is used for xcom_pull (key in the XCom table)
+ :param operator: Operator instance to which the XComArg references.
+ :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
+ i.e. the referenced operator's return value.
+ """
+
+ operator: "Operator"
+ key: str
+
+ @overload
+ def __new__(cls: Type["XComArg"], operator: "Operator", key: str = XCOM_RETURN_KEY) -> "XComArg":
+ """Called when the user writes ``XComArg(...)`` directly."""
+
+ @overload
+ def __new__(cls: Type["XComArg"]) -> "XComArg":
+ """Called by Python internals from subclasses."""
+
+ def __new__(cls, *args, **kwargs) -> "XComArg":
+ if cls is XComArg:
+ return PlainXComArg(*args, **kwargs)
+ return super().__new__(cls)
+
+ @staticmethod
+ def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
+ """Return XComArg instances in an arbitrary value.
+
+ Recursively traverse ``arg`` and look for XComArg instances in any
+ collection objects, and instances with ``template_fields`` set.
+ """
+ if isinstance(arg, XComArg):
+ yield arg
+ elif isinstance(arg, (tuple, set, list)):
+ for elem in arg:
+ yield from XComArg.iter_xcom_args(elem)
+ elif isinstance(arg, dict):
+ for elem in arg.values():
+ yield from XComArg.iter_xcom_args(elem)
+ elif isinstance(arg, AbstractOperator):
+ for elem in arg.template_fields:
+ yield from XComArg.iter_xcom_args(elem)
+
+ @staticmethod
+ def apply_upstream_relationship(op: "Operator", arg: Any):
+ """Set dependency for XComArgs.
+
+ This looks for XComArg objects in ``arg`` "deeply" (looking inside
+ collections objects and classes decorated with ``template_fields``), and
+ sets the relationship to ``op`` on any found.
+ """
+ for ref in XComArg.iter_xcom_args(arg):
+ op.set_upstream(ref.operator)
+
+ @property
+ def roots(self) -> List[DAGNode]:
+ """Required by TaskMixin"""
+ return [self.operator]
+
+ @property
+ def leaves(self) -> List[DAGNode]:
+ """Required by TaskMixin"""
+ return [self.operator]
+
+ def set_upstream(
+ self,
+ task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
+ edge_modifier: Optional[EdgeModifier] = None,
+ ):
+ """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
+ self.operator.set_upstream(task_or_task_list, edge_modifier)
+
+ def set_downstream(
+ self,
+ task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
+ edge_modifier: Optional[EdgeModifier] = None,
+ ):
+ """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
+ self.operator.set_downstream(task_or_task_list, edge_modifier)
+
+ def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+ raise NotImplementedError()
+
+ def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
+ raise NotImplementedError()
+
+
+class PlainXComArg(XComArg):
+ """Reference to one single XCom without any additional semantics.
+
+ This class should not be accessed directly, but only through XComArg. The
+ class inheritance chain and ``__new__`` is implemented in this slightly
+ convoluted way because we want to
+
+ a. Allow the user to continue using XComArg directly for the simple
+ semantics (see documentation of the base class for details).
+ b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
+ references.
+ c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
+ ``__str__``) to exist on other kinds of XComArg implementations since
+ they don't make sense.
+
+ :meta private:
"""
def __init__(self, operator: "Operator", key: str = XCOM_RETURN_KEY):
@@ -69,13 +167,15 @@ class XComArg(DependencyMixin):
self.key = key
def __eq__(self, other):
+ if not isinstance(other, PlainXComArg):
+ return NotImplemented
return self.operator == other.operator and self.key == other.key
def __getitem__(self, item: str) -> "XComArg":
"""Implements xcomresult['some_result_key']"""
if not isinstance(item, str):
raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
- return XComArg(operator=self.operator, key=item)
+ return PlainXComArg(operator=self.operator, key=item)
def __iter__(self):
"""Override iterable protocol to raise error explicitly.
@@ -89,7 +189,7 @@ class XComArg(DependencyMixin):
This override catches the error eagerly, so an incorrectly implemented
DAG fails fast and avoids wasting resources on nonsensical iterating.
"""
- raise TypeError(f"{self.__class__.__name__!r} object is not iterable")
+ raise TypeError("'XComArg' object is not iterable")
def __str__(self):
"""
@@ -113,31 +213,10 @@ class XComArg(DependencyMixin):
xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
return xcom_pull
- @property
- def roots(self) -> List[DAGNode]:
- """Required by TaskMixin"""
- return [self.operator]
-
- @property
- def leaves(self) -> List[DAGNode]:
- """Required by TaskMixin"""
- return [self.operator]
-
- def set_upstream(
- self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
- edge_modifier: Optional[EdgeModifier] = None,
- ):
- """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
- self.operator.set_upstream(task_or_task_list, edge_modifier)
-
- def set_downstream(
- self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
- edge_modifier: Optional[EdgeModifier] = None,
- ):
- """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
- self.operator.set_downstream(task_or_task_list, edge_modifier)
+ def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+ if self.key != XCOM_RETURN_KEY:
+ raise ValueError
+ return MapXComArg(self, [f])
@provide_session
def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
@@ -155,32 +234,46 @@ class XComArg(DependencyMixin):
)
return result
- @staticmethod
- def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
- """Return XComArg instances in an arbitrary value.
- This recursively traverse ``arg`` and look for XComArg instances in any
- collection objects, and instances with ``template_fields`` set.
- """
- if isinstance(arg, XComArg):
- yield arg
- elif isinstance(arg, (tuple, set, list)):
- for elem in arg:
- yield from XComArg.iter_xcom_args(elem)
- elif isinstance(arg, dict):
- for elem in arg.values():
- yield from XComArg.iter_xcom_args(elem)
- elif isinstance(arg, AbstractOperator):
- for elem in arg.template_fields:
- yield from XComArg.iter_xcom_args(elem)
+class _MapResult(Sequence):
+ def __init__(self, value: Union[Sequence, dict], callables: Sequence[Callable[[Any], Any]]) -> None:
+ self.value = value
+ self.callables = callables
- @staticmethod
- def apply_upstream_relationship(op: "Operator", arg: Any):
- """Set dependency for XComArgs.
+ def __getitem__(self, index: Any) -> Any:
+ value = self.value[index]
+ for f in self.callables:
+ value = f(value)
+ return value
- This looks for XComArg objects in ``arg`` "deeply" (looking inside
- collections objects and classes decorated with ``template_fields``), and
- sets the relationship to ``op`` on any found.
- """
- for ref in XComArg.iter_xcom_args(arg):
- op.set_upstream(ref.operator)
+ def __len__(self) -> int:
+ return len(self.value)
+
+
+class MapXComArg(XComArg):
+ """An XCom reference with ``map()`` call(s) applied.
+
+ This is based on an XComArg, but also applies a series of "transforms" that
+ convert the pulled XCom value.
+ """
+
+ def __init__(self, arg: PlainXComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
+ self.arg = arg
+ self.callables = callables
+
+ @property
+ def operator(self) -> "Operator": # type: ignore[override]
+ return self.arg.operator
+
+ @property
+ def key(self) -> str: # type: ignore[override]
+ return self.arg.key
+
+ def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+ return MapXComArg(self.arg, [*self.callables, f])
+
+ @provide_session
+ def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
+ value = self.arg.resolve(context, session=session)
+ assert isinstance(value, (Sequence, dict)) # Validation was done when XCom was pushed.
+ return _MapResult(value, self.callables)
diff --git a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
index 909693de5a..29d6c7df4d 100755
--- a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
@@ -52,7 +52,6 @@ IGNORED = {
# Only on MappedOperator.
"expand_input",
"partial_kwargs",
- "validate_upstream_return_value",
}
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index f0b53fbf5d..87ac9b69cc 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -43,7 +43,6 @@ from airflow.exceptions import (
AirflowSkipException,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
- UnmappableXComValuePushed,
XComForMappingNotPushed,
)
from airflow.models import (
@@ -2614,8 +2613,6 @@ class TestTaskInstanceRecordTaskMapXComPush:
"return_value, exception_type, error_message",
[
(123, UnmappableXComTypePushed, "unmappable return type 'int'"),
- ([123], UnmappableXComTypePushed, "unmappable return type 'list[int]'"),
- ([{1: 3}], UnmappableXComValuePushed, "unmappable return value [{1: 3}] (dict keys must be str)"),
(None, XComForMappingNotPushed, "did not push XCom for task mapping"),
],
)
diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py
new file mode 100644
index 0000000000..5807f1b296
--- /dev/null
+++ b/tests/models/test_xcom_arg_map.py
@@ -0,0 +1,259 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import pytest
+
+from airflow.exceptions import AirflowSkipException
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.trigger_rule import TriggerRule
+
+
+def test_xcom_map(dag_maker, session):
+ results = set()
+ with dag_maker(session=session) as dag:
+
+ @dag.task
+ def push():
+ return ["a", "b", "c"]
+
+ @dag.task
+ def pull(value):
+ results.add(value)
+
+ pull.expand_kwargs(push().map(lambda v: {"value": v * 2}))
+
+ # The function passed to "map" is *NOT* a task.
+ assert set(dag.task_dict) == {"push", "pull"}
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run(session=session)
+ session.commit()
+
+ # Run "pull".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+ assert sorted(tis) == [("pull", 0), ("pull", 1), ("pull", 2)]
+ for ti in tis.values():
+ ti.run(session=session)
+
+ assert results == {"aa", "bb", "cc"}
+
+
+def test_xcom_map_transform_to_none(dag_maker, session):
+ results = set()
+
+ with dag_maker(session=session) as dag:
+
+ @dag.task()
+ def push():
+ return ["a", "b", "c"]
+
+ @dag.task()
+ def pull(value):
+ results.add(value)
+
+ def c_to_none(v):
+ if v == "c":
+ return None
+ return v
+
+ pull.expand(value=push().map(c_to_none))
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ # Run "pull". This should automatically convert "c" to None.
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+ assert results == {"a", "b", None}
+
+
+def test_xcom_convert_to_kwargs_fails_task(dag_maker, session):
+ results = set()
+
+ with dag_maker(session=session) as dag:
+
+ @dag.task()
+ def push():
+ return ["a", "b", "c"]
+
+ @dag.task()
+ def pull(value):
+ results.add(value)
+
+ def c_to_none(v):
+ if v == "c":
+ return None
+ return {"value": v}
+
+ pull.expand_kwargs(push().map(c_to_none))
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ # Prepare to run "pull"...
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+
+ # The first two "pull" tis should also succeed.
+ tis[("pull", 0)].run()
+ tis[("pull", 1)].run()
+
+ # But the third one fails because the map() result cannot be used as kwargs.
+ with pytest.raises(TypeError) as ctx:
+ tis[("pull", 2)].run()
+ assert str(ctx.value) == "'NoneType' object is not iterable"
+
+ assert [tis[("pull", i)].state for i in range(3)] == [
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.FAILED,
+ ]
+
+
+def test_xcom_map_error_fails_task(dag_maker, session):
+ with dag_maker(session=session) as dag:
+
+ @dag.task()
+ def push():
+ return ["a", "b", "c"]
+
+ @dag.task()
+ def pull(value):
+ print(value)
+
+ def does_not_work_with_c(v):
+ if v == "c":
+ raise ValueError("nope")
+ return {"value": v * 2}
+
+ pull.expand_kwargs(push().map(does_not_work_with_c))
+
+ dr = dag_maker.create_dagrun()
+
+ # The "push" task should not fail.
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+ assert [ti.state for ti in decision.schedulable_tis] == [TaskInstanceState.SUCCESS]
+
+ # Prepare to run "pull"...
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+
+ # The first two "pull" tis should also succeed.
+ tis[("pull", 0)].run()
+ tis[("pull", 1)].run()
+
+ # But the third one (for "c") will fail.
+ with pytest.raises(ValueError) as ctx:
+ tis[("pull", 2)].run()
+ assert str(ctx.value) == "nope"
+
+ assert [tis[("pull", i)].state for i in range(3)] == [
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.FAILED,
+ ]
+
+
+def test_xcom_map_raise_to_skip(dag_maker, session):
+ result = None
+
+ with dag_maker(session=session) as dag:
+
+ @dag.task()
+ def push():
+ return ["a", "b", "c"]
+
+ @dag.task()
+ def forward(value):
+ return value
+
+ @dag.task(trigger_rule=TriggerRule.ALL_DONE)
+ def collect(value):
+ nonlocal result
+ result = list(value)
+
+ def skip_c(v):
+ if v == "c":
+ raise AirflowSkipException
+ return {"value": v}
+
+ collect(value=forward.expand_kwargs(push().map(skip_c)))
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ # Run "forward". This should automatically skip "c".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ # Now "collect" should only get "a" and "b".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+ assert result == ["a", "b"]
+
+
+def test_xcom_map_nest(dag_maker, session):
+ results = set()
+
+ with dag_maker(session=session) as dag:
+
+ @dag.task()
+ def push():
+ return ["a", "b", "c"]
+
+ @dag.task()
+ def pull(value):
+ results.add(value)
+
+ converted = push().map(lambda v: v * 2).map(lambda v: {"value": v})
+ pull.expand_kwargs(converted)
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ # Now "pull" should apply the mapping functions in order.
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.run()
+ assert results == {"aa", "bb", "cc"}