You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jh...@apache.org on 2021/08/12 00:59:47 UTC

[airflow] branch v2-1-test updated (20ed40b -> fe70111)

This is an automated email from the ASF dual-hosted git repository.

jhtimmins pushed a change to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.


    from 20ed40b  Better diagnostics and self-healing of docker-compose (#17484)
     new 26e8d2d  Fix bug that log can't be shown when task runs failed (#16768)
     new 1be2003  Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)
     new fe70111  Add 'queued' to DagRunState (#16854)

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 airflow/jobs/local_task_job.py         |  24 ++--
 airflow/jobs/scheduler_job.py          |  19 ++-
 airflow/models/dag.py                  |  10 +-
 airflow/models/dagrun.py               |  16 +--
 airflow/models/taskinstance.py         |   4 +-
 airflow/utils/state.py                 |   2 +-
 tests/dag_processing/test_processor.py |  23 ++++
 tests/jobs/test_local_task_job.py      | 206 +++++++++++++++++++++++++++++----
 tests/models/test_taskinstance.py      |  32 +++++
 9 files changed, 276 insertions(+), 60 deletions(-)

[airflow] 03/03: Add 'queued' to DagRunState (#16854)

Posted by jh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jhtimmins pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit fe7011100e4c2e789d2ca185e2733373414bd3ae
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 7 21:37:22 2021 +0100

    Add 'queued' to DagRunState (#16854)
    
    This change adds 'queued' to DagRunState and improved typing for DagRun state
    
    Co-authored-by: Kaxil Naik <ka...@gmail.com>
    (cherry picked from commit 5a5f30f9133a6c5f0c41886ff9ae80ea53c73989)
---
 airflow/jobs/scheduler_job.py  |  8 ++++----
 airflow/models/dag.py          | 10 +++++-----
 airflow/models/dagrun.py       | 16 +++++++++-------
 airflow/models/taskinstance.py |  4 ++--
 airflow/utils/state.py         |  2 +-
 5 files changed, 21 insertions(+), 19 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 6d37a2b..9298b82 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -64,7 +64,7 @@ from airflow.utils.mixins import MultiprocessingStartMethodMixin
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 TI = models.TaskInstance
@@ -796,7 +796,7 @@ class SchedulerJob(BaseJob):
 
     @provide_session
     def _change_state_for_tis_without_dagrun(
-        self, old_states: List[str], new_state: str, session: Session = None
+        self, old_states: List[TaskInstanceState], new_state: TaskInstanceState, session: Session = None
     ) -> None:
         """
         For all DAG IDs in the DagBag, look for task instances in the
@@ -870,7 +870,7 @@ class SchedulerJob(BaseJob):
 
     @provide_session
     def __get_concurrency_maps(
-        self, states: List[str], session: Session = None
+        self, states: List[TaskInstanceState], session: Session = None
     ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]:
         """
         Get the concurrency maps.
@@ -1546,7 +1546,7 @@ class SchedulerJob(BaseJob):
             return num_queued_tis
 
     @retry_db_transaction
-    def _get_next_dagruns_to_examine(self, state, session):
+    def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session):
         """Get Next DagRuns to Examine with retries"""
         return DagRun.next_dagruns_to_examine(state, session)
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index a3d06db..2e66b40 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -75,7 +75,7 @@ from airflow.utils.helpers import validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
@@ -1153,7 +1153,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=True,
-        dag_run_state: str = State.QUEUED,
+        dag_run_state: DagRunState = DagRunState.QUEUED,
         dry_run=False,
         session=None,
         get_tis=False,
@@ -1369,7 +1369,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=False,
-        dag_run_state=State.QUEUED,
+        dag_run_state=DagRunState.QUEUED,
         dry_run=False,
     ):
         all_tis = []
@@ -1731,7 +1731,7 @@ class DAG(LoggingMixin):
     @provide_session
     def create_dagrun(
         self,
-        state: State,
+        state: DagRunState,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         start_date: Optional[datetime] = None,
@@ -1753,7 +1753,7 @@ class DAG(LoggingMixin):
         :param execution_date: the execution date of this dag run
         :type execution_date: datetime.datetime
         :param state: the state of the dag run
-        :type state: airflow.utils.state.State
+        :type state: airflow.utils.state.DagRunState
         :param start_date: the date this dag run should be evaluated
         :type start_date: datetime
         :param external_trigger: whether this dag run is externally triggered
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index c503ac4..66be9a3 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -36,7 +36,7 @@ from airflow.utils import callback_requests, timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
@@ -110,7 +110,7 @@ class DagRun(Base, LoggingMixin):
         start_date: Optional[datetime] = None,
         external_trigger: Optional[bool] = None,
         conf: Optional[Any] = None,
-        state: Optional[str] = None,
+        state: Optional[DagRunState] = None,
         run_type: Optional[str] = None,
         dag_hash: Optional[str] = None,
         creating_job_id: Optional[int] = None,
@@ -144,7 +144,7 @@ class DagRun(Base, LoggingMixin):
     def get_state(self):
         return self._state
 
-    def set_state(self, state):
+    def set_state(self, state: DagRunState):
         if self._state != state:
             self._state = state
             self.end_date = timezone.utcnow() if self._state in State.finished else None
@@ -170,7 +170,7 @@ class DagRun(Base, LoggingMixin):
     @classmethod
     def next_dagruns_to_examine(
         cls,
-        state: str,
+        state: DagRunState,
         session: Session,
         max_number: Optional[int] = None,
     ):
@@ -219,7 +219,7 @@ class DagRun(Base, LoggingMixin):
         dag_id: Optional[Union[str, List[str]]] = None,
         run_id: Optional[str] = None,
         execution_date: Optional[datetime] = None,
-        state: Optional[str] = None,
+        state: Optional[DagRunState] = None,
         external_trigger: Optional[bool] = None,
         no_backfills: bool = False,
         run_type: Optional[DagRunType] = None,
@@ -239,7 +239,7 @@ class DagRun(Base, LoggingMixin):
         :param execution_date: the execution date
         :type execution_date: datetime.datetime or list[datetime.datetime]
         :param state: the state of the dag run
-        :type state: str
+        :type state: DagRunState
         :param external_trigger: whether this dag run is externally triggered
         :type external_trigger: bool
         :param no_backfills: return no backfills (True), return all (False).
@@ -341,7 +341,9 @@ class DagRun(Base, LoggingMixin):
         return self.dag
 
     @provide_session
-    def get_previous_dagrun(self, state: Optional[str] = None, session: Session = None) -> Optional['DagRun']:
+    def get_previous_dagrun(
+        self, state: Optional[DagRunState] = None, session: Session = None
+    ) -> Optional['DagRun']:
         """The previous DagRun, if there is one"""
         filters = [
             DagRun.dag_id == self.dag_id,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0e10567..c715f22 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -70,7 +70,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.timeout import timeout
 
 try:
@@ -137,7 +137,7 @@ def clear_task_instances(
     session,
     activate_dag_runs=None,
     dag=None,
-    dag_run_state: Union[str, Literal[False]] = State.QUEUED,
+    dag_run_state: Union[DagRunState, Literal[False]] = DagRunState.QUEUED,
 ):
     """
     Clears a set of task instances, but makes sure the running ones
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index d5300e1..5ffcbd7 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -33,7 +33,6 @@ class State:
     # set by the executor (t.b.d.)
     # LAUNCHED = "launched"
 
-    # set by a task
     QUEUED = "queued"
     RUNNING = "running"
     SUCCESS = "success"
@@ -64,6 +63,7 @@ class State:
         SUCCESS,
         RUNNING,
         FAILED,
+        QUEUED,
     )
 
     state_color = {

[airflow] 02/03: Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)

Posted by jh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jhtimmins pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1be2003b69b6f0cdd58d0743640bdb797acdfac4
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 28 15:57:35 2021 +0100

    Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)
    
    Currently, tasks are not retried when they receive SIGKILL or SIGTERM even if the task has retry. This change fixes it
    and added test for both SIGTERM and SIGKILL so we don't experience regression
    
    Also, SIGTERM sets the task as failed and raises AirflowException which heartbeat sometimes see as externally set to fail
    and not call failure_callbacks. This commit also fixes this by calling handle_task_exit when a task gets SIGTERM
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
    (cherry picked from commit 4e2a94c6d1bde5ddf2aa0251190c318ac22f3b17)
---
 airflow/jobs/local_task_job.py    |  24 ++---
 tests/jobs/test_local_task_job.py | 206 ++++++++++++++++++++++++++++++++++----
 tests/models/test_taskinstance.py |  32 ++++++
 3 files changed, 228 insertions(+), 34 deletions(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index e7f61f1..d82e115 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -78,12 +78,9 @@ class LocalTaskJob(BaseJob):
         def signal_handler(signum, frame):
             """Setting kill signal handler"""
             self.log.error("Received SIGTERM. Terminating subprocesses")
-            self.on_kill()
-            self.task_instance.refresh_from_db()
-            if self.task_instance.state not in State.finished:
-                self.task_instance.set_state(State.FAILED)
-            self.task_instance._run_finished_callback(error="task received sigterm")
-            raise AirflowException("LocalTaskJob received SIGTERM signal")
+            self.task_runner.terminate()
+            self.handle_task_exit(128 + signum)
+            return
 
         signal.signal(signal.SIGTERM, signal_handler)
 
@@ -148,16 +145,19 @@ class LocalTaskJob(BaseJob):
             self.on_kill()
 
     def handle_task_exit(self, return_code: int) -> None:
-        """Handle case where self.task_runner exits by itself"""
+        """Handle case where self.task_runner exits by itself or is externally killed"""
+        # Without setting this, heartbeat may get us
+        self.terminating = True
         self.log.info("Task exited with return code %s", return_code)
         self.task_instance.refresh_from_db()
-        # task exited by itself, so we need to check for error file
+
+        if self.task_instance.state == State.RUNNING:
+            # This is for a case where the task received a SIGKILL
+            # while running or the task runner received a sigterm
+            self.task_instance.handle_failure(error=None)
+        # We need to check for error file
         # in case it failed due to runtime exception/error
         error = None
-        if self.task_instance.state == State.RUNNING:
-            # This is for a case where the task received a sigkill
-            # while running
-            self.task_instance.set_state(State.FAILED)
         if self.task_instance.state != State.SUCCESS:
             error = self.task_runner.deserialize_run_error()
         self.task_instance._run_finished_callback(error=error)
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index b985cf6..8f6cba3 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -22,6 +22,7 @@ import signal
 import time
 import unittest
 import uuid
+from datetime import timedelta
 from multiprocessing import Lock, Value
 from unittest import mock
 from unittest.mock import patch
@@ -286,7 +287,6 @@ class TestLocalTaskJob(unittest.TestCase):
                 delta = (time2 - time1).total_seconds()
                 assert abs(delta - job.heartrate) < 0.5
 
-    @pytest.mark.quarantined
     def test_mark_success_no_kill(self):
         """
         Test that ensures that mark_success in the UI doesn't cause
@@ -314,7 +314,6 @@ class TestLocalTaskJob(unittest.TestCase):
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
         process = multiprocessing.Process(target=job1.run)
         process.start()
-        ti.refresh_from_db()
         for _ in range(0, 50):
             if ti.state == State.RUNNING:
                 break
@@ -543,8 +542,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert ti.state == State.FAILED  # task exits with failure state
         assert failure_callback_called.value == 1
 
-    @pytest.mark.quarantined
-    def test_mark_success_on_success_callback(self):
+    def test_mark_success_on_success_callback(self, dag_maker):
         """
         Test that ensures that where a task is marked success in the UI
         on_success_callback gets executed
@@ -610,15 +608,9 @@ class TestLocalTaskJob(unittest.TestCase):
         assert task_terminated_externally.value == 1
         assert not process.is_alive()
 
-    @parameterized.expand(
-        [
-            (signal.SIGTERM,),
-            (signal.SIGKILL,),
-        ]
-    )
-    def test_process_kill_calls_on_failure_callback(self, signal_type):
+    def test_task_sigkill_calls_on_failure_callback(self, dag_maker):
         """
-        Test that ensures that when a task is killed with sigterm or sigkill
+        Test that ensures that when a task is killed with sigkill
         on_failure_callback gets executed
         """
         # use shared memory value so we can properly track value change even if
@@ -630,12 +622,52 @@ class TestLocalTaskJob(unittest.TestCase):
         def failure_callback(context):
             with shared_mem_lock:
                 failure_callback_called.value += 1
-            assert context['dag_run'].dag_id == 'test_mark_failure'
+            assert context['dag_run'].dag_id == 'test_send_sigkill'
 
         dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
 
         def task_function(ti):
+            os.kill(os.getpid(), signal.SIGKILL)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        with dag_maker(dag_id='test_send_sigkill'):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                on_failure_callback=failure_callback,
+            )
+
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        settings.engine.dispose()
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        time.sleep(0.3)
+        process.join(timeout=10)
+        assert failure_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+        assert not process.is_alive()
+
+    def test_process_sigterm_calls_on_failure_callback(self, dag_maker):
+        """
+        Test that ensures that when a task runner is killed with sigterm
+        on_failure_callback gets executed
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        failure_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
+
+        def failure_callback(context):
+            with shared_mem_lock:
+                failure_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_mark_failure'
 
+        def task_function(ti):
             time.sleep(60)
             # This should not happen -- the state change should be noticed and the task should get killed
             with shared_mem_lock:
@@ -661,20 +693,16 @@ class TestLocalTaskJob(unittest.TestCase):
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
-        job1.task_runner = StandardTaskRunner(job1)
-
         settings.engine.dispose()
         process = multiprocessing.Process(target=job1.run)
         process.start()
-
-        for _ in range(0, 20):
+        for _ in range(0, 25):
             ti.refresh_from_db()
-            if ti.state == State.RUNNING and ti.pid is not None:
+            if ti.state == State.RUNNING:
                 break
             time.sleep(0.2)
-        assert ti.pid is not None
-        assert ti.state == State.RUNNING
-        os.kill(ti.pid, signal_type)
+        os.kill(process.pid, signal.SIGTERM)
+        ti.refresh_from_db()
         process.join(timeout=10)
         assert failure_callback_called.value == 1
         assert task_terminated_externally.value == 1
@@ -819,6 +847,140 @@ class TestLocalTaskJob(unittest.TestCase):
             if scheduler_job.processor_agent:
                 scheduler_job.processor_agent.end()
 
+    def test_task_sigkill_works_with_retries(self, dag_maker):
+        """
+        Test that ensures that tasks are retried when they receive sigkill
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        retry_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
+
+        def retry_callback(context):
+            with shared_mem_lock:
+                retry_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_mark_failure_2'
+
+        def task_function(ti):
+            os.kill(os.getpid(), signal.SIGKILL)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        with dag_maker(
+            dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
+        ):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                retries=1,
+                retry_delay=timedelta(seconds=2),
+                on_retry_callback=retry_callback,
+            )
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.task_runner.start()
+        settings.engine.dispose()
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        time.sleep(0.4)
+        process.join()
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
+        assert retry_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+
+    def test_process_sigterm_works_with_retries(self, dag_maker):
+        """
+        Test that ensures that task runner sets tasks to retry when they(task runner)
+         receive sigterm
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        retry_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
+
+        def retry_callback(context):
+            with shared_mem_lock:
+                retry_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_mark_failure_2'
+
+        def task_function(ti):
+            time.sleep(60)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        with dag_maker(dag_id='test_mark_failure_2'):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                retries=1,
+                retry_delay=timedelta(seconds=2),
+                on_retry_callback=retry_callback,
+            )
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.task_runner.start()
+        settings.engine.dispose()
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        for _ in range(0, 25):
+            ti.refresh_from_db()
+            if ti.state == State.RUNNING and ti.pid is not None:
+                break
+            time.sleep(0.2)
+        os.kill(process.pid, signal.SIGTERM)
+        process.join()
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
+        assert retry_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+
+    def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
+        """Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
+        dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
+        op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True)
+
+        session = settings.Session()
+        orm_dag = DagModel(
+            dag_id=dag.dag_id,
+            has_task_concurrency_limits=False,
+            next_dagrun=dag.start_date,
+            next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
+            is_active=True,
+            is_paused=True,
+        )
+        session.add(orm_dag)
+        session.flush()
+        # Write Dag to DB
+        dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
+        dagbag.bag_dag(dag, root_dag=dag)
+        dagbag.sync_to_db()
+
+        dr = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            state=State.RUNNING,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+            session=session,
+        )
+
+        assert dr.state == State.RUNNING
+        ti = TaskInstance(op1, dr.execution_date)
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.run()
+        session.add(dr)
+        session.refresh(dr)
+        assert dr.state == State.SUCCESS
+
 
 @pytest.fixture()
 def clean_db_helper():
@@ -844,5 +1006,5 @@ class TestLocalTaskJobPerformance:
         mock_get_task_runner.return_value.return_code.side_effects = return_codes
 
         job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
-        with assert_queries_count(16):
+        with assert_queries_count(18):
             job.run()
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index c1882e1..db23271 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -18,6 +18,7 @@
 
 import datetime
 import os
+import signal
 import time
 import unittest
 import urllib
@@ -522,6 +523,37 @@ class TestTaskInstance(unittest.TestCase):
         ti.run()
         assert State.SKIPPED == ti.state
 
+    def test_task_sigterm_works_with_retries(self):
+        """
+        Test that ensures that tasks are retried when they receive sigterm
+        """
+        dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+
+        def task_function(ti):
+            # pylint: disable=unused-argument
+            os.kill(ti.pid, signal.SIGTERM)
+
+        task = PythonOperator(
+            task_id='test_on_failure',
+            python_callable=task_function,
+            retries=1,
+            retry_delay=datetime.timedelta(seconds=2),
+            dag=dag,
+        )
+
+        dag.create_dagrun(
+            run_id="test",
+            state=State.RUNNING,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+        )
+        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        with self.assertRaises(AirflowException):
+            ti.run()
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
+
     def test_retry_delay(self):
         """
         Test that retry delays are respected

[airflow] 01/03: Fix bug that log can't be shown when task runs failed (#16768)

Posted by jh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jhtimmins pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 26e8d2d77aeb51b8dbfb758c477417ccd0c45063
Author: Zhanfeng Huo <hu...@gmail.com>
AuthorDate: Fri Jul 23 06:13:40 2021 +0800

    Fix bug that log can't be shown when task runs failed (#16768)
    
    The log can't be shown normally when the task runs failed. Users can only get useless logs as follows. #13692
    
    <pre>
    *** Log file does not exist: /home/airflow/airflow/logs/dag_id/task_id/2021-06-28T00:00:00+08:00/28.log
    *** Fetching from: http://:8793/log/dag_id/task_id/2021-06-28T00:00:00+08:00/28.log
    *** Failed to fetch log file from worker. Unsupported URL protocol
    </pre>
    
    The root cause is that scheduler will overwrite the hostname info into the task_instance table in DB by using blank str in the progress of `_execute_task_callbacks` when tasks into failed.  Webserver can't get the right host of the task from task_instance because the hostname info of  task_instance table is lost in the progress.
    
    Co-authored-by: huozhanfeng <hu...@vipkid.cn>
    (cherry picked from commit 34478c26d7de1328797e03bbf96d8261796fccbb)
---
 airflow/jobs/scheduler_job.py          | 11 ++++-------
 tests/dag_processing/test_processor.py | 23 +++++++++++++++++++++++
 2 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 198635c..6d37a2b 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -588,14 +588,11 @@ class DagFileProcessor(LoggingMixin):
             if simple_ti.task_id in dag.task_ids:
                 task = dag.get_task(simple_ti.task_id)
                 ti = TI(task, simple_ti.execution_date)
-                # Get properties needed for failure handling from SimpleTaskInstance.
-                ti.start_date = simple_ti.start_date
-                ti.end_date = simple_ti.end_date
-                ti.try_number = simple_ti.try_number
-                ti.state = simple_ti.state
-                ti.test_mode = self.UNIT_TEST_MODE
                 if request.is_failure_callback:
-                    ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
+                    ti = TI(task, simple_ti.execution_date)
+                    # TODO: Use simple_ti to improve performance here in the future
+                    ti.refresh_from_db()
+                    ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE)
                     self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
 
     @provide_session
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index 243afc7..b6b9589 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -644,6 +644,29 @@ class TestDagFileProcessor(unittest.TestCase):
                 test_mode=conf.getboolean('core', 'unit_test_mode'),
             )
 
+    def test_failure_callbacks_should_not_drop_hostname(self):
+        dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        dag_file_processor.UNIT_TEST_MODE = False
+
+        with create_session() as session:
+            dag = dagbag.get_dag('example_branch_operator')
+            task = dag.get_task(task_id='run_this_first')
+
+            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+            ti.hostname = "test_hostname"
+            session.add(ti)
+
+        with create_session() as session:
+            requests = [
+                TaskCallbackRequest(
+                    full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
+                )
+            ]
+            dag_file_processor.execute_callbacks(dagbag, requests)
+            tis = session.query(TaskInstance)
+            assert tis[0].hostname == "test_hostname"
+
     def test_process_file_should_failure_callback(self):
         dag_file = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'