You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/15 04:08:04 UTC

[airflow] branch main updated: Skip mapping against mapped ti if it returns None (#25047)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6383a2772b Skip mapping against mapped ti if it returns None (#25047)
6383a2772b is described below

commit 6383a2772bac463bb1335cea2ad4554a3f94c6f7
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Fri Jul 15 12:07:55 2022 +0800

    Skip mapping against mapped ti if it returns None (#25047)
---
 airflow/models/taskinstance.py    |  4 ++--
 tests/models/test_taskinstance.py | 44 +++++++++++++++++++++++++++------------
 2 files changed, 33 insertions(+), 15 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4dd3e2545d..e083a79fd5 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2327,14 +2327,14 @@ class TaskInstance(Base, LoggingMixin):
         validators = {m.validate_upstream_return_value for m in task.iter_mapped_dependants()}
         if not validators:  # No mapped dependants, no need to validate.
             return
-        if value is None:
-            raise XComForMappingNotPushed()
         # TODO: We don't push TaskMap for mapped task instances because it's not
         # currently possible for a downstream to depend on one individual mapped
         # task instance. This will change when we implement task group mapping,
         # and we'll need to further analyze the mapped task case.
         if task.is_mapped:
             return
+        if value is None:
+            raise XComForMappingNotPushed()
         for validator in validators:
             validator(value)
         assert isinstance(value, collections.abc.Collection)  # The validators type-guard this.
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 05ac8daae5..28651c9dc2 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2949,22 +2949,40 @@ def test_ti_mapped_depends_on_mapped_xcom_arg(dag_maker, session):
     assert [x.value for x in query.order_by(None).order_by(XCom.map_index)] == [3, 4, 5]
 
 
-def test_ti_mapped_depends_on_mapped_xcom_arg_XXX(dag_maker, session):
-    with dag_maker(session=session) as dag:
+def test_mapped_upstream_return_none_should_skip(dag_maker, session):
+    results = set()
 
-        @dag.task
-        def add_one(x):
-            x + 1
+    with dag_maker(dag_id="test_mapped_upstream_return_none_should_skip", session=session) as dag:
 
-        two_three_four = add_one.expand(x=[1, 2, 3])
-        add_one.expand(x=two_three_four)
+        @dag.task()
+        def transform(value):
+            if value == "b":  # Now downstream doesn't map against this!
+                return None
+            return value
 
-    dagrun = dag_maker.create_dagrun()
-    for map_index in range(3):
-        ti = dagrun.get_task_instance("add_one", map_index=map_index)
-        ti.refresh_from_task(dag.get_task("add_one"))
-        with pytest.raises(XComForMappingNotPushed):
-            ti.run()
+        @dag.task()
+        def pull(value):
+            results.add(value)
+
+        original = ["a", "b", "c"]
+        transformed = transform.expand(value=original)  # ["a", None, "c"]
+        pull.expand(value=transformed)  # ["a", "c"]
+
+    dr = dag_maker.create_dagrun()
+
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+    assert sorted(tis) == [("transform", 0), ("transform", 1), ("transform", 2)]
+    for ti in tis.values():
+        ti.run()
+
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}
+    assert sorted(tis) == [("pull", 0), ("pull", 1)]
+    for ti in tis.values():
+        ti.run()
+
+    assert results == {"a", "c"}
 
 
 def test_expand_non_templated_field(dag_maker, session):