You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2023/08/01 20:52:42 UTC
[airflow] branch main updated: Fail stop feature can work with setup / teardown (#32985)
This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5e78a09495 Fail stop feature can work with setup / teardown (#32985)
5e78a09495 is described below
commit 5e78a0949523f4489c78e0d956459913376bad0e
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Tue Aug 1 13:52:32 2023 -0700
Fail stop feature can work with setup / teardown (#32985)
---
airflow/exceptions.py | 19 ++++--
airflow/models/baseoperator.py | 11 ++-
airflow/models/dag.py | 6 +-
airflow/models/taskinstance.py | 30 ++++++---
tests/models/test_baseoperator.py | 74 +++++++++++++++++++--
tests/models/test_dag.py | 12 ++--
tests/models/test_mappedoperator.py | 129 +++++++++++++++++++++++++++++++++++-
7 files changed, 250 insertions(+), 31 deletions(-)
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index ea162fe8db..fe0c4e416b 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -25,6 +25,8 @@ import warnings
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple, Sized
+from airflow.utils.trigger_rule import TriggerRule
+
if TYPE_CHECKING:
from airflow.models import DAG, DagRun
@@ -214,20 +216,23 @@ class DagFileExists(AirflowBadRequest):
warnings.warn("DagFileExists is deprecated and will be removed.", DeprecationWarning, stacklevel=2)
-class DagInvalidTriggerRule(AirflowException):
+class FailStopDagInvalidTriggerRule(AirflowException):
"""Raise when a dag has 'fail_stop' enabled yet has a non-default trigger rule."""
+ _allowed_rules = (TriggerRule.ALL_SUCCESS, TriggerRule.ALL_DONE_SETUP_SUCCESS)
+
@classmethod
- def check(cls, dag: DAG | None, trigger_rule: str):
- from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
+ def check(cls, *, dag: DAG | None, trigger_rule: TriggerRule):
+ """
+ Check that fail_stop dag tasks have allowable trigger rules.
- if dag is not None and dag.fail_stop and trigger_rule != DEFAULT_TRIGGER_RULE:
+ :meta private:
+ """
+ if dag is not None and dag.fail_stop and trigger_rule not in cls._allowed_rules:
raise cls()
def __str__(self) -> str:
- from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
-
- return f"A 'fail-stop' dag can only have {DEFAULT_TRIGGER_RULE} trigger rule"
+ return f"A 'fail-stop' dag can only have {TriggerRule.ALL_SUCCESS} trigger rule"
class DuplicateTaskIdFound(AirflowException):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d49babbeb0..99911f322e 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -54,7 +54,12 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning, TaskDeferred
+from airflow.exceptions import (
+ AirflowException,
+ FailStopDagInvalidTriggerRule,
+ RemovedInAirflow3Warning,
+ TaskDeferred,
+)
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
@@ -801,8 +806,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
- DagInvalidTriggerRule.check(dag, trigger_rule)
-
self.task_id = task_group.child_id(task_id) if task_group else task_id
if not self.__from_mapped and task_group:
task_group.add(self)
@@ -868,6 +871,8 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
)
self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
+ FailStopDagInvalidTriggerRule.check(dag=dag, trigger_rule=self.trigger_rule)
+
self.depends_on_past: bool = depends_on_past
self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
self.wait_for_past_depends_before_skipping: bool = wait_for_past_depends_before_skipping
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 5d2172c56d..8a5f6c714b 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -85,8 +85,8 @@ from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
AirflowSkipException,
- DagInvalidTriggerRule,
DuplicateTaskIdFound,
+ FailStopDagInvalidTriggerRule,
RemovedInAirflow3Warning,
TaskNotFound,
)
@@ -722,6 +722,7 @@ class DAG(LoggingMixin):
f"Dag has teardown task without an upstream work task: dag='{self.dag_id}',"
f" task='{task.task_id}'"
)
+ FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
def __repr__(self):
return f"<DAG: {self.dag_id}>"
@@ -2520,7 +2521,7 @@ class DAG(LoggingMixin):
:param task: the task you want to add
"""
- DagInvalidTriggerRule.check(self, task.trigger_rule)
+ FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule)
from airflow.utils.task_group import TaskGroupContext
@@ -2711,6 +2712,7 @@ class DAG(LoggingMixin):
secrets_backend_list.insert(0, local_secrets)
execution_date = execution_date or timezone.utcnow()
+ self.validate()
self.log.debug("Clearing existing task instances for execution date %s", execution_date)
self.clear(
start_date=execution_date,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4d946acbea..eb6388ca3d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -186,19 +186,32 @@ def set_current_context(context: Context) -> Generator[Context, None, None]:
)
-def stop_all_tasks_in_dag(tis: list[TaskInstance], session: Session, task_id_to_ignore: int):
+def _stop_remaining_tasks(*, self, session: Session):
+ """
+ Stop non-teardown tasks in dag.
+
+ :meta private:
+ """
+ tis = self.dag_run.get_task_instances(session=session)
+ if TYPE_CHECKING:
+ assert isinstance(self.task.dag, DAG)
+
for ti in tis:
- if ti.task_id == task_id_to_ignore or ti.state in (
+ if ti.task_id == self.task_id or ti.state in (
TaskInstanceState.SUCCESS,
TaskInstanceState.FAILED,
):
continue
- if ti.state == TaskInstanceState.RUNNING:
- log.info("Forcing task %s to fail", ti.task_id)
- ti.error(session)
+ task = self.task.dag.task_dict[ti.task_id]
+ if not task.is_teardown:
+ if ti.state == TaskInstanceState.RUNNING:
+ log.info("Forcing task %s to fail due to dag's `fail_stop` setting", ti.task_id)
+ ti.error(session)
+ else:
+ log.info("Setting task %s to SKIPPED due to dag's `fail_stop` setting.", ti.task_id)
+ ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
else:
- log.info("Setting task %s to SKIPPED", ti.task_id)
- ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
+ log.info("Not skipping teardown task '%s'", ti.task_id)
def clear_task_instances(
@@ -1980,8 +1993,7 @@ class TaskInstance(Base, LoggingMixin):
callback_type = "on_failure"
if task and task.dag and task.dag.fail_stop:
- tis = self.get_dagrun(session).get_task_instances()
- stop_all_tasks_in_dag(tis, session, self.task_id)
+ _stop_remaining_tasks(self=self, session=session)
else:
if self.state == TaskInstanceState.QUEUED:
# We increase the try_number so as to fail the task if it fails to start after sometime
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index c49e0bb034..dbb842c314 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import copy
import logging
import uuid
+from collections import defaultdict
from datetime import date, datetime, timedelta
from typing import Any, NamedTuple
from unittest import mock
@@ -28,7 +29,7 @@ import jinja2
import pytest
from airflow.decorators import task as task_decorator
-from airflow.exceptions import AirflowException, DagInvalidTriggerRule, RemovedInAirflow3Warning
+from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning
from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import (
@@ -184,7 +185,7 @@ class TestBaseOperator:
BaseOperator(
task_id="test_valid_trigger_rule", dag=fail_stop_dag, trigger_rule=DEFAULT_TRIGGER_RULE
)
- except DagInvalidTriggerRule as exception:
+ except FailStopDagInvalidTriggerRule as exception:
assert (
False
), f"BaseOperator raises exception with fail-stop dag & default trigger rule: {exception}"
@@ -194,13 +195,13 @@ class TestBaseOperator:
BaseOperator(
task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, trigger_rule=TriggerRule.DUMMY
)
- except DagInvalidTriggerRule as exception:
+ except FailStopDagInvalidTriggerRule as exception:
assert (
False
), f"BaseOperator raises exception with non fail-stop dag & non-default trigger rule: {exception}"
# An operator with non default trigger rule and a fail stop dag should not be allowed
- with pytest.raises(DagInvalidTriggerRule):
+ with pytest.raises(FailStopDagInvalidTriggerRule):
BaseOperator(
task_id="test_invalid_trigger_rule", dag=fail_stop_dag, trigger_rule=TriggerRule.DUMMY
)
@@ -919,6 +920,7 @@ def test_render_template_fields_logging(
caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log
):
"""Verify if operator attributes are correctly templated."""
+
# Trigger templating and verify results
def _do_render():
task.render_template_fields(context=context)
@@ -957,3 +959,67 @@ def test_find_mapped_dependants_in_another_group(dag_maker):
dependants = list(gen_result.operator.iter_mapped_dependants())
assert dependants == [add_result.operator]
+
+
+def get_states(dr):
+ """
+ For a given dag run, get a dict of states.
+
+ Example::
+ {
+ "my_setup": "success",
+ "my_teardown": {0: "success", 1: "success", 2: "success"},
+ "my_work": "failed",
+ }
+ """
+ ti_dict = defaultdict(dict)
+ for ti in dr.get_task_instances():
+ if ti.map_index == -1:
+ ti_dict[ti.task_id] = ti.state
+ else:
+ ti_dict[ti.task_id][ti.map_index] = ti.state
+ return dict(ti_dict)
+
+
+def test_teardown_and_fail_stop(dag_maker):
+ """
+ when fail_stop enabled, teardowns should run according to their setups.
+ in this case, the second teardown skips because its setup skips.
+ """
+
+ with dag_maker(fail_stop=True) as dag:
+ for num in (1, 2):
+ with TaskGroup(f"tg_{num}"):
+
+ @task_decorator
+ def my_setup():
+ print("setting up multiple things")
+ return [1, 2, 3]
+
+ @task_decorator
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ raise ValueError("this fails")
+ return val
+
+ @task_decorator
+ def my_teardown():
+ print("teardown")
+
+ s = my_setup()
+ t = my_teardown().as_teardown(setups=s)
+ with t:
+ my_work(s)
+ tg1, tg2 = dag.task_group.children.values()
+ tg1 >> tg2
+ dr = dag.test()
+ states = get_states(dr)
+ expected = {
+ "tg_1.my_setup": "success",
+ "tg_1.my_teardown": "success",
+ "tg_1.my_work": "failed",
+ "tg_2.my_setup": "skipped",
+ "tg_2.my_teardown": "skipped",
+ "tg_2.my_work": "skipped",
+ }
+ assert states == expected
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 494c7c7c66..a4be6396be 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1731,7 +1731,7 @@ class TestDag:
def test_dag_add_task_checks_trigger_rule(self):
# A non fail stop dag should allow any trigger rule
- from airflow.exceptions import DagInvalidTriggerRule
+ from airflow.exceptions import FailStopDagInvalidTriggerRule
from airflow.utils.trigger_rule import TriggerRule
task_with_non_default_trigger_rule = EmptyOperator(
@@ -1742,8 +1742,10 @@ class TestDag:
)
try:
non_fail_stop_dag.add_task(task_with_non_default_trigger_rule)
- except DagInvalidTriggerRule as exception:
- assert False, f"dag add_task() raises DagInvalidTriggerRule for non fail stop dag: {exception}"
+ except FailStopDagInvalidTriggerRule as exception:
+ assert (
+ False
+ ), f"dag add_task() raises FailStopDagInvalidTriggerRule for non fail stop dag: {exception}"
# a fail stop dag should allow default trigger rule
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
@@ -1756,13 +1758,13 @@ class TestDag:
)
try:
fail_stop_dag.add_task(task_with_default_trigger_rule)
- except DagInvalidTriggerRule as exception:
+ except FailStopDagInvalidTriggerRule as exception:
assert (
False
), f"dag.add_task() raises exception for fail-stop dag & default trigger rule: {exception}"
# a fail stop dag should not allow a non-default trigger rule
- with pytest.raises(DagInvalidTriggerRule):
+ with pytest.raises(FailStopDagInvalidTriggerRule):
fail_stop_dag.add_task(task_with_non_default_trigger_rule)
def test_dag_add_task_sets_default_task_group(self):
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index d626b8499e..a7f6d0660c 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -666,7 +666,7 @@ class TestMappedSetupTeardown:
ti_dict[ti.task_id] = ti.state
else:
ti_dict[ti.task_id][ti.map_index] = ti.state
- return ti_dict
+ return dict(ti_dict)
def classic_operator(self, task_id, ret=None, partial=False, fail=False):
def success_callable(ret=None):
@@ -1365,3 +1365,130 @@ class TestMappedSetupTeardown:
"my_work": "upstream_failed",
}
assert states == expected
+
+ def test_one_to_many_with_teardown_and_fail_stop(self, dag_maker):
+ """
+ With fail_stop enabled, the teardown for an already-completed setup
+ should not be skipped.
+ """
+ with dag_maker(fail_stop=True) as dag:
+
+ @task
+ def my_setup():
+ print("setting up multiple things")
+ return [1, 2, 3]
+
+ @task
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ raise ValueError("this fails")
+ return val
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup()
+ t = my_teardown.expand(val=s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "my_setup": "success",
+ "my_teardown": {0: "success", 1: "success", 2: "success"},
+ "my_work": "failed",
+ }
+ assert states == expected
+
+ def test_one_to_many_with_teardown_and_fail_stop_more_tasks(self, dag_maker):
+ """
+ when fail_stop enabled, teardowns should run according to their setups.
+ in this case, the second teardown skips because its setup skips.
+ """
+ with dag_maker(fail_stop=True) as dag:
+ for num in (1, 2):
+ with TaskGroup(f"tg_{num}"):
+
+ @task
+ def my_setup():
+ print("setting up multiple things")
+ return [1, 2, 3]
+
+ @task
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ raise ValueError("this fails")
+ return val
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup()
+ t = my_teardown.expand(val=s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+ tg1, tg2 = dag.task_group.children.values()
+ tg1 >> tg2
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "tg_1.my_setup": "success",
+ "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"},
+ "tg_1.my_work": "failed",
+ "tg_2.my_setup": "skipped",
+ "tg_2.my_teardown": "skipped",
+ "tg_2.my_work": "skipped",
+ }
+ assert states == expected
+
+ def test_one_to_many_with_teardown_and_fail_stop_more_tasks_mapped_setup(self, dag_maker):
+ """
+ when fail_stop enabled, teardowns should run according to their setups.
+ in this case, the second teardown skips because its setup skips.
+ """
+ with dag_maker(fail_stop=True) as dag:
+ for num in (1, 2):
+ with TaskGroup(f"tg_{num}"):
+
+ @task
+ def my_pre_setup():
+ print("input to the setup")
+ return [1, 2, 3]
+
+ @task
+ def my_setup(val):
+ print("setting up multiple things")
+ return val
+
+ @task
+ def my_work(val):
+ print(f"doing work with multiple things: {val}")
+ raise ValueError("this fails")
+ return val
+
+ @task
+ def my_teardown(val):
+ print(f"teardown: {val}")
+
+ s = my_setup.expand(val=my_pre_setup())
+ t = my_teardown.expand(val=s).as_teardown(setups=s)
+ with t:
+ my_work(s)
+ tg1, tg2 = dag.task_group.children.values()
+ tg1 >> tg2
+ dr = dag.test()
+ states = self.get_states(dr)
+ expected = {
+ "tg_1.my_pre_setup": "success",
+ "tg_1.my_setup": {0: "success", 1: "success", 2: "success"},
+ "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"},
+ "tg_1.my_work": "failed",
+ "tg_2.my_pre_setup": "skipped",
+ "tg_2.my_setup": "skipped",
+ "tg_2.my_teardown": "skipped",
+ "tg_2.my_work": "skipped",
+ }
+ assert states == expected