You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/25 03:41:09 UTC

[airflow] branch main updated: Implement map() semantic (#25085)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 877dc89004 Implement map() semantic (#25085)
877dc89004 is described below

commit 877dc890047357b664c5f9ee97d58653ccc24237
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Jul 25 11:41:00 2022 +0800

    Implement map() semantic (#25085)
---
 airflow/exceptions.py                              |  12 -
 airflow/models/expandinput.py                      |  18 --
 airflow/models/mappedoperator.py                   |  15 --
 airflow/models/taskinstance.py                     |  13 +-
 airflow/models/xcom_arg.py                         | 219 ++++++++++++-----
 .../pre_commit_base_operator_partial_arguments.py  |   1 -
 tests/models/test_taskinstance.py                  |   3 -
 tests/models/test_xcom_arg_map.py                  | 259 +++++++++++++++++++++
 8 files changed, 423 insertions(+), 117 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 7a91100f11..924d326ca9 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -123,18 +123,6 @@ class UnmappableXComTypePushed(AirflowException):
         return f"unmappable return type {typename!r}"
 
 
-class UnmappableXComValuePushed(AirflowException):
-    """Raise when an invalid value is pushed as a mapped downstream's dependency."""
-
-    def __init__(self, value: Any, reason: str) -> None:
-        super().__init__(value, reason)
-        self.value = value
-        self.reason = reason
-
-    def __str__(self) -> str:
-        return f"unmappable return value {self.value!r} ({self.reason})"
-
-
 class UnmappableXComLengthPushed(AirflowException):
     """Raise when the pushed value is too large to map as a downstream's dependency."""
 
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index b5b922f9df..7aab2446e2 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -28,7 +28,6 @@ from sqlalchemy import func
 from sqlalchemy.orm import Session
 
 from airflow.compat.functools import cache
-from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed
 from airflow.utils.context import Context
 
 if TYPE_CHECKING:
@@ -72,11 +71,6 @@ class DictOfListsExpandInput(NamedTuple):
 
     value: dict[str, Mappable]
 
-    @staticmethod
-    def validate_xcom(value: Any) -> None:
-        if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
-            raise UnmappableXComTypePushed(value)
-
     def get_unresolved_kwargs(self) -> dict[str, Any]:
         """Get the kwargs dict that can be inferred without resolving."""
         return self.value
@@ -212,18 +206,6 @@ class ListOfDictsExpandInput(NamedTuple):
 
     value: XComArg
 
-    @staticmethod
-    def validate_xcom(value: Any) -> None:
-        if not isinstance(value, collections.abc.Collection):
-            raise UnmappableXComTypePushed(value)
-        if isinstance(value, (str, bytes, collections.abc.Mapping)):
-            raise UnmappableXComTypePushed(value)
-        for item in value:
-            if not isinstance(item, collections.abc.Mapping):
-                raise UnmappableXComTypePushed(value, item)
-            if not all(isinstance(k, str) for k in item):
-                raise UnmappableXComValuePushed(value, reason="dict keys must be str")
-
     def get_unresolved_kwargs(self) -> dict[str, Any]:
         """Get the kwargs dict that can be inferred without resolving.
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index a883ff2404..dead62753c 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -23,7 +23,6 @@ import warnings
 from typing import (
     TYPE_CHECKING,
     Any,
-    Callable,
     ClassVar,
     Collection,
     Dict,
@@ -613,20 +612,6 @@ class MappedOperator(AbstractOperator):
         """Input received from the expand call on the operator."""
         return getattr(self, self._expand_input_attr)
 
-    @property
-    def validate_upstream_return_value(self) -> Callable[[Any], None]:
-        """Validate an upstream's return value satisfies this task's needs.
-
-        This is implemented as a property (instead of a function calling
-        ``validate_xcom``) so the call site in TaskInstance can de-duplicate
-        validation functions. If this is an instance method, each
-        ``validate_upstream_return_value`` would be a different object (due to
-        how Python handles bounded functions), and de-duplication won't work.
-
-        :meta private:
-        """
-        return self._get_specified_expand_input().validate_xcom
-
     def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]:
         """Create the mapped task instances for mapped task.
 
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index bb4eff6faf..725070500c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -92,6 +92,7 @@ from airflow.exceptions import (
     TaskDeferralError,
     TaskDeferred,
     UnmappableXComLengthPushed,
+    UnmappableXComTypePushed,
     XComForMappingNotPushed,
 )
 from airflow.models.base import Base, StringID
@@ -2330,8 +2331,7 @@ class TaskInstance(Base, LoggingMixin):
         self.log.debug("Task Duration set to %s", self.duration)
 
     def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None:
-        validators = {m.validate_upstream_return_value for m in task.iter_mapped_dependants()}
-        if not validators:  # No mapped dependants, no need to validate.
+        if next(task.iter_mapped_dependants(), None) is None:  # No mapped dependants, no need to validate.
             return
         # TODO: We don't push TaskMap for mapped task instances because it's not
         # currently possible for a downstream to depend on one individual mapped
@@ -2341,9 +2341,12 @@ class TaskInstance(Base, LoggingMixin):
             return
         if value is None:
             raise XComForMappingNotPushed()
-        for validator in validators:
-            validator(value)
-        assert isinstance(value, collections.abc.Collection)  # The validators type-guard this.
+        if not isinstance(value, (collections.abc.Sequence, dict)):
+            raise UnmappableXComTypePushed(value)
+        if isinstance(value, (bytes, str)):
+            raise UnmappableXComTypePushed(value)
+        if TYPE_CHECKING:  # The isinstance() checks above guard this.
+            assert isinstance(value, collections.abc.Collection)
         task_map = TaskMap.from_task_instance_xcom(self, value)
         max_map_length = conf.getint("core", "max_map_length", fallback=1024)
         if task_map.length > max_map_length:
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 2c2a8f9b58..0a602d0487 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -14,7 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Union
+#
+from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Sequence, Type, Union, overload
 
 from airflow.exceptions import AirflowException
 from airflow.models.abstractoperator import AbstractOperator
@@ -32,15 +33,14 @@ if TYPE_CHECKING:
 
 
 class XComArg(DependencyMixin):
-    """
-    Class that represents a XCom push from a previous operator.
-    Defaults to "return_value" as only key.
+    """Reference to an XCom value pushed from another operator.
+
+    The implementation supports::
 
-    Current implementation supports
         xcomarg >> op
         xcomarg << op
-        op >> xcomarg   (by BaseOperator code)
-        op << xcomarg   (by BaseOperator code)
+        op >> xcomarg   # By BaseOperator code
+        op << xcomarg   # By BaseOperator code
 
     **Example**: The moment you get a result from any operator (decorated or regular) you can ::
 
@@ -53,15 +53,113 @@ class XComArg(DependencyMixin):
 
     This object can be used in legacy Operators via Jinja.
 
-    **Example**: You can make this result to be part of any generated string ::
+    **Example**: You can make this result to be part of any generated string::
 
         any_op = AnyOperator()
         xcomarg = any_op.output
         op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
         op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
 
-    :param operator: operator to which the XComArg belongs to
-    :param key: key value which is used for xcom_pull (key in the XCom table)
+    :param operator: Operator instance to which the XComArg references.
+    :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
+        i.e. the referenced operator's return value.
+    """
+
+    operator: "Operator"
+    key: str
+
+    @overload
+    def __new__(cls: Type["XComArg"], operator: "Operator", key: str = XCOM_RETURN_KEY) -> "XComArg":
+        """Called when the user writes ``XComArg(...)`` directly."""
+
+    @overload
+    def __new__(cls: Type["XComArg"]) -> "XComArg":
+        """Called by Python internals from subclasses."""
+
+    def __new__(cls, *args, **kwargs) -> "XComArg":
+        if cls is XComArg:
+            return PlainXComArg(*args, **kwargs)
+        return super().__new__(cls)
+
+    @staticmethod
+    def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
+        """Return XComArg instances in an arbitrary value.
+
+        Recursively traverse ``arg`` and look for XComArg instances in any
+        collection objects, and instances with ``template_fields`` set.
+        """
+        if isinstance(arg, XComArg):
+            yield arg
+        elif isinstance(arg, (tuple, set, list)):
+            for elem in arg:
+                yield from XComArg.iter_xcom_args(elem)
+        elif isinstance(arg, dict):
+            for elem in arg.values():
+                yield from XComArg.iter_xcom_args(elem)
+        elif isinstance(arg, AbstractOperator):
+            for elem in arg.template_fields:
+                yield from XComArg.iter_xcom_args(elem)
+
+    @staticmethod
+    def apply_upstream_relationship(op: "Operator", arg: Any):
+        """Set dependency for XComArgs.
+
+        This looks for XComArg objects in ``arg`` "deeply" (looking inside
+        collections objects and classes decorated with ``template_fields``), and
+        sets the relationship to ``op`` on any found.
+        """
+        for ref in XComArg.iter_xcom_args(arg):
+            op.set_upstream(ref.operator)
+
+    @property
+    def roots(self) -> List[DAGNode]:
+        """Required by TaskMixin"""
+        return [self.operator]
+
+    @property
+    def leaves(self) -> List[DAGNode]:
+        """Required by TaskMixin"""
+        return [self.operator]
+
+    def set_upstream(
+        self,
+        task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
+        edge_modifier: Optional[EdgeModifier] = None,
+    ):
+        """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
+        self.operator.set_upstream(task_or_task_list, edge_modifier)
+
+    def set_downstream(
+        self,
+        task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
+        edge_modifier: Optional[EdgeModifier] = None,
+    ):
+        """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
+        self.operator.set_downstream(task_or_task_list, edge_modifier)
+
+    def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+        raise NotImplementedError()
+
+    def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
+        raise NotImplementedError()
+
+
+class PlainXComArg(XComArg):
+    """Reference to one single XCom without any additional semantics.
+
+    This class should not be accessed directly, but only through XComArg. The
+    class inheritance chain and ``__new__`` is implemented in this slightly
+    convoluted way because we want to
+
+    a. Allow the user to continue using XComArg directly for the simple
+       semantics (see documentation of the base class for details).
+    b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
+       references.
+    c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
+       ``__str__``) to exist on other kinds of XComArg implementations since
+       they don't make sense.
+
+    :meta private:
     """
 
     def __init__(self, operator: "Operator", key: str = XCOM_RETURN_KEY):
@@ -69,13 +167,15 @@ class XComArg(DependencyMixin):
         self.key = key
 
     def __eq__(self, other):
+        if not isinstance(other, PlainXComArg):
+            return NotImplemented
         return self.operator == other.operator and self.key == other.key
 
     def __getitem__(self, item: str) -> "XComArg":
         """Implements xcomresult['some_result_key']"""
         if not isinstance(item, str):
             raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
-        return XComArg(operator=self.operator, key=item)
+        return PlainXComArg(operator=self.operator, key=item)
 
     def __iter__(self):
         """Override iterable protocol to raise error explicitly.
@@ -89,7 +189,7 @@ class XComArg(DependencyMixin):
         This override catches the error eagerly, so an incorrectly implemented
         DAG fails fast and avoids wasting resources on nonsensical iterating.
         """
-        raise TypeError(f"{self.__class__.__name__!r} object is not iterable")
+        raise TypeError("'XComArg' object is not iterable")
 
     def __str__(self):
         """
@@ -113,31 +213,10 @@ class XComArg(DependencyMixin):
         xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
         return xcom_pull
 
-    @property
-    def roots(self) -> List[DAGNode]:
-        """Required by TaskMixin"""
-        return [self.operator]
-
-    @property
-    def leaves(self) -> List[DAGNode]:
-        """Required by TaskMixin"""
-        return [self.operator]
-
-    def set_upstream(
-        self,
-        task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
-        edge_modifier: Optional[EdgeModifier] = None,
-    ):
-        """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
-        self.operator.set_upstream(task_or_task_list, edge_modifier)
-
-    def set_downstream(
-        self,
-        task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
-        edge_modifier: Optional[EdgeModifier] = None,
-    ):
-        """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
-        self.operator.set_downstream(task_or_task_list, edge_modifier)
+    def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+        if self.key != XCOM_RETURN_KEY:
+            raise ValueError
+        return MapXComArg(self, [f])
 
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
@@ -155,32 +234,46 @@ class XComArg(DependencyMixin):
             )
         return result
 
