You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/03/07 20:52:40 UTC
[airflow] branch v2-5-test updated: Fix on_failure_callback when task receives a SIGTERM (#29743)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-5-test by this push:
new 0a6cc4c4d3 Fix on_failure_callback when task receives a SIGTERM (#29743)
0a6cc4c4d3 is described below
commit 0a6cc4c4d397a966552b1e33213db4ca505f581b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Feb 25 00:08:32 2023 +0100
Fix on_failure_callback when task receives a SIGTERM (#29743)
This fixes on_failure_callback when task receives a SIGTERM by
raising a different exception in the handler and catching the
exception during task execution so we can directly run the failure
callback.
(cherry picked from commit 671b88eb3423e86bb331eaf7829659080cbd184e)
---
airflow/exceptions.py | 6 ++++++
airflow/models/taskinstance.py | 17 +++++++++++------
tests/models/test_taskinstance.py | 22 ++++++++++++++++++++++
3 files changed, 39 insertions(+), 6 deletions(-)
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 4bf946fd8e..e6ef9bd4e1 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -29,6 +29,12 @@ if TYPE_CHECKING:
from airflow.models import DagRun
+class AirflowTermSignal(Exception):
+ """Raise when we receive a TERM signal"""
+
+ status_code = HTTPStatus.INTERNAL_SERVER_ERROR
+
+
class AirflowException(Exception):
"""
Base class for all Airflow's errors.
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 1eb778df55..ad8d7b614b 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -75,6 +75,7 @@ from airflow.exceptions import (
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTimeout,
+ AirflowTermSignal,
DagRunNotFound,
RemovedInAirflow3Warning,
TaskDeferralError,
@@ -1471,8 +1472,7 @@ class TaskInstance(Base, LoggingMixin):
os._exit(1)
return
self.log.error("Received SIGTERM. Terminating subprocesses.")
- self.task.on_kill()
- raise AirflowException("Task received SIGTERM signal")
+ raise AirflowTermSignal("Task received SIGTERM signal")
signal.signal(signal.SIGTERM, signal_handler)
@@ -1511,10 +1511,15 @@ class TaskInstance(Base, LoggingMixin):
# Execute the task
with set_current_context(context):
- result = self._execute_task(context, task_orig)
-
- # Run post_execute callback
- self.task.post_execute(context=context, result=result)
+ try:
+ result = self._execute_task(context, task_orig)
+ # Run post_execute callback
+ self.task.post_execute(context=context, result=result)
+ except AirflowTermSignal:
+ self.task.on_kill()
+ if self.task.on_failure_callback:
+ self._run_finished_callback(self.task.on_failure_callback, context, "on_failure")
+ raise AirflowException("Task received SIGTERM signal")
Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1)
Stats.incr("ti_successes")
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 9b126c68c9..80b6d6d302 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -471,6 +471,28 @@ class TestTaskInstance:
ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY
+ def test_task_sigterm_calls_on_failure_callack(self, dag_maker, caplog):
+ """
+ Test that ensures that tasks call on_failure_callback when they receive sigterm
+ """
+
+ def task_function(ti):
+ os.kill(ti.pid, signal.SIGTERM)
+
+ with dag_maker():
+ task_ = PythonOperator(
+ task_id="test_on_failure",
+ python_callable=task_function,
+ on_failure_callback=lambda context: context["ti"].log.info("on_failure_callback called"),
+ )
+
+ dr = dag_maker.create_dagrun()
+ ti = dr.task_instances[0]
+ ti.task = task_
+ with pytest.raises(AirflowException):
+ ti.run()
+ assert "on_failure_callback called" in caplog.text
+
@pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED, State.SKIPPED])
def test_task_sigterm_doesnt_change_state_of_finished_tasks(self, state, dag_maker):
session = settings.Session()