You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/01/18 23:39:51 UTC

[airflow] branch master updated: Fix race conditions in task callback invocations (#10917)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new f1d4f54  Fix race conditions in task callback invocations (#10917)
f1d4f54 is described below

commit f1d4f54b3479cd7549ce79efadd25cc6859dd420
Author: QP Hou <qp...@scribd.com>
AuthorDate: Mon Jan 18 15:39:41 2021 -0800

    Fix race conditions in task callback invocations (#10917)
    
    This race condition resulted in task success and failure callbacks being
    called more than once. Here is the order of events that could lead to
    this issue:
    
    * task started running within process 2
    * (process 1) local_task_job checked for task return code, returns None
    * (process 2) task exited with failure state, task state updated as failed in DB
    * (process 2) task failure callback invoked through taskinstance.handle_failure method
    * (process 1) local_task_job heartbeat noticed task state set to
      failure, mistoken it as state bing updated externally, also invoked task
      failure callback
    
    To avoid this race condition, we need to make sure task callbacks are
    only invoked within a single process.
---
 airflow/cli/cli_parser.py                          |   2 +
 airflow/cli/commands/task_command.py               |   5 +-
 airflow/executors/debug_executor.py                |   4 +
 airflow/jobs/backfill_job.py                       |   2 +-
 airflow/jobs/local_task_job.py                     |  36 +++--
 airflow/jobs/scheduler_job.py                      |   6 +-
 airflow/models/dag.py                              |   2 +-
 airflow/models/taskinstance.py                     | 139 +++++++++++++-----
 airflow/task/task_runner/base_task_runner.py       |  35 +++--
 airflow/task/task_runner/standard_task_runner.py   |   3 +-
 airflow/utils/process_utils.py                     |   7 +-
 tests/core/test_core.py                            |   4 +-
 tests/jobs/test_local_task_job.py                  | 156 ++++++++++++++++-----
 tests/jobs/test_scheduler_job.py                   |   5 +-
 tests/models/test_taskinstance.py                  |   6 +-
 .../apache/hive/transfers/test_mysql_to_hive.py    |  67 +++++----
 16 files changed, 343 insertions(+), 136 deletions(-)

diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 35949ea..8c52c86 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -405,6 +405,7 @@ ARG_SHIP_DAG = Arg(
     ("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true"
 )
 ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
+ARG_ERROR_FILE = Arg(("--error-file",), help="File to store task failure error")
 ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
 ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
 ARG_MIGRATION_TIMEOUT = Arg(
@@ -962,6 +963,7 @@ TASKS_COMMANDS = (
             ARG_PICKLE,
             ARG_JOB_ID,
             ARG_INTERACTIVE,
+            ARG_ERROR_FILE,
             ARG_SHUT_DOWN_LOGGING,
         ),
     ),
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index c3bce77..b89ecfa 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -47,7 +47,7 @@ from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session
 
 
-def _run_task_by_selected_method(args, dag, ti):
+def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
     """
     Runs the task in one of 3 modes
 
@@ -132,7 +132,7 @@ RAW_TASK_UNSUPPORTED_OPTION = [
 ]
 
 
-def _run_raw_task(args, ti):
+def _run_raw_task(args, ti: TaskInstance) -> None:
     """Runs the main task handling code"""
     unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]
 
@@ -149,6 +149,7 @@ def _run_raw_task(args, ti):
         mark_success=args.mark_success,
         job_id=args.job_id,
         pool=args.pool,
+        error_file=args.error_file,
     )
 
 
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index 580dc65..25aace0 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -66,6 +66,7 @@ class DebugExecutor(BaseExecutor):
                 self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
                 ti.set_state(State.FAILED)
                 self.change_state(ti.key, State.FAILED)
+                ti._run_finished_callback()  # pylint: disable=protected-access
                 continue
 
             task_succeeded = self._run_task(ti)
@@ -77,9 +78,12 @@ class DebugExecutor(BaseExecutor):
             params = self.tasks_params.pop(ti.key, {})
             ti._run_raw_task(job_id=ti.job_id, **params)  # pylint: disable=protected-access
             self.change_state(key, State.SUCCESS)
+            ti._run_finished_callback()  # pylint: disable=protected-access
             return True
         except Exception as e:  # pylint: disable=broad-except
+            ti.set_state(State.FAILED)
             self.change_state(key, State.FAILED)
+            ti._run_finished_callback()  # pylint: disable=protected-access
             self.log.exception("Failed to execute task: %s.", str(e))
             return False
 
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index da64b21..0d3d057 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -280,7 +280,7 @@ class BackfillJob(BaseJob):
                     "killed externally? Info: {}".format(ti, state, ti.state, info)
                 )
                 self.log.error(msg)
-                ti.handle_failure(msg)
+                ti.handle_failure_with_callback(error=msg)
 
     @provide_session
     def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None):
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index f4d4ef0..d68bfc7 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -98,7 +98,14 @@ class LocalTaskJob(BaseJob):
 
             heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
 
-            while True:
+            # task callback invocation happens either here or in
+            # self.heartbeat() instead of taskinstance._run_raw_task to
+            # avoid race conditions
+            #
+            # When self.terminating is set to True by heartbeat_callback, this
+            # loop should not be restarted. Otherwise self.handle_task_exit
+            # will be invoked and we will end up with duplicated callbacks
+            while not self.terminating:
                 # Monitor the task to see if it's done. Wait in a syscall
                 # (`os.wait`) for as long as possible so we notice the
                 # subprocess finishing as quick as we can
@@ -115,7 +122,7 @@ class LocalTaskJob(BaseJob):
 
                 return_code = self.task_runner.return_code(timeout=max_wait_time)
                 if return_code is not None:
-                    self.log.info("Task exited with return code %s", return_code)
+                    self.handle_task_exit(return_code)
                     return
 
                 self.heartbeat()
@@ -134,6 +141,17 @@ class LocalTaskJob(BaseJob):
         finally:
             self.on_kill()
 
+    def handle_task_exit(self, return_code: int) -> None:
+        """Handle case where self.task_runner exits by itself"""
+        self.log.info("Task exited with return code %s", return_code)
+        self.task_instance.refresh_from_db()
+        # task exited by itself, so we need to check for error file
+        # incase it failed due to runtime exception/error
+        error = None
+        if self.task_instance.state != State.SUCCESS:
+            error = self.task_runner.deserialize_run_error()
+        self.task_instance._run_finished_callback(error=error)  # pylint: disable=protected-access
+
     def on_kill(self):
         self.task_runner.terminate()
         self.task_runner.on_finish()
@@ -169,11 +187,13 @@ class LocalTaskJob(BaseJob):
             self.log.warning(
                 "State of this instance has been externally set to %s. " "Terminating instance.", ti.state
             )
-            if ti.state == State.FAILED and ti.task.on_failure_callback:
-                context = ti.get_template_context()
-                ti.task.on_failure_callback(context)
-            if ti.state == State.SUCCESS and ti.task.on_success_callback:
-                context = ti.get_template_context()
-                ti.task.on_success_callback(context)
             self.task_runner.terminate()
+            if ti.state == State.SUCCESS:
+                error = None
+            else:
+                # if ti.state is not set by taskinstance.handle_failure, then
+                # error file will not be populated and it must be updated by
+                # external source suck as web UI
+                error = self.task_runner.deserialize_run_error() or "task marked as failed externally"
+            ti._run_finished_callback(error=error)  # pylint: disable=protected-access
             self.terminating = True
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 3d58077..82b7561 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -590,7 +590,7 @@ class DagFileProcessor(LoggingMixin):
                 ti.state = simple_ti.state
                 ti.test_mode = self.UNIT_TEST_MODE
                 if request.is_failure_callback:
-                    ti.handle_failure(request.msg, ti.test_mode, ti.get_template_context())
+                    ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
                     self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
 
     @provide_session
@@ -1732,8 +1732,8 @@ class SchedulerJob(BaseJob):  # pylint: disable=too-many-instance-attributes
         pools = models.Pool.slots_stats(session=session)
         for pool_name, slot_stats in pools.items():
             Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
-            Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED])
-            Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING])
+            Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED])  # type: ignore
+            Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING])  # type: ignore
 
     @provide_session
     def heartbeat_callback(self, session: Session = None) -> None:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 1d0105d..d9096ff 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -584,7 +584,7 @@ class DAG(LoggingMixin):
         next_run_date = None
         if not date_last_automated_dagrun:
             # First run
-            task_start_dates = [t.start_date for t in self.tasks]
+            task_start_dates = [t.start_date for t in self.tasks if t.start_date]
             if task_start_dates:
                 next_run_date = self.normalize_schedule(min(task_start_dates))
                 self.log.debug("Next run date based on tasks %s", next_run_date)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 203b5db..6b6f906 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -21,10 +21,12 @@ import hashlib
 import logging
 import math
 import os
+import pickle
 import signal
 import warnings
 from datetime import datetime, timedelta
-from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
+from tempfile import NamedTemporaryFile
+from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
 from urllib.parse import quote
 
 import dill
@@ -105,6 +107,29 @@ def set_current_context(context: Context):
             )
 
 
+def load_error_file(fd: IO[bytes]) -> Optional[Union[str, Exception]]:
+    """Load and return error from error file"""
+    fd.seek(0, os.SEEK_SET)
+    data = fd.read()
+    if not data:
+        return None
+    try:
+        return pickle.loads(data)
+    except Exception:  # pylint: disable=broad-except
+        return "Failed to load task run error"
+
+
+def set_error_file(error_file: str, error: Union[str, Exception]) -> None:
+    """Write error into error file by path"""
+    with open(error_file, "wb") as fd:
+        try:
+            pickle.dump(error, fd)
+        except Exception:  # pylint: disable=broad-except
+            # local class objects cannot be pickled, so we fallback
+            # to store the string representation instead
+            pickle.dump(str(error), fd)
+
+
 def clear_task_instances(
     tis,
     session,
@@ -1053,6 +1078,7 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
         test_mode: bool = False,
         job_id: Optional[str] = None,
         pool: Optional[str] = None,
+        error_file: Optional[str] = None,
         session=None,
     ) -> None:
         """
@@ -1111,7 +1137,7 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
             return
         except AirflowFailException as e:
             self.refresh_from_db()
-            self.handle_failure(e, test_mode, context, force_fail=True)
+            self.handle_failure(e, test_mode, force_fail=True, error_file=error_file)
             raise
         except AirflowException as e:
             self.refresh_from_db()
@@ -1120,16 +1146,14 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
             if self.state in {State.SUCCESS, State.FAILED}:
                 return
             else:
-                self.handle_failure(e, test_mode, context)
+                self.handle_failure(e, test_mode, error_file=error_file)
                 raise
         except (Exception, KeyboardInterrupt) as e:
-            self.handle_failure(e, test_mode, context)
+            self.handle_failure(e, test_mode, error_file=error_file)
             raise
         finally:
             Stats.incr(f'ti.finish.{task.dag_id}.{task.task_id}.{self.state}')
 
-        self._run_success_callback(context, task)
-
         # Recording SUCCESS
         self.end_date = timezone.utcnow()
         self.log.info(
@@ -1275,16 +1299,6 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
         # Raise exception for sensing state
         raise AirflowSmartSensorException("Task successfully registered in smart sensor.")
 
-    def _run_success_callback(self, context, task):
-        """Functions that need to be run if Task is successful"""
-        # Success callback
-        try:
-            if task.on_success_callback:
-                task.on_success_callback(context)
-        except Exception as exc:  # pylint: disable=broad-except
-            self.log.error("Failed when executing success callback")
-            self.log.exception(exc)
-
     def _execute_task(self, context, task_copy):
         """Executes Task (optionally with a Timeout) and pushes Xcom results"""
         # If a timeout is specified for the task, make it fail
@@ -1303,7 +1317,7 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
             self.xcom_push(key=XCOM_RETURN_KEY, value=result)
         return result
 
-    def _run_execute_callback(self, context, task):
+    def _run_execute_callback(self, context: Context, task):
         """Functions that need to be run before a Task is executed"""
         try:
             if task.on_execute_callback:
@@ -1312,6 +1326,31 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
             self.log.error("Failed when executing execute callback")
             self.log.exception(exc)
 
+    def _run_finished_callback(self, error: Optional[Union[str, Exception]] = None) -> None:
+        """
+        Call callback defined for finished state change.
+
+        NOTE: Only invoke this function from caller of self._run_raw_task or
+        self.run
+        """
+        if self.state == State.FAILED:
+            task = self.task
+            if task.on_failure_callback is not None:
+                context = self.get_template_context()
+                context["exception"] = error
+                task.on_failure_callback(context)
+        elif self.state == State.SUCCESS:
+            task = self.task
+            if task.on_success_callback is not None:
+                context = self.get_template_context()
+                task.on_success_callback(context)
+        elif self.state == State.UP_FOR_RETRY:
+            task = self.task
+            if task.on_retry_callback is not None:
+                context = self.get_template_context()
+                context["exception"] = error
+                task.on_retry_callback(context)
+
     @provide_session
     def run(  # pylint: disable=too-many-arguments
         self,
@@ -1339,10 +1378,23 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
             pool=pool,
             session=session,
         )
-        if res:
+        if not res:
+            return
+
+        try:
+            error_fd = NamedTemporaryFile(delete=True)
             self._run_raw_task(
-                mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
+                mark_success=mark_success,
+                test_mode=test_mode,
+                job_id=job_id,
+                pool=pool,
+                error_file=error_fd.name,
+                session=session,
             )
+        finally:
+            error = None if self.state == State.SUCCESS else load_error_file(error_fd)
+            error_fd.close()
+            self._run_finished_callback(error=error)
 
     def dry_run(self):
         """Only Renders Templates for the TI"""
@@ -1386,14 +1438,25 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
         self.log.info('Rescheduling task, marking task as UP_FOR_RESCHEDULE')
 
     @provide_session
-    def handle_failure(self, error, test_mode=None, context=None, force_fail=False, session=None):
+    def handle_failure(
+        self,
+        error: Union[str, Exception],
+        test_mode: Optional[bool] = None,
+        force_fail: bool = False,
+        error_file: Optional[str] = None,
+        session=None,
+    ) -> None:
         """Handle Failure for the TaskInstance"""
         if test_mode is None:
             test_mode = self.test_mode
-        if context is None:
-            context = self.get_template_context()
 
-        self.log.exception(error)
+        if error:
+            self.log.exception(error)
+            # external monitoring process provides pickle file so _run_raw_task
+            # can send its runtime errors for access by failure callback
+            if error_file:
+                set_error_file(error_file, error)
+
         task = self.task
         self.end_date = timezone.utcnow()
         self.set_duration()
@@ -1405,11 +1468,12 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
         # Log failure duration
         session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date))
 
-        if context is not None:
-            context['exception'] = error
+        # Set state correctly and figure out how to log it and decide whether
+        # to email
 
-        # Set state correctly and figure out how to log it,
-        # what callback to call if any, and how to decide whether to email
+        # Note, callback invocation needs to be handled by caller of
+        # _run_raw_task to avoid race conditions which could lead to duplicate
+        # invocations or miss invocation.
 
         # Since this function is called only when the TaskInstance state is running,
         # try_number contains the current try_number (not the next). We
@@ -1423,12 +1487,10 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
             else:
                 log_message = "Marking task as FAILED."
             email_for_state = task.email_on_failure
-            callback = task.on_failure_callback
         else:
             self.state = State.UP_FOR_RETRY
             log_message = "Marking task as UP_FOR_RETRY."
             email_for_state = task.email_on_retry
-            callback = task.on_retry_callback
 
         self.log.info(
             '%s dag_id=%s, task_id=%s, execution_date=%s, start_date=%s, end_date=%s',
@@ -1446,18 +1508,21 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
                 self.log.error('Failed to send email to: %s', task.email)
                 self.log.exception(exec2)
 
-        # Handling callbacks pessimistically
-        if callback:
-            try:
-                callback(context)
-            except Exception as exec3:  # pylint: disable=broad-except
-                self.log.error("Failed at executing callback")
-                self.log.exception(exec3)
-
         if not test_mode:
             session.merge(self)
         session.commit()
 
+    @provide_session
+    def handle_failure_with_callback(
+        self,
+        error: Union[str, Exception],
+        test_mode: Optional[bool] = None,
+        force_fail: bool = False,
+        session=None,
+    ) -> None:
+        self.handle_failure(error=error, test_mode=test_mode, force_fail=force_fail, session=session)
+        self._run_finished_callback(error=error)
+
     def is_eligible_to_retry(self):
         """Is task instance is eligible for retry"""
         return self.task.retries and self.try_number <= self.max_tries
diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py
index 743685e..81235ea 100644
--- a/airflow/task/task_runner/base_task_runner.py
+++ b/airflow/task/task_runner/base_task_runner.py
@@ -20,9 +20,12 @@ import getpass
 import os
 import subprocess
 import threading
+from tempfile import NamedTemporaryFile
+from typing import Optional, Union
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowConfigException
+from airflow.models.taskinstance import load_error_file
 from airflow.utils.configuration import tmp_configuration_copy
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
@@ -81,17 +84,26 @@ class BaseTaskRunner(LoggingMixin):
             # - the runner can read/execute those values as it needs
             cfg_path = tmp_configuration_copy(chmod=0o600)
 
+        self._error_file = NamedTemporaryFile(delete=True)
         self._cfg_path = cfg_path
-        self._command = popen_prepend + self._task_instance.command_as_list(
-            raw=True,
-            pickle_id=local_task_job.pickle_id,
-            mark_success=local_task_job.mark_success,
-            job_id=local_task_job.id,
-            pool=local_task_job.pool,
-            cfg_path=cfg_path,
+        self._command = (
+            popen_prepend
+            + self._task_instance.command_as_list(
+                raw=True,
+                pickle_id=local_task_job.pickle_id,
+                mark_success=local_task_job.mark_success,
+                job_id=local_task_job.id,
+                pool=local_task_job.pool,
+                cfg_path=cfg_path,
+            )
+            + ["--error-file", self._error_file.name]
         )
         self.process = None
 
+    def deserialize_run_error(self) -> Optional[Union[str, Exception]]:
+        """Return task runtime error if its written to provided error file."""
+        return load_error_file(self._error_file)
+
     def _read_task_logs(self, stream):
         while True:
             line = stream.readline()
@@ -144,7 +156,7 @@ class BaseTaskRunner(LoggingMixin):
         """Start running the task instance in a subprocess."""
         raise NotImplementedError()
 
-    def return_code(self):
+    def return_code(self) -> Optional[int]:
         """
         :return: The return code associated with running the task instance or
             None if the task is not yet done.
@@ -152,14 +164,15 @@ class BaseTaskRunner(LoggingMixin):
         """
         raise NotImplementedError()
 
-    def terminate(self):
-        """Kill the running task instance."""
+    def terminate(self) -> None:
+        """Force kill the running task instance."""
         raise NotImplementedError()
 
-    def on_finish(self):
+    def on_finish(self) -> None:
         """A callback that should be called when this is done running."""
         if self._cfg_path and os.path.isfile(self._cfg_path):
             if self.run_as_user:
                 subprocess.call(['sudo', 'rm', self._cfg_path], close_fds=True)
             else:
                 os.remove(self._cfg_path)
+        self._error_file.close()
diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py
index fee9b0d..505b225 100644
--- a/airflow/task/task_runner/standard_task_runner.py
+++ b/airflow/task/task_runner/standard_task_runner.py
@@ -18,6 +18,7 @@
 """Standard task runner"""
 import logging
 import os
+from typing import Optional
 
 import psutil
 from setproctitle import setproctitle  # pylint: disable=no-name-in-module
@@ -91,7 +92,7 @@ class StandardTaskRunner(BaseTaskRunner):
                 logging.shutdown()
                 os._exit(return_code)  # pylint: disable=protected-access
 
-    def return_code(self, timeout=0):
+    def return_code(self, timeout: int = 0) -> Optional[int]:
         # We call this multiple times, but we can only wait on the process once
         if self._rc is not None or not self.process:
             return self._rc
diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py
index 1ee1ac6..38607bd 100644
--- a/airflow/utils/process_utils.py
+++ b/airflow/utils/process_utils.py
@@ -44,7 +44,12 @@ log = logging.getLogger(__name__)
 DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint('core', 'KILLED_TASK_CLEANUP_TIME')
 
 
-def reap_process_group(pgid, logger, sig=signal.SIGTERM, timeout=DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM):
+def reap_process_group(
+    pgid: int,
+    logger,
+    sig: 'signal.Signals' = signal.SIGTERM,
+    timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM,
+) -> Dict[int, int]:
     """
     Tries really hard to terminate all processes in the group (including grandchildren). Will send
     sig (SIGTERM) to the process group of pid. If any process is alive after timeout
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 5073aa4..fae2c73 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -190,9 +190,9 @@ class TestCore(unittest.TestCase):
         # Annoying workaround for nonlocal not existing in python 2
         data = {'called': False}
 
-        def check_failure(context, test_case=self):
+        def check_failure(context, test_case=self):  # pylint: disable=unused-argument
             data['called'] = True
-            error = context.get('exception')
+            error = context.get("exception")
             test_case.assertIsInstance(error, AirflowException)
 
         op = BashOperator(
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index fdd0163..537a242 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -21,20 +21,22 @@ import os
 import time
 import unittest
 import uuid
+from multiprocessing import Lock, Value
 from unittest import mock
 from unittest.mock import patch
 
 import pytest
 
 from airflow import settings
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowFailException
 from airflow.executors.sequential_executor import SequentialExecutor
 from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.models.dag import DAG
 from airflow.models.dagbag import DagBag
 from airflow.models.taskinstance import TaskInstance
-from airflow.operators.dummy import DummyOperator
+from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.python import PythonOperator
+from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
 from airflow.utils import timezone
 from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session
@@ -242,8 +244,6 @@ class TestLocalTaskJob(unittest.TestCase):
         ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti_run.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor())
-        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
-
         with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method:
             job1.run()
             mock_method.assert_not_called()
@@ -286,8 +286,6 @@ class TestLocalTaskJob(unittest.TestCase):
             return return_codes.pop(0)
 
         time_start = time.time()
-        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
-
         with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start:
             with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code:
                 mock_ret_code.side_effect = multi_return_code
@@ -311,14 +309,18 @@ class TestLocalTaskJob(unittest.TestCase):
         Test that ensures that mark_failure in the UI fails
         the task, and executes on_failure_callback
         """
-        data = {'called': False}
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        failure_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
 
         def check_failure(context):
+            with failure_callback_called.get_lock():
+                failure_callback_called.value += 1
             assert context['dag_run'].dag_id == 'test_mark_failure'
-            data['called'] = True
+            assert context['exception'] == "task marked as failed externally"
 
         def task_function(ti):
-            print("python_callable run in pid %s", os.getpid())
             with create_session() as session:
                 assert State.RUNNING == ti.state
                 ti.log.info("Marking TI as failed 'externally'")
@@ -326,9 +328,10 @@ class TestLocalTaskJob(unittest.TestCase):
                 session.merge(ti)
                 session.commit()
 
-            time.sleep(60)
+            time.sleep(10)
             # This should not happen -- the state change should be noticed and the task should get killed
-            data['reached_end_of_sleep'] = True
+            with task_terminated_externally.get_lock():
+                task_terminated_externally.value = 0
 
         with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
             task = PythonOperator(
@@ -337,16 +340,15 @@ class TestLocalTaskJob(unittest.TestCase):
                 on_failure_callback=check_failure,
             )
 
-        session = settings.Session()
-
         dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
+        with create_session() as session:
+            dag.create_dagrun(
+                run_id="test",
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                start_date=DEFAULT_DATE,
+                session=session,
+            )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
 
@@ -358,24 +360,106 @@ class TestLocalTaskJob(unittest.TestCase):
 
         ti.refresh_from_db()
         assert ti.state == State.FAILED
-        assert data['called']
-        assert 'reached_end_of_sleep' not in data, 'Task should not have been allowed to run to completion'
+        assert failure_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+
+    @patch('airflow.utils.process_utils.subprocess.check_call')
+    @patch.object(StandardTaskRunner, 'return_code')
+    def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
+        """
+        Test that ensures that when a task exits with failure by itself,
+        failure callback is only called once
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        failure_callback_called = Value('i', 0)
+        callback_count_lock = Lock()
+
+        def failure_callback(context):
+            with callback_count_lock:
+                failure_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_failure_callback_race'
+            assert isinstance(context['exception'], AirflowFailException)
+
+        def task_function(ti):
+            raise AirflowFailException()
+
+        dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
+        task = PythonOperator(
+            task_id='test_exit_on_failure',
+            python_callable=task_function,
+            on_failure_callback=failure_callback,
+            dag=dag,
+        )
+
+        dag.clear()
+        with create_session() as session:
+            dag.create_dagrun(
+                run_id="test",
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                start_date=DEFAULT_DATE,
+                session=session,
+            )
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+
+        # Simulate race condition where job1 heartbeat ran right after task
+        # state got set to failed by ti.handle_failure but before task process
+        # fully exits. See _execute loop in airflow/jobs/local_task_job.py.
+        # In this case, we have:
+        #  * task_runner.return_code() is None
+        #  * ti.state == State.Failed
+        #
+        # We also need to set return_code to a valid int after job1.terminating
+        # is set to True so _execute loop won't loop forever.
+        def dummy_return_code(*args, **kwargs):
+            return None if not job1.terminating else -9
+
+        mock_return_code.side_effect = dummy_return_code
+
+        with timeout(10):
+            # This should be _much_ shorter to run.
+            # If you change this limit, make the timeout in the callbable above bigger
+            job1.run()
+
+        ti.refresh_from_db()
+        assert ti.state == State.FAILED  # task exits with failure state
+        assert failure_callback_called.value == 1
 
-    @pytest.mark.quarantined
     def test_mark_success_on_success_callback(self):
         """
         Test that ensures that where a task is marked suceess in the UI
         on_success_callback gets executed
         """
-        data = {'called': False}
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        success_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
 
         def success_callback(context):
+            with shared_mem_lock:
+                success_callback_called.value += 1
             assert context['dag_run'].dag_id == 'test_mark_success'
-            data['called'] = True
 
         dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
 
-        task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback)
+        def task_function(ti):
+            # pylint: disable=unused-argument
+            time.sleep(60)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        task = PythonOperator(
+            task_id='test_state_succeeded1',
+            python_callable=task_function,
+            on_success_callback=success_callback,
+            dag=dag,
+        )
 
         session = settings.Session()
 
@@ -390,25 +474,25 @@ class TestLocalTaskJob(unittest.TestCase):
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
-        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
-
         job1.task_runner = StandardTaskRunner(job1)
+
+        settings.engine.dispose()
         process = multiprocessing.Process(target=job1.run)
         process.start()
-        ti.refresh_from_db()
-        for _ in range(0, 50):
+
+        for _ in range(0, 25):
+            ti.refresh_from_db()
             if ti.state == State.RUNNING:
                 break
-            time.sleep(0.1)
-            ti.refresh_from_db()
-        assert State.RUNNING == ti.state
+            time.sleep(0.2)
+        assert ti.state == State.RUNNING
         ti.state = State.SUCCESS
         session.merge(ti)
         session.commit()
 
-        job1.heartbeat_callback(session=None)
-        assert data['called']
         process.join(timeout=10)
+        assert success_callback_called.value == 1
+        assert task_terminated_externally.value == 1
         assert not process.is_alive()
 
 
@@ -436,5 +520,5 @@ class TestLocalTaskJobPerformance:
         mock_get_task_runner.return_value.return_code.side_effects = return_codes
 
         job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
-        with assert_queries_count(12):
+        with assert_queries_count(13):
             job.run()
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index c5340e8..ef6c52b 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -642,7 +642,7 @@ class TestDagFileProcessor(unittest.TestCase):
             num_scheduled = scheduler._schedule_dag_run(dr3, {dr1.execution_date}, session)
             assert num_scheduled == 0
 
-    @patch.object(TaskInstance, 'handle_failure')
+    @patch.object(TaskInstance, 'handle_failure_with_callback')
     def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
         dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
         dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
@@ -663,7 +663,8 @@ class TestDagFileProcessor(unittest.TestCase):
             ]
             dag_file_processor.execute_callbacks(dagbag, requests)
             mock_ti_handle_failure.assert_called_once_with(
-                "Message", conf.getboolean('core', 'unit_test_mode'), mock.ANY
+                error="Message",
+                test_mode=conf.getboolean('core', 'unit_test_mode'),
             )
 
     def test_process_file_should_failure_callback(self):
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 012c547..cd99b02 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1389,8 +1389,9 @@ class TestTaskInstance(unittest.TestCase):
 
         callback_wrapper.wrap_task_instance(ti)
         ti._run_raw_task()
+        ti._run_finished_callback()
         assert callback_wrapper.callback_ran
-        assert callback_wrapper.task_state_in_callback == State.RUNNING
+        assert callback_wrapper.task_state_in_callback == State.SUCCESS
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
 
@@ -1618,6 +1619,7 @@ class TestTaskInstance(unittest.TestCase):
         ti1 = TI(task=task1, execution_date=start_date)
         ti1.state = State.FAILED
         ti1.handle_failure("test failure handling")
+        ti1._run_finished_callback()
 
         context_arg_1 = mock_on_failure_1.call_args[0][0]
         assert context_arg_1 and "task_instance" in context_arg_1
@@ -1635,6 +1637,7 @@ class TestTaskInstance(unittest.TestCase):
         ti2 = TI(task=task2, execution_date=start_date)
         ti2.state = State.FAILED
         ti2.handle_failure("test retry handling")
+        ti2._run_finished_callback()
 
         mock_on_failure_2.assert_not_called()
 
@@ -1654,6 +1657,7 @@ class TestTaskInstance(unittest.TestCase):
         ti3 = TI(task=task3, execution_date=start_date)
         ti3.state = State.FAILED
         ti3.handle_failure("test force_fail handling", force_fail=True)
+        ti3._run_finished_callback()
 
         context_arg_3 = mock_on_failure_3.call_args[0][0]
         assert context_arg_3 and "task_instance" in context_arg_3
diff --git a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
index b921c74..c6f7736 100644
--- a/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_mysql_to_hive.py
@@ -18,6 +18,7 @@
 
 import unittest
 from collections import OrderedDict
+from os import path
 from unittest import mock
 
 import pytest
@@ -35,6 +36,27 @@ DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
 TEST_DAG_ID = 'unit_test_dag'
 
 
+class HiveopTempFile:
+    """
+    Make sure temp file path is in the format of "/tmp/airflow_hiveop_t_78lpye/tmpour2_kig",
+    """
+
+    def __eq__(self, other: str) -> bool:
+        (head, tail) = path.split(other)
+        (head, tail) = path.split(head)
+        return tail.startswith("airflow_hiveop_")
+
+
+class HiveopTempDir:
+    """
+    Make sure temp dir path is in the format of "/tmp/airflow_hiveop_t_78lpye",
+    """
+
+    def __eq__(self, other: str) -> bool:
+        (_, tail) = path.split(other)
+        return tail.startswith("airflow_hiveop_")
+
+
 @pytest.mark.backend("mysql")
 class TestTransfer(unittest.TestCase):
     def setUp(self):
@@ -126,13 +148,10 @@ class TestTransfer(unittest.TestCase):
         with MySqlHook().get_conn() as cur:
             cur.execute("DROP TABLE IF EXISTS baby_names CASCADE;")
 
-    @mock.patch('tempfile.tempdir', '/tmp/')
-    @mock.patch('tempfile._RandomNameSequence.__next__')
     @mock.patch('subprocess.Popen')
-    def test_mysql_to_hive(self, mock_popen, mock_temp_dir):
+    def test_mysql_to_hive(self, mock_popen):
         mock_subprocess = MockSubProcess()
         mock_popen.return_value = mock_subprocess
-        mock_temp_dir.return_value = "test_mysql_to_hive"
 
         with mock.patch.dict('os.environ', self.env_vars):
             sql = "SELECT * FROM baby_names LIMIT 1000;"
@@ -170,24 +189,21 @@ class TestTransfer(unittest.TestCase):
             '-hiveconf',
             'tez.queue.name=airflow',
             '-f',
-            '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+            HiveopTempFile(),
         ]
 
         mock_popen.assert_called_with(
             hive_cmd,
             stdout=mock_subprocess.PIPE,
             stderr=mock_subprocess.STDOUT,
-            cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+            cwd=HiveopTempDir(),
             close_fds=True,
         )
 
-    @mock.patch('tempfile.tempdir', '/tmp/')
-    @mock.patch('tempfile._RandomNameSequence.__next__')
     @mock.patch('subprocess.Popen')
-    def test_mysql_to_hive_partition(self, mock_popen, mock_temp_dir):
+    def test_mysql_to_hive_partition(self, mock_popen):
         mock_subprocess = MockSubProcess()
         mock_popen.return_value = mock_subprocess
-        mock_temp_dir.return_value = "test_mysql_to_hive_partition"
 
         with mock.patch.dict('os.environ', self.env_vars):
             sql = "SELECT * FROM baby_names LIMIT 1000;"
@@ -227,24 +243,21 @@ class TestTransfer(unittest.TestCase):
             '-hiveconf',
             'tez.queue.name=airflow',
             '-f',
-            '/tmp/airflow_hiveop_test_mysql_to_hive_partition/tmptest_mysql_to_hive_partition',
+            HiveopTempFile(),
         ]
 
         mock_popen.assert_called_with(
             hive_cmd,
             stdout=mock_subprocess.PIPE,
             stderr=mock_subprocess.STDOUT,
-            cwd="/tmp/airflow_hiveop_test_mysql_to_hive_partition",
+            cwd=HiveopTempDir(),
             close_fds=True,
         )
 
-    @mock.patch('tempfile.tempdir', '/tmp/')
-    @mock.patch('tempfile._RandomNameSequence.__next__')
     @mock.patch('subprocess.Popen')
-    def test_mysql_to_hive_tblproperties(self, mock_popen, mock_temp_dir):
+    def test_mysql_to_hive_tblproperties(self, mock_popen):
         mock_subprocess = MockSubProcess()
         mock_popen.return_value = mock_subprocess
-        mock_temp_dir.return_value = "test_mysql_to_hive"
 
         with mock.patch.dict('os.environ', self.env_vars):
             sql = "SELECT * FROM baby_names LIMIT 1000;"
@@ -283,14 +296,14 @@ class TestTransfer(unittest.TestCase):
             '-hiveconf',
             'tez.queue.name=airflow',
             '-f',
-            '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+            HiveopTempFile(),
         ]
 
         mock_popen.assert_called_with(
             hive_cmd,
             stdout=mock_subprocess.PIPE,
             stderr=mock_subprocess.STDOUT,
-            cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+            cwd=HiveopTempDir(),
             close_fds=True,
         )
 
@@ -340,13 +353,10 @@ class TestTransfer(unittest.TestCase):
             with hook.get_conn() as conn:
                 conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
 
-    @mock.patch('tempfile.tempdir', '/tmp/')
-    @mock.patch('tempfile._RandomNameSequence.__next__')
     @mock.patch('subprocess.Popen')
-    def test_mysql_to_hive_verify_csv_special_char(self, mock_popen, mock_temp_dir):
+    def test_mysql_to_hive_verify_csv_special_char(self, mock_popen):
         mock_subprocess = MockSubProcess()
         mock_popen.return_value = mock_subprocess
-        mock_temp_dir.return_value = "test_mysql_to_hive"
 
         mysql_table = 'test_mysql_to_hive'
         hive_table = 'test_mysql_to_hive'
@@ -424,27 +434,24 @@ class TestTransfer(unittest.TestCase):
                 '-hiveconf',
                 'tez.queue.name=airflow',
                 '-f',
-                '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+                HiveopTempFile(),
             ]
 
             mock_popen.assert_called_with(
                 hive_cmd,
                 stdout=mock_subprocess.PIPE,
                 stderr=mock_subprocess.STDOUT,
-                cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+                cwd=HiveopTempDir(),
                 close_fds=True,
             )
         finally:
             with hook.get_conn() as conn:
                 conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
 
-    @mock.patch('tempfile.tempdir', '/tmp/')
-    @mock.patch('tempfile._RandomNameSequence.__next__')
     @mock.patch('subprocess.Popen')
-    def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir):
+    def test_mysql_to_hive_verify_loaded_values(self, mock_popen):
         mock_subprocess = MockSubProcess()
         mock_popen.return_value = mock_subprocess
-        mock_temp_dir.return_value = "test_mysql_to_hive"
 
         mysql_table = 'test_mysql_to_hive'
         hive_table = 'test_mysql_to_hive'
@@ -537,14 +544,14 @@ class TestTransfer(unittest.TestCase):
                     '-hiveconf',
                     'tez.queue.name=airflow',
                     '-f',
-                    '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive',
+                    HiveopTempFile(),
                 ]
 
                 mock_popen.assert_called_with(
                     hive_cmd,
                     stdout=mock_subprocess.PIPE,
                     stderr=mock_subprocess.STDOUT,
-                    cwd="/tmp/airflow_hiveop_test_mysql_to_hive",
+                    cwd=HiveopTempDir(),
                     close_fds=True,
                 )