You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/11/15 11:53:15 UTC

(airflow) branch main updated: Revert "Fix pre-mature evaluation of tasks in mapped task group (#34337)" (#35651)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 95bf5dd620 Revert "Fix pre-mature evaluation of tasks in mapped task group (#34337)" (#35651)
95bf5dd620 is described below

commit 95bf5dd620ec996dda834ba13048a77314d6d915
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Nov 15 12:53:02 2023 +0100

    Revert "Fix pre-mature evaluation of tasks in mapped task group (#34337)" (#35651)
    
    This reverts commit 69938fd163045d750b8c218500d79bc89858f9c1.
---
 airflow/ti_deps/deps/trigger_rule_dep.py    | 18 -----------
 tests/models/test_mappedoperator.py         |  4 +--
 tests/ti_deps/deps/test_trigger_rule_dep.py | 47 ++++-------------------------
 3 files changed, 8 insertions(+), 61 deletions(-)

diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index 6203b2a79b..ca2a6100a2 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -27,7 +27,6 @@ from sqlalchemy import and_, func, or_, select
 from airflow.models.taskinstance import PAST_DEPENDS_MET
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.state import TaskInstanceState
-from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule as TR
 
 if TYPE_CHECKING:
@@ -133,20 +132,6 @@ class TriggerRuleDep(BaseTIDep):
             """
             return ti.task.get_mapped_ti_count(ti.run_id, session=session)
 
-        def _iter_expansion_dependencies() -> Iterator[str]:
-            from airflow.models.mappedoperator import MappedOperator
-
-            if isinstance(ti.task, MappedOperator):
-                for op in ti.task.iter_mapped_dependencies():
-                    yield op.task_id
-            task_group = ti.task.task_group
-            if task_group and task_group.iter_mapped_task_groups():
-                yield from (
-                    op.task_id
-                    for tg in task_group.iter_mapped_task_groups()
-                    for op in tg.iter_mapped_dependencies()
-                )
-
         @functools.lru_cache
         def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None:
             """Get the given task's map indexes relevant to the current ti.
@@ -157,9 +142,6 @@ class TriggerRuleDep(BaseTIDep):
             """
             if TYPE_CHECKING:
                 assert isinstance(ti.task.dag, DAG)
-            if isinstance(ti.task.task_group, MappedTaskGroup):
-                if upstream_id not in set(_iter_expansion_dependencies()):
-                    return None
             try:
                 expanded_ti_count = _get_expanded_ti_count()
             except (NotFullyPopulated, NotMapped):
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index 5c2e23c1f9..7244c55774 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -1305,8 +1305,8 @@ class TestMappedSetupTeardown:
         states = self.get_states(dr)
         expected = {
             "file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"},
-            "file_transforms.my_work": {2: "upstream_failed", 1: "upstream_failed", 0: "upstream_failed"},
-            "file_transforms.my_teardown": {2: "success", 1: "success", 0: "success"},
+            "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"},
+            "file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"},
         }
 
         assert states == expected
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 1bc8808cb8..00cbcd449a 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -1165,23 +1165,19 @@ def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
     tis = _one_scheduling_decision_iteration()
     assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)]
 
-    # After running the first t1, the remaining t1 must be run before t2 is available.
+    # After running the first t1, the first t2 becomes immediately available.
     tis["tg.t1", 0].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2)]
+    assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)]
 
-    # After running all t1, t2 is available.
-    tis["tg.t1", 1].run()
+    # Similarly for the subsequent t2 instances.
     tis["tg.t1", 2].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)]
-
-    # Similarly for t2 instances. They both have to complete before t3 is available
-    tis["tg.t2", 0].run()
-    tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t2", 1), ("tg.t2", 2)]
+    assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)]
 
     # But running t2 partially does not make t3 available.
+    tis["tg.t1", 1].run()
+    tis["tg.t2", 0].run()
     tis["tg.t2", 2].run()
     tis = _one_scheduling_decision_iteration()
     assert sorted(tis) == [("tg.t2", 1)]
@@ -1411,34 +1407,3 @@ class TestTriggerRuleDepSetupConstraint:
             (status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session)
         assert status.reason.startswith("All setup tasks must complete successfully")
         assert self.get_ti(dr, "w2").state == expected
-
-
-def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session):
-    """Test that one failed trigger rule works well in mapped task group"""
-    with dag_maker() as dag:
-
-        @dag.task
-        def t1():
-            return [1, 2, 3]
-
-        @task_group("tg1")
-        def tg1(a):
-            @dag.task()
-            def t2(a):
-                return a
-
-            @dag.task(trigger_rule=TriggerRule.ONE_FAILED)
-            def t3(a):
-                return a
-
-            t2(a) >> t3(a)
-
-        t = t1()
-        tg1.expand(a=t)
-
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="t1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    ti3 = dr.get_task_instance(task_id="tg1.t3")
-    assert not ti3.state