You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/03/07 20:52:40 UTC

[airflow] branch v2-5-test updated: Fix on_failure_callback when task receives a SIGTERM (#29743)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-5-test by this push:
     new 0a6cc4c4d3 Fix on_failure_callback when task receives a SIGTERM (#29743)
0a6cc4c4d3 is described below

commit 0a6cc4c4d397a966552b1e33213db4ca505f581b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Feb 25 00:08:32 2023 +0100

    Fix on_failure_callback when task receives a SIGTERM (#29743)
    
    This fixes on_failure_callback when task receives a SIGTERM by
    raising a different exception in the handler and catching the
    exception during task execution so we can directly run the failure
    callback.
    
    (cherry picked from commit 671b88eb3423e86bb331eaf7829659080cbd184e)
---
 airflow/exceptions.py             |  6 ++++++
 airflow/models/taskinstance.py    | 17 +++++++++++------
 tests/models/test_taskinstance.py | 22 ++++++++++++++++++++++
 3 files changed, 39 insertions(+), 6 deletions(-)

diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 4bf946fd8e..e6ef9bd4e1 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -29,6 +29,12 @@ if TYPE_CHECKING:
     from airflow.models import DagRun
 
 
+class AirflowTermSignal(Exception):
+    """Raise when we receive a TERM signal"""
+
+    status_code = HTTPStatus.INTERNAL_SERVER_ERROR
+
+
 class AirflowException(Exception):
     """
     Base class for all Airflow's errors.
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 1eb778df55..ad8d7b614b 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -75,6 +75,7 @@ from airflow.exceptions import (
     AirflowSensorTimeout,
     AirflowSkipException,
     AirflowTaskTimeout,
+    AirflowTermSignal,
     DagRunNotFound,
     RemovedInAirflow3Warning,
     TaskDeferralError,
@@ -1471,8 +1472,7 @@ class TaskInstance(Base, LoggingMixin):
                 os._exit(1)
                 return
             self.log.error("Received SIGTERM. Terminating subprocesses.")
-            self.task.on_kill()
-            raise AirflowException("Task received SIGTERM signal")
+            raise AirflowTermSignal("Task received SIGTERM signal")
 
         signal.signal(signal.SIGTERM, signal_handler)
 
@@ -1511,10 +1511,15 @@ class TaskInstance(Base, LoggingMixin):
 
             # Execute the task
             with set_current_context(context):
-                result = self._execute_task(context, task_orig)
-
-            # Run post_execute callback
-            self.task.post_execute(context=context, result=result)
+                try:
+                    result = self._execute_task(context, task_orig)
+                    # Run post_execute callback
+                    self.task.post_execute(context=context, result=result)
+                except AirflowTermSignal:
+                    self.task.on_kill()
+                    if self.task.on_failure_callback:
+                        self._run_finished_callback(self.task.on_failure_callback, context, "on_failure")
+                    raise AirflowException("Task received SIGTERM signal")
 
         Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1)
         Stats.incr("ti_successes")
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 9b126c68c9..80b6d6d302 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -471,6 +471,28 @@ class TestTaskInstance:
         ti.refresh_from_db()
         assert ti.state == State.UP_FOR_RETRY
 
+    def test_task_sigterm_calls_on_failure_callack(self, dag_maker, caplog):
+        """
+        Test that ensures that tasks call on_failure_callback when they receive sigterm
+        """
+
+        def task_function(ti):
+            os.kill(ti.pid, signal.SIGTERM)
+
+        with dag_maker():
+            task_ = PythonOperator(
+                task_id="test_on_failure",
+                python_callable=task_function,
+                on_failure_callback=lambda context: context["ti"].log.info("on_failure_callback called"),
+            )
+
+        dr = dag_maker.create_dagrun()
+        ti = dr.task_instances[0]
+        ti.task = task_
+        with pytest.raises(AirflowException):
+            ti.run()
+        assert "on_failure_callback called" in caplog.text
+
     @pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED, State.SKIPPED])
     def test_task_sigterm_doesnt_change_state_of_finished_tasks(self, state, dag_maker):
         session = settings.Session()