You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/11 09:29:42 UTC

[airflow] branch main updated: Allow using mapped upstream's aggregated XCom (#22849)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8af77127f1 Allow using mapped upstream's aggregated XCom (#22849)
8af77127f1 is described below

commit 8af77127f1aa332c6e976c14c8b98b28c8a4cd26
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Mon Apr 11 17:29:32 2022 +0800

    Allow using mapped upstream's aggregated XCom (#22849)
    
    This needs two changes. First, when the upstream pushes the return value
    to XCom, we need to identify that the pushed value is not used on its
    own, but only aggregated with other return values from other mapped task
    instances. Fortunately, this is actually the only possible case right
    now, since we have not implemented support for depending on individual
    return values from a mapped task (aka nested mapping). So we instead
    skip recording any TaskMap metadata from a mapped task to avoid the
    problem altogether.
    
    The second change is for when the downstream task is expanded. Since the
    task depends on the mapped upstream as a whole, we should not use
    TaskMap from the upstream (which corresponds to individual task
    instances, as mentioned above), but the XComs pushed by every instance
    of the mapped task. Again, since we don't nested mapping now, we can cut
    corners and simply check whether the upstream is mapped or not to decide
    what to do, and leave further logic to the future.
    
    Co-authored-by: Ash Berlin-Taylor <as...@apache.org>
---
 airflow/models/mappedoperator.py  | 37 +++++++++++++---
 airflow/models/taskinstance.py    | 92 +++++++++++++++++++++++++++++++++------
 tests/models/test_taskinstance.py | 41 +++++++++++++----
 3 files changed, 143 insertions(+), 27 deletions(-)

diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index a7e557ff96..a28a6ea858 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -509,6 +509,7 @@ class MappedOperator(AbstractOperator):
     def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
         # TODO: Find a way to cache this.
         from airflow.models.taskmap import TaskMap
+        from airflow.models.xcom import XCom
         from airflow.models.xcom_arg import XComArg
 
         expansion_kwargs = self._get_expansion_kwargs()
@@ -518,19 +519,45 @@ class MappedOperator(AbstractOperator):
         map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if not isinstance(v, XComArg))
 
         # Build a reverse mapping of what arguments each task contributes to.
-        dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
+        mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
+        non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
         for k, v in expansion_kwargs.items():
             if not isinstance(v, XComArg):
                 continue
-            dep_keys[v.operator.task_id].add(k)
-
+            if v.operator.is_mapped:
+                mapped_dep_keys[v.operator.task_id].add(k)
+            else:
+                non_mapped_dep_keys[v.operator.task_id].add(k)
+            # TODO: It's not possible now, but in the future (AIP-42 Phase 2)
+            # we will add support for depending on one single mapped task
+            # instance. When that happens, we need to further analyze the mapped
+            # case to contain only tasks we depend on "as a whole", and put
+            # those we only depend on individually to the non-mapped lookup.
+
+        # Collect lengths from unmapped upstreams.
         taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
             TaskMap.dag_id == self.dag_id,
             TaskMap.run_id == run_id,
-            TaskMap.task_id.in_(list(dep_keys)),
+            TaskMap.task_id.in_(non_mapped_dep_keys),
+            TaskMap.map_index < 0,
         )
         for task_id, length in taskmap_query:
-            for mapped_arg_name in dep_keys[task_id]:
+            for mapped_arg_name in non_mapped_dep_keys[task_id]:
+                map_lengths[mapped_arg_name] += length
+
+        # Collect lengths from mapped upstreams.
+        xcom_query = (
+            session.query(XCom.task_id, func.count(XCom.map_index))
+            .group_by(XCom.task_id)
+            .filter(
+                XCom.dag_id == self.dag_id,
+                XCom.run_id == run_id,
+                XCom.task_id.in_(mapped_dep_keys),
+                XCom.map_index >= 0,
+            )
+        )
+        for task_id, length in xcom_query:
+            for mapped_arg_name in mapped_dep_keys[task_id]:
                 map_lengths[mapped_arg_name] += length
 
         if len(map_lengths) < len(expansion_kwargs):
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 00ea4e2c05..2f3d75436d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -35,7 +35,9 @@ from typing import (
     IO,
     TYPE_CHECKING,
     Any,
+    ContextManager,
     Dict,
+    Generator,
     Iterable,
     Iterator,
     List,
@@ -69,6 +71,8 @@ from sqlalchemy import (
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import reconstructor, relationship
 from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
+from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql.elements import BooleanClauseList
 from sqlalchemy.sql.sqltypes import BigInteger
@@ -295,6 +299,71 @@ def clear_task_instances(
                 dr.start_date = None
 
 
+class _LazyXComAccessIterator(collections.abc.Iterator):
+    __slots__ = ['_cm', '_it']
+
+    def __init__(self, cm: ContextManager[Query]):
+        self._cm = cm
+        self._it = None
+
+    def __del__(self):
+        if self._it:
+            self._cm.__exit__(None, None, None)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if not self._it:
+            self._it = iter(self._cm.__enter__())
+        return XCom.deserialize_value(next(self._it))
+
+
+class _LazyXComAccess(collections.abc.Sequence):
+    """Wrapper to lazily pull XCom with a sequence-like interface.
+
+    Note that since the session bound to the parent query may have died when we
+    actually access the sequence's content, we must create a new session
+    for every function call with ``with_session()``.
+    """
+
+    def __init__(self, query: Query):
+        self._q = query
+        self._len = None
+
+    def __len__(self):
+        if self._len is None:
+            with self._get_bound_query() as query:
+                self._len = query.count()
+        return self._len
+
+    def __iter__(self):
+        return _LazyXComAccessIterator(self._get_bound_query())
+
+    def __getitem__(self, key):
+        if not isinstance(key, int):
+            raise ValueError("only support index access for now")
+        try:
+            with self._get_bound_query() as query:
+                r = query.offset(key).limit(1).one()
+        except NoResultFound:
+            raise IndexError(key) from None
+        return XCom.deserialize_value(r)
+
+    @contextlib.contextmanager
+    def _get_bound_query(self) -> Generator[Query, None, None]:
+        # Do we have a valid session already?
+        if self._q.session and self._q.session.is_active:
+            yield self._q
+            return
+
+        session = settings.Session()
+        try:
+            yield self._q.with_session(session)
+        finally:
+            session.close()
+
+
 class TaskInstanceKey(NamedTuple):
     """Key used to identify task instance."""
 
@@ -2233,14 +2302,19 @@ class TaskInstance(Base, LoggingMixin):
         self.log.debug("Task Duration set to %s", self.duration)
 
     def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None:
-        if not task.has_mapped_dependants():
+        # TODO: We don't push TaskMap for mapped task instances because it's not
+        # currently possible for a downstream to depend on one individual mapped
+        # task instance, only a task as a whole. This will change in AIP-42
+        # Phase 2, and we'll need to further analyze the mapped task case.
+        if task.is_mapped or not task.has_mapped_dependants():
             return
         if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
             raise UnmappableXComTypePushed(value)
+        task_map = TaskMap.from_task_instance_xcom(self, value)
         max_map_length = conf.getint("core", "max_map_length", fallback=1024)
-        if len(value) > max_map_length:
+        if task_map.length > max_map_length:
             raise UnmappableXComLengthPushed(value, max_map_length)
-        session.merge(TaskMap.from_task_instance_xcom(self, value))
+        session.merge(task_map)
 
     @provide_session
     def xcom_push(
@@ -2351,21 +2425,13 @@ class TaskInstance(Base, LoggingMixin):
             # make sure all XComs come from one task and run (for task_ids=None
             # and include_prior_dates=True), and re-order by map index (reset
             # needed because XCom.get_many() orders by XCom timestamp).
-            query = (
+            return _LazyXComAccess(
                 query.with_entities(XCom.value)
-                .filter(XCom.run_id == first.run_id, XCom.task_id == first.task_id)
+                .filter(XCom.run_id == first.run_id, XCom.task_id == first.task_id, XCom.map_index >= 0)
                 .order_by(None)
                 .order_by(XCom.map_index.asc())
             )
 
-            def iter_xcom_values(query):
-                # The session passed to xcom_pull() may die before this is
-                # iterated through, so we need to bind to a new session.
-                for r in query.with_session(settings.Session()):
-                    yield XCom.deserialize_value(r)
-
-            return iter_xcom_values(query)
-
         # At this point either task_ids or map_indexes is explicitly multi-value.
 
         results = (
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 8b4da46692..2a1ea0889e 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1089,10 +1089,7 @@ class TestTaskInstance:
         assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"]
 
         assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b"
-
-        joined = ti_2.xcom_pull("task_1", session=session)
-        assert iter(joined) is joined, "should be iterator"
-        assert list(joined) == ["a", "b"]
+        assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"]
 
     def test_xcom_pull_after_success(self, create_task_instance):
         """
@@ -2635,7 +2632,7 @@ class TestMappedTaskInstanceReceiveValue:
 
 
 @mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value)
-def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterator(mock_deserialize_value, dag_maker, session):
+def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session):
     """Ensure we access XCom lazily when pulling from a mapped operator."""
     with dag_maker(dag_id="test_xcom", session=session):
         task_1 = DummyOperator.partial(task_id="task_1").expand()
@@ -2657,10 +2654,36 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterator(mock_deserialize_v
     joined = ti_2.xcom_pull("task_1", session=session)
     assert mock_deserialize_value.call_count == 0
 
-    # Only when we go through the iterator does deserialization happen.
-    assert next(joined) == "a"
+    # Only when we go through the iterable does deserialization happen.
+    it = iter(joined)
+    assert next(it) == "a"
     assert mock_deserialize_value.call_count == 1
-    assert next(joined) == "b"
+    assert next(it) == "b"
     assert mock_deserialize_value.call_count == 2
     with pytest.raises(StopIteration):
-        next(joined)
+        next(it)
+
+
+def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session):
+    with dag_maker(session=session) as dag:
+
+        @dag.task
+        def add_one(x):
+            return x + 1
+
+        two_three_four = add_one.expand(x=[1, 2, 3])
+        add_one.expand(x=two_three_four)
+
+    dagrun = dag_maker.create_dagrun()
+    for map_index in range(3):
+        ti = dagrun.get_task_instance("add_one", map_index=map_index)
+        ti.refresh_from_task(dag.get_task("add_one"))
+        ti.run()
+
+    task_345 = dag.get_task("add_one__1")
+    for ti in task_345.expand_mapped_task(dagrun.run_id, session=session):
+        ti.refresh_from_task(task_345)
+        ti.run()
+
+    query = XCom.get_many(run_id=dagrun.run_id, task_ids=["add_one__1"], session=session)
+    assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5]