-    @staticmethod
-    def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
-        """Return XComArg instances in an arbitrary value.
 
-        This recursively traverse ``arg`` and look for XComArg instances in any
-        collection objects, and instances with ``template_fields`` set.
-        """
-        if isinstance(arg, XComArg):
-            yield arg
-        elif isinstance(arg, (tuple, set, list)):
-            for elem in arg:
-                yield from XComArg.iter_xcom_args(elem)
-        elif isinstance(arg, dict):
-            for elem in arg.values():
-                yield from XComArg.iter_xcom_args(elem)
-        elif isinstance(arg, AbstractOperator):
-            for elem in arg.template_fields:
-                yield from XComArg.iter_xcom_args(elem)
+class _MapResult(Sequence):
+    def __init__(self, value: Union[Sequence, dict], callables: Sequence[Callable[[Any], Any]]) -> None:
+        self.value = value
+        self.callables = callables
 
-    @staticmethod
-    def apply_upstream_relationship(op: "Operator", arg: Any):
-        """Set dependency for XComArgs.
+    def __getitem__(self, index: Any) -> Any:
+        value = self.value[index]
+        for f in self.callables:
+            value = f(value)
+        return value
 
-        This looks for XComArg objects in ``arg`` "deeply" (looking inside
-        collections objects and classes decorated with ``template_fields``), and
-        sets the relationship to ``op`` on any found.
-        """
-        for ref in XComArg.iter_xcom_args(arg):
-            op.set_upstream(ref.operator)
+    def __len__(self) -> int:
+        return len(self.value)
+
+
+class MapXComArg(XComArg):
+    """An XCom reference with ``map()`` call(s) applied.
+
+    This is based on an XComArg, but also applies a series of "transforms" that
+    convert the pulled XCom value.
+    """
+
+    def __init__(self, arg: PlainXComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
+        self.arg = arg
+        self.callables = callables
+
+    @property
+    def operator(self) -> "Operator":  # type: ignore[override]
+        return self.arg.operator
+
+    @property
+    def key(self) -> str:  # type: ignore[override]
+        return self.arg.key
+
+    def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
+        return MapXComArg(self.arg, [*self.callables, f])
+
+    @provide_session
+    def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
+        value = self.arg.resolve(context, session=session)
+        assert isinstance(value, (Sequence, dict))  # Validation was done when XCom was pushed.
+        return _MapResult(value, self.callables)
diff --git a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
index 909693de5a..29d6c7df4d 100755
--- a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
@@ -52,7 +52,6 @@ IGNORED = {
     # Only on MappedOperator.
     "expand_input",
     "partial_kwargs",
-    "validate_upstream_return_value",
 }
 
 
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index f0b53fbf5d..87ac9b69cc 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -43,7 +43,6 @@ from airflow.exceptions import (
     AirflowSkipException,
     UnmappableXComLengthPushed,
     UnmappableXComTypePushed,
-    UnmappableXComValuePushed,
     XComForMappingNotPushed,
 )
 from airflow.models import (
@@ -2614,8 +2613,6 @@ class TestTaskInstanceRecordTaskMapXComPush:
         "return_value, exception_type, error_message",
         [
             (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
-            ([123], UnmappableXComTypePushed, "unmappable return type 'list[int]'"),
-            ([{1: 3}], UnmappableXComValuePushed, "unmappable return value [{1: 3}] (dict keys must be str)"),
             (None, XComForMappingNotPushed, "did not push XCom for task mapping"),
         ],
     )
diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py
new file mode 100644
index 0000000000..5807f1b296
--- /dev/null
+++ b/tests/models/test_xcom_arg_map.py
@@ -0,0 +1,259 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import pytest
+
+from airflow.exceptions import AirflowSkipException
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.trigger_rule import TriggerRule
+
+
+def test_xcom_map(dag_maker, session):
+    results = set()
+    with dag_maker(session=session) as dag:
+
+        @dag.task
+        def push():
+            return ["a", "b", "c"]
+
+        @dag.task
+        def pull(value):
+            results.add(value)
+
+        pull.expand_kwargs(push().map(lambda v: {"value": v * 2}))
+
+    # The function passed to "map" is *NOT* a task.
+    assert set(dag.task_dict) == {"push", "pull"}
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run(session=session)
+    session.commit()
+
+    # Run "pull".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+    assert sorted(tis) == [("pull", 0), ("pull", 1), ("pull", 2)]
+    for ti in tis.values():
+        ti.run(session=session)
+
+    assert results == {"aa", "bb", "cc"}
+
+
+def test_xcom_map_transform_to_none(dag_maker, session):
+    results = set()
+
+    with dag_maker(session=session) as dag:
+
+        @dag.task()
+        def push():
+            return ["a", "b", "c"]
+
+        @dag.task()
+        def pull(value):
+            results.add(value)
+
+        def c_to_none(v):
+            if v == "c":
+                return None
+            return v
+
+        pull.expand(value=push().map(c_to_none))
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+
+    # Run "pull". This should automatically convert "c" to None.
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+    assert results == {"a", "b", None}
+
+
+def test_xcom_convert_to_kwargs_fails_task(dag_maker, session):
+    results = set()
+
+    with dag_maker(session=session) as dag:
+
+        @dag.task()
+        def push():
+            return ["a", "b", "c"]
+
+        @dag.task()
+        def pull(value):
+            results.add(value)
+
+        def c_to_none(v):
+            if v == "c":
+                return None
+            return {"value": v}
+
+        pull.expand_kwargs(push().map(c_to_none))
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+
+    # Prepare to run "pull"...
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+
+    # The first two "pull" tis should also succeed.
+    tis[("pull", 0)].run()
+    tis[("pull", 1)].run()
+
+    # But the third one fails because the map() result cannot be used as kwargs.
+    with pytest.raises(TypeError) as ctx:
+        tis[("pull", 2)].run()
+    assert str(ctx.value) == "'NoneType' object is not iterable"
+
+    assert [tis[("pull", i)].state for i in range(3)] == [
+        TaskInstanceState.SUCCESS,
+        TaskInstanceState.SUCCESS,
+        TaskInstanceState.FAILED,
+    ]
+
+
+def test_xcom_map_error_fails_task(dag_maker, session):
+    with dag_maker(session=session) as dag:
+
+        @dag.task()
+        def push():
+            return ["a", "b", "c"]
+
+        @dag.task()
+        def pull(value):
+            print(value)
+
+        def does_not_work_with_c(v):
+            if v == "c":
+                raise ValueError("nope")
+            return {"value": v * 2}
+
+        pull.expand_kwargs(push().map(does_not_work_with_c))
+
+    dr = dag_maker.create_dagrun()
+
+    # The "push" task should not fail.
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+    assert [ti.state for ti in decision.schedulable_tis] == [TaskInstanceState.SUCCESS]
+
+    # Prepare to run "pull"...
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+
+    # The first two "pull" tis should also succeed.
+    tis[("pull", 0)].run()
+    tis[("pull", 1)].run()
+
+    # But the third one (for "c") will fail.
+    with pytest.raises(ValueError) as ctx:
+        tis[("pull", 2)].run()
+    assert str(ctx.value) == "nope"
+
+    assert [tis[("pull", i)].state for i in range(3)] == [
+        TaskInstanceState.SUCCESS,
+        TaskInstanceState.SUCCESS,
+        TaskInstanceState.FAILED,
+    ]
+
+
+def test_xcom_map_raise_to_skip(dag_maker, session):
+    result = None
+
+    with dag_maker(session=session) as dag:
+
+        @dag.task()
+        def push():
+            return ["a", "b", "c"]
+
+        @dag.task()
+        def forward(value):
+            return value
+
+        @dag.task(trigger_rule=TriggerRule.ALL_DONE)
+        def collect(value):
+            nonlocal result
+            result = list(value)
+
+        def skip_c(v):
+            if v == "c":
+                raise AirflowSkipException
+            return {"value": v}
+
+        collect(value=forward.expand_kwargs(push().map(skip_c)))
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+
+    # Run "forward". This should automatically skip "c".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+
+    # Now "collect" should only get "a" and "b".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+    assert result == ["a", "b"]
+
+
+def test_xcom_map_nest(dag_maker, session):
+    results = set()
+
+    with dag_maker(session=session) as dag:
+
+        @dag.task()
+        def push():
+            return ["a", "b", "c"]
+
+        @dag.task()
+        def pull(value):
+            results.add(value)
+
+        converted = push().map(lambda v: v * 2).map(lambda v: {"value": v})
+        pull.expand_kwargs(converted)
+
+    dr = dag_maker.create_dagrun()
+
+    # Run "push".
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+
+    # Now "pull" should apply the mapping functions in order.
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    for ti in decision.schedulable_tis:
+        ti.run()
+    assert results == {"aa", "bb", "cc"}