You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/27 08:36:33 UTC

[airflow] branch main updated: Filter XCOM by key when calculating map lengths (#24530)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a69095fea1 Filter XCOM by key when calculating map lengths (#24530)
a69095fea1 is described below

commit a69095fea1722e153a95ef9da93b002b82a02426
Author: Henry Hinnefeld <he...@gmail.com>
AuthorDate: Wed Jul 27 03:36:23 2022 -0500

    Filter XCOM by key when calculating map lengths (#24530)
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/models/expandinput.py       |  3 ++-
 tests/models/test_mappedoperator.py | 37 +++++++++++++++++++++++++++++++++++++
 2 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 7aab2446e2..f3fa84a302 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -96,7 +96,7 @@ class DictOfListsExpandInput(NamedTuple):
         they will not be present in the dict.
         """
         from airflow.models.taskmap import TaskMap
-        from airflow.models.xcom import XCom
+        from airflow.models.xcom import XCOM_RETURN_KEY, XCom
         from airflow.models.xcom_arg import XComArg
 
         # Populate literal mapped arguments first.
@@ -143,6 +143,7 @@ class DictOfListsExpandInput(NamedTuple):
             .filter(
                 XCom.dag_id == dag_id,
                 XCom.run_id == run_id,
+                XCom.key == XCOM_RETURN_KEY,
                 XCom.task_id.in_(mapped_dep_keys),
                 XCom.map_index >= 0,
             )
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index 09ab87524b..483ec97c54 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -99,6 +99,43 @@ def test_map_xcom_arg():
     assert task1.downstream_list == [mapped]
 
 
+def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session):
+    """Test that the correct number of downstream tasks are generated when mapping with an XComArg"""
+
+    class PushExtraXComOperator(BaseOperator):
+        """Push an extra XCom value along with the default return value."""
+
+        def __init__(self, return_value, **kwargs):
+            super().__init__(**kwargs)
+            self.return_value = return_value
+
+        def execute(self, context):
+            context['task_instance'].xcom_push(key='extra_key', value="extra_value")
+            return self.return_value
+
+    with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag:
+        upstream_return = [1, 2, 3]
+        task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1")
+        task2 = PushExtraXComOperator.partial(task_id='task_2').expand(return_value=XComArg(task1))
+        task3 = PushExtraXComOperator.partial(task_id='task_3').expand(return_value=XComArg(task2))
+
+    dr = dag_maker.create_dagrun()
+    ti_1 = dr.get_task_instance("task_1", session)
+    ti_1.run()
+
+    ti_2s, _ = task2.expand_mapped_task(dr.run_id, session=session)
+    for ti in ti_2s:
+        ti.refresh_from_task(dag.get_task("task_2"))
+        ti.run()
+
+    ti_3s, _ = task3.expand_mapped_task(dr.run_id, session=session)
+    for ti in ti_3s:
+        ti.refresh_from_task(dag.get_task("task_3"))
+        ti.run()
+
+    assert len(ti_3s) == len(ti_2s) == len(upstream_return)
+
+
 def test_partial_on_instance() -> None:
     """`.partial` on an instance should fail -- it's only designed to be called on classes"""
     with pytest.raises(TypeError):