You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/01/11 21:25:58 UTC

[airflow] 17/27: Guard not-yet-expanded ti in trigger rule dep (#28592)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit e7e10c16727c1282eed588561e2eebc7de404371
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Dec 28 23:47:11 2022 +0800

    Guard not-yet-expanded ti in trigger rule dep (#28592)
    
    Previously, if a mapped task is not yet expanded when the trigger rule
    dep is evaluated, it would raise an exception and fail the scheduler.
    This adds an additional try-except to guard against this.
    
    The problematic scenario is when a mapped task depends on another mapped
    task, and its trigger rule is evaluated before that other mapped task is
    expanded (e.g. the other task also has a task-mapping dependency that is
    not yet finished). Since we can be certain the upstream task has not yet
    satisfy the expansion dep, we can simply declare the task we're checking
    as unsatisfied.
    
    (cherry picked from commit d4dbb0077aec33e5b3b4793bf9e2902e6cbdaa7f)
---
 airflow/ti_deps/deps/trigger_rule_dep.py    |  8 +++++++-
 tests/ti_deps/deps/test_trigger_rule_dep.py | 30 +++++++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index d932a6dd21..7d78b591af 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -105,6 +105,8 @@ class TriggerRuleDep(BaseTIDep):
         :param dep_context: The current dependency context.
         :param session: Database session.
         """
+        from airflow.models.abstractoperator import NotMapped
+        from airflow.models.expandinput import NotFullyPopulated
         from airflow.models.operator import needs_expansion
         from airflow.models.taskinstance import TaskInstance
 
@@ -129,9 +131,13 @@ class TriggerRuleDep(BaseTIDep):
             and at most once for each task (instead of once for each expanded
             task instance of the same task).
             """
+            try:
+                expanded_ti_count = _get_expanded_ti_count()
+            except (NotFullyPopulated, NotMapped):
+                return None
             return ti.get_relevant_upstream_map_indexes(
                 upstream_tasks[upstream_id],
-                _get_expanded_ti_count(),
+                expanded_ti_count,
                 session=session,
             )
 
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 509909d974..42c979c93a 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -22,6 +22,7 @@ from typing import Iterator
 
 import pytest
 
+from airflow.decorators import task, task_group
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
@@ -947,3 +948,32 @@ def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
     tis["tg.t2", 1].run()
     tis = _one_scheduling_decision_iteration()
     assert sorted(tis) == [("t3", -1)]
+
+
+def test_mapped_task_check_before_expand(dag_maker, session):
+    with dag_maker(session=session):
+
+        @task
+        def t(x):
+            return x
+
+        @task_group
+        def tg(a):
+            b = t.override(task_id="t2")(a)
+            c = t.override(task_id="t3")(b)
+            return c
+
+        tg.expand(a=t([1, 2, 3]))
+
+    dr: DagRun = dag_maker.create_dagrun()
+    result_iterator = TriggerRuleDep()._evaluate_trigger_rule(
+        # t3 depends on t2, which depends on t1 for expansion. Since t1 has not
+        # yet run, t2 has not expanded yet, and we need to guarantee this lack
+        # of expansion does not fail the dependency-checking logic.
+        ti=next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and ti.map_index == -1),
+        dep_context=DepContext(),
+        session=session,
+    )
+    results = list(result_iterator)
+    assert len(results) == 1
+    assert results[0].passed is False