You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2023/08/01 20:52:42 UTC

[airflow] branch main updated: Fail stop feature can work with setup / teardown (#32985)

This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e78a09495 Fail stop feature can work with setup / teardown (#32985)
5e78a09495 is described below

commit 5e78a0949523f4489c78e0d956459913376bad0e
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Tue Aug 1 13:52:32 2023 -0700

    Fail stop feature can work with setup / teardown (#32985)
---
 airflow/exceptions.py               |  19 ++++--
 airflow/models/baseoperator.py      |  11 ++-
 airflow/models/dag.py               |   6 +-
 airflow/models/taskinstance.py      |  30 ++++++---
 tests/models/test_baseoperator.py   |  74 +++++++++++++++++++--
 tests/models/test_dag.py            |  12 ++--
 tests/models/test_mappedoperator.py | 129 +++++++++++++++++++++++++++++++++++-
 7 files changed, 250 insertions(+), 31 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index ea162fe8db..fe0c4e416b 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -25,6 +25,8 @@ import warnings
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Any, NamedTuple, Sized
 
+from airflow.utils.trigger_rule import TriggerRule
+
 if TYPE_CHECKING:
     from airflow.models import DAG, DagRun
 
@@ -214,20 +216,23 @@ class DagFileExists(AirflowBadRequest):
         warnings.warn("DagFileExists is deprecated and will be removed.", DeprecationWarning, stacklevel=2)
 
 
-class DagInvalidTriggerRule(AirflowException):
+class FailStopDagInvalidTriggerRule(AirflowException):
     """Raise when a dag has 'fail_stop' enabled yet has a non-default trigger rule."""
 
+    _allowed_rules = (TriggerRule.ALL_SUCCESS, TriggerRule.ALL_DONE_SETUP_SUCCESS)
+
     @classmethod
-    def check(cls, dag: DAG | None, trigger_rule: str):
-        from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+    def check(cls, *, dag: DAG | None, trigger_rule: TriggerRule):
+        """
+        Check that fail_stop dag tasks have allowable trigger rules.
 
-        if dag is not None and dag.fail_stop and trigger_rule != DEFAULT_TRIGGER_RULE:
+        :meta private:
+        """
+        if dag is not None and dag.fail_stop and trigger_rule not in cls._allowed_rules:
             raise cls()
 
     def __str__(self) -> str:
-        from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
-
-        return f"A 'fail-stop' dag can only have {DEFAULT_TRIGGER_RULE} trigger rule"
+        return f"A 'fail-stop' dag can only have {TriggerRule.ALL_SUCCESS} trigger rule"
 
 
 class DuplicateTaskIdFound(AirflowException):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d49babbeb0..99911f322e 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -54,7 +54,12 @@ from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import NoResultFound
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning, TaskDeferred
+from airflow.exceptions import (
+    AirflowException,
+    FailStopDagInvalidTriggerRule,
+    RemovedInAirflow3Warning,
+    TaskDeferred,
+)
 from airflow.lineage import apply_lineage, prepare_lineage
 from airflow.models.abstractoperator import (
     DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
@@ -801,8 +806,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
         dag = dag or DagContext.get_current_dag()
         task_group = task_group or TaskGroupContext.get_current_task_group(dag)
 
-        DagInvalidTriggerRule.check(dag, trigger_rule)
-
         self.task_id = task_group.child_id(task_id) if task_group else task_id
         if not self.__from_mapped and task_group:
             task_group.add(self)
@@ -868,6 +871,8 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
             )
 
         self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
+        FailStopDagInvalidTriggerRule.check(dag=dag, trigger_rule=self.trigger_rule)
+
         self.depends_on_past: bool = depends_on_past
         self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
         self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 5d2172c56d..8a5f6c714b 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -85,8 +85,8 @@ from airflow.exceptions import (
     AirflowDagInconsistent,
     AirflowException,
     AirflowSkipException,
-    DagInvalidTriggerRule,
     DuplicateTaskIdFound,
+    FailStopDagInvalidTriggerRule,
     RemovedInAirflow3Warning,
     TaskNotFound,
 )
@@ -722,6 +722,7 @@ class DAG(LoggingMixin):
                     f"Dag has teardown task without an upstream work task: dag='{self.dag_id}',"
                     f" task='{task.task_id}'"
                 )
+            FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
 
     def __repr__(self):
         return f"<DAG: {self.dag_id}>"
@@ -2520,7 +2521,7 @@ class DAG(LoggingMixin):
 
         :param task: the task you want to add
         """
-        DagInvalidTriggerRule.check(self, task.trigger_rule)
+        FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
 
         from airflow.utils.task_group import TaskGroupContext
 
@@ -2711,6 +2712,7 @@ class DAG(LoggingMixin):
             secrets_backend_list.insert(0, local_secrets)
 
         execution_date = execution_date or timezone.utcnow()
+        self.validate()
         self.log.debug("Clearing existing task instances for execution date %s", execution_date)
         self.clear(
             start_date=execution_date,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4d946acbea..eb6388ca3d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -186,19 +186,32 @@ def set_current_context(context: Context) -> Generator[Context, None, None]:
             )
 
 
-def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session, task_id_to_ignore: int):
+def _stop_remaining_tasks(*, self, session: Session):
+    """
+    Stop non-teardown tasks in dag.
+
+    :meta private:
+    """
+    tis = self.dag_run.get_task_instances(session=session)
+    if TYPE_CHECKING:
+        assert isinstance(self.task.dag, DAG)
+
     for ti in tis:
-        if ti.task_id == task_id_to_ignore or ti.state in (
+        if ti.task_id == self.task_id or ti.state in (
             TaskInstanceState.SUCCESS,
             TaskInstanceState.FAILED,
         ):
             continue
-        if ti.state == TaskInstanceState.RUNNING:
-            log.info("Forcing task %s to fail", ti.task_id)
-            ti.error(session)
+        task = self.task.dag.task_dict[ti.task_id]
+        if not task.is_teardown:
+            if ti.state == TaskInstanceState.RUNNING:
+                log.info("Forcing task %s to fail due to dag's `fail_stop` setting", ti.task_id)
+                ti.error(session)
+            else:
+                log.info("Setting task %s to SKIPPED due to dag's `fail_stop` setting.", ti.task_id)
+                ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
         else:
-            log.info("Setting task %s to SKIPPED", ti.task_id)
-            ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
+            log.info("Not skipping teardown task '%s'", ti.task_id)
 
 
 def clear_task_instances(
@@ -1980,8 +1993,7 @@ class TaskInstance(Base, LoggingMixin):
             callback_type = "on_failure"
 
             if task and task.dag and task.dag.fail_stop:
-                tis = self.get_dagrun(session).get_task_instances()
-                stop_all_tasks_in_dag(tis, session, self.task_id)
+                _stop_remaining_tasks(self=self, session=session)
         else:
             if self.state == TaskInstanceState.QUEUED:
                 # We increase the try_number so as to fail the task if it fails to start after sometime
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index c49e0bb034..dbb842c314 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import copy
 import logging
 import uuid
+from collections import defaultdict
 from datetime import date, datetime, timedelta
 from typing import Any, NamedTuple
 from unittest import mock
@@ -28,7 +29,7 @@ import jinja2
 import pytest
 
 from airflow.decorators import task as task_decorator
-from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning
+from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning
 from airflow.lineage.entities import File
 from airflow.models import DAG
 from airflow.models.baseoperator import (
@@ -184,7 +185,7 @@ class TestBaseOperator:
             BaseOperator(
                 task_id="test_valid_trigger_rule", dag=fail_stop_dag, trigger_rule=DEFAULT_TRIGGER_RULE
             )
-        except DagInvalidTriggerRule as exception:
+        except FailStopDagInvalidTriggerRule as exception:
             assert (
                 False
             ), f"BaseOperator raises exception with fail-stop dag & default trigger rule: {exception}"
@@ -194,13 +195,13 @@ class TestBaseOperator:
             BaseOperator(
                 task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, trigger_rule=TriggerRule.DUMMY
             )
-        except DagInvalidTriggerRule as exception:
+        except FailStopDagInvalidTriggerRule as exception:
             assert (
                 False
             ), f"BaseOperator raises exception with non fail-stop dag & non-default trigger rule: {exception}"
 
         # An operator with non default trigger rule and a fail stop dag should not be allowed
-        with pytest.raises(DagInvalidTriggerRule):
+        with pytest.raises(FailStopDagInvalidTriggerRule):
             BaseOperator(
                 task_id="test_invalid_trigger_rule", dag=fail_stop_dag, trigger_rule=TriggerRule.DUMMY
             )
@@ -919,6 +920,7 @@ def test_render_template_fields_logging(
     caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log
 ):
     """Verify if operator attributes are correctly templated."""
+
     # Trigger templating and verify results
     def _do_render():
         task.render_template_fields(context=context)
@@ -957,3 +959,67 @@ def test_find_mapped_dependants_in_another_group(dag_maker):
 
     dependants = list(gen_result.operator.iter_mapped_dependants())
     assert dependants == [add_result.operator]
+
+
+def get_states(dr):
+    """
+    For a given dag run, get a dict of states.
+
+    Example::
+        {
+            "my_setup": "success",
+            "my_teardown": {0: "success", 1: "success", 2: "success"},
+            "my_work": "failed",
+        }
+    """
+    ti_dict = defaultdict(dict)
+    for ti in dr.get_task_instances():
+        if ti.map_index == -1:
+            ti_dict[ti.task_id] = ti.state
+        else:
+            ti_dict[ti.task_id][ti.map_index] = ti.state
+    return dict(ti_dict)
+
+
+def test_teardown_and_fail_stop(dag_maker):
+    """
+    when fail_stop enabled, teardowns should run according to their setups.
+    in this case, the second teardown skips because its setup skips.
+    """
+
+    with dag_maker(fail_stop=True) as dag:
+        for num in (1, 2):
+            with TaskGroup(f"tg_{num}"):
+
+                @task_decorator
+                def my_setup():
+                    print("setting up multiple things")
+                    return [1, 2, 3]
+
+                @task_decorator
+                def my_work(val):
+                    print(f"doing work with multiple things: {val}")
+                    raise ValueError("this fails")
+                    return val
+
+                @task_decorator
+                def my_teardown():
+                    print("teardown")
+
+                s = my_setup()
+                t = my_teardown().as_teardown(setups=s)
+                with t:
+                    my_work(s)
+    tg1, tg2 = dag.task_group.children.values()
+    tg1 >> tg2
+    dr = dag.test()
+    states = get_states(dr)
+    expected = {
+        "tg_1.my_setup": "success",
+        "tg_1.my_teardown": "success",
+        "tg_1.my_work": "failed",
+        "tg_2.my_setup": "skipped",
+        "tg_2.my_teardown": "skipped",
+        "tg_2.my_work": "skipped",
+    }
+    assert states == expected
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 494c7c7c66..a4be6396be 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1731,7 +1731,7 @@ class TestDag:
 
     def test_dag_add_task_checks_trigger_rule(self):
         # A non fail stop dag should allow any trigger rule
-        from airflow.exceptions import DagInvalidTriggerRule
+        from airflow.exceptions import FailStopDagInvalidTriggerRule
         from airflow.utils.trigger_rule import TriggerRule
 
         task_with_non_default_trigger_rule = EmptyOperator(
@@ -1742,8 +1742,10 @@ class TestDag:
         )
         try:
             non_fail_stop_dag.add_task(task_with_non_default_trigger_rule)
-        except DagInvalidTriggerRule as exception:
-            assert False, f"dag add_task() raises DagInvalidTriggerRule for non fail stop dag: {exception}"
+        except FailStopDagInvalidTriggerRule as exception:
+            assert (
+                False
+            ), f"dag add_task() raises FailStopDagInvalidTriggerRule for non fail stop dag: {exception}"
 
         # a fail stop dag should allow default trigger rule
         from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
@@ -1756,13 +1758,13 @@ class TestDag:
         )
         try:
             fail_stop_dag.add_task(task_with_default_trigger_rule)
-        except DagInvalidTriggerRule as exception:
+        except FailStopDagInvalidTriggerRule as exception:
             assert (
                 False
             ), f"dag.add_task() raises exception for fail-stop dag & default trigger rule: {exception}"
 
         # a fail stop dag should not allow a non-default trigger rule
-        with pytest.raises(DagInvalidTriggerRule):
+        with pytest.raises(FailStopDagInvalidTriggerRule):
             fail_stop_dag.add_task(task_with_non_default_trigger_rule)
 
     def test_dag_add_task_sets_default_task_group(self):
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index d626b8499e..a7f6d0660c 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -666,7 +666,7 @@ class TestMappedSetupTeardown:
                 ti_dict[ti.task_id] = ti.state
             else:
                 ti_dict[ti.task_id][ti.map_index] = ti.state
-        return ti_dict
+        return dict(ti_dict)
 
     def classic_operator(self, task_id, ret=None, partial=False, fail=False):
         def success_callable(ret=None):
@@ -1365,3 +1365,130 @@ class TestMappedSetupTeardown:
             "my_work": "upstream_failed",
         }
         assert states == expected
+
+    def test_one_to_many_with_teardown_and_fail_stop(self, dag_maker):
+        """
+        With fail_stop enabled, the teardown for an already-completed setup
+        should not be skipped.
+        """
+        with dag_maker(fail_stop=True) as dag:
+
+            @task
+            def my_setup():
+                print("setting up multiple things")
+                return [1, 2, 3]
+
+            @task
+            def my_work(val):
+                print(f"doing work with multiple things: {val}")
+                raise ValueError("this fails")
+                return val
+
+            @task
+            def my_teardown(val):
+                print(f"teardown: {val}")
+
+            s = my_setup()
+            t = my_teardown.expand(val=s).as_teardown(setups=s)
+            with t:
+                my_work(s)
+
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "my_setup": "success",
+            "my_teardown": {0: "success", 1: "success", 2: "success"},
+            "my_work": "failed",
+        }
+        assert states == expected
+
+    def test_one_to_many_with_teardown_and_fail_stop_more_tasks(self, dag_maker):
+        """
+        when fail_stop enabled, teardowns should run according to their setups.
+        in this case, the second teardown skips because its setup skips.
+        """
+        with dag_maker(fail_stop=True) as dag:
+            for num in (1, 2):
+                with TaskGroup(f"tg_{num}"):
+
+                    @task
+                    def my_setup():
+                        print("setting up multiple things")
+                        return [1, 2, 3]
+
+                    @task
+                    def my_work(val):
+                        print(f"doing work with multiple things: {val}")
+                        raise ValueError("this fails")
+                        return val
+
+                    @task
+                    def my_teardown(val):
+                        print(f"teardown: {val}")
+
+                    s = my_setup()
+                    t = my_teardown.expand(val=s).as_teardown(setups=s)
+                    with t:
+                        my_work(s)
+        tg1, tg2 = dag.task_group.children.values()
+        tg1 >> tg2
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "tg_1.my_setup": "success",
+            "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"},
+            "tg_1.my_work": "failed",
+            "tg_2.my_setup": "skipped",
+            "tg_2.my_teardown": "skipped",
+            "tg_2.my_work": "skipped",
+        }
+        assert states == expected
+
+    def test_one_to_many_with_teardown_and_fail_stop_more_tasks_mapped_setup(self, dag_maker):
+        """
+        when fail_stop enabled, teardowns should run according to their setups.
+        in this case, the second teardown skips because its setup skips.
+        """
+        with dag_maker(fail_stop=True) as dag:
+            for num in (1, 2):
+                with TaskGroup(f"tg_{num}"):
+
+                    @task
+                    def my_pre_setup():
+                        print("input to the setup")
+                        return [1, 2, 3]
+
+                    @task
+                    def my_setup(val):
+                        print("setting up multiple things")
+                        return val
+
+                    @task
+                    def my_work(val):
+                        print(f"doing work with multiple things: {val}")
+                        raise ValueError("this fails")
+                        return val
+
+                    @task
+                    def my_teardown(val):
+                        print(f"teardown: {val}")
+
+                    s = my_setup.expand(val=my_pre_setup())
+                    t = my_teardown.expand(val=s).as_teardown(setups=s)
+                    with t:
+                        my_work(s)
+        tg1, tg2 = dag.task_group.children.values()
+        tg1 >> tg2
+        dr = dag.test()
+        states = self.get_states(dr)
+        expected = {
+            "tg_1.my_pre_setup": "success",
+            "tg_1.my_setup": {0: "success", 1: "success", 2: "success"},
+            "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"},
+            "tg_1.my_work": "failed",
+            "tg_2.my_pre_setup": "skipped",
+            "tg_2.my_setup": "skipped",
+            "tg_2.my_teardown": "skipped",
+            "tg_2.my_work": "skipped",
+        }
+        assert states == expected