You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/27 08:36:33 UTC
[airflow] branch main updated: Filter XCOM by key when calculating map lengths (#24530)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a69095fea1 Filter XCOM by key when calculating map lengths (#24530)
a69095fea1 is described below
commit a69095fea1722e153a95ef9da93b002b82a02426
Author: Henry Hinnefeld <he...@gmail.com>
AuthorDate: Wed Jul 27 03:36:23 2022 -0500
Filter XCOM by key when calculating map lengths (#24530)
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
airflow/models/expandinput.py | 3 ++-
tests/models/test_mappedoperator.py | 37 +++++++++++++++++++++++++++++++++++++
2 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 7aab2446e2..f3fa84a302 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -96,7 +96,7 @@ class DictOfListsExpandInput(NamedTuple):
they will not be present in the dict.
"""
from airflow.models.taskmap import TaskMap
- from airflow.models.xcom import XCom
+ from airflow.models.xcom import XCOM_RETURN_KEY, XCom
from airflow.models.xcom_arg import XComArg
# Populate literal mapped arguments first.
@@ -143,6 +143,7 @@ class DictOfListsExpandInput(NamedTuple):
.filter(
XCom.dag_id == dag_id,
XCom.run_id == run_id,
+ XCom.key == XCOM_RETURN_KEY,
XCom.task_id.in_(mapped_dep_keys),
XCom.map_index >= 0,
)
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index 09ab87524b..483ec97c54 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -99,6 +99,43 @@ def test_map_xcom_arg():
assert task1.downstream_list == [mapped]
+def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session):
+ """Test that the correct number of downstream tasks are generated when mapping with an XComArg"""
+
+ class PushExtraXComOperator(BaseOperator):
+ """Push an extra XCom value along with the default return value."""
+
+ def __init__(self, return_value, **kwargs):
+ super().__init__(**kwargs)
+ self.return_value = return_value
+
+ def execute(self, context):
+ context['task_instance'].xcom_push(key='extra_key', value="extra_value")
+ return self.return_value
+
+ with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag:
+ upstream_return = [1, 2, 3]
+ task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1")
+ task2 = PushExtraXComOperator.partial(task_id='task_2').expand(return_value=XComArg(task1))
+ task3 = PushExtraXComOperator.partial(task_id='task_3').expand(return_value=XComArg(task2))
+
+ dr = dag_maker.create_dagrun()
+ ti_1 = dr.get_task_instance("task_1", session)
+ ti_1.run()
+
+ ti_2s, _ = task2.expand_mapped_task(dr.run_id, session=session)
+ for ti in ti_2s:
+ ti.refresh_from_task(dag.get_task("task_2"))
+ ti.run()
+
+ ti_3s, _ = task3.expand_mapped_task(dr.run_id, session=session)
+ for ti in ti_3s:
+ ti.refresh_from_task(dag.get_task("task_3"))
+ ti.run()
+
+ assert len(ti_3s) == len(ti_2s) == len(upstream_return)
+
+
def test_partial_on_instance() -> None:
"""`.partial` on an instance should fail -- it's only designed to be called on classes"""
with pytest.raises(TypeError